WIP chunked KD loss with autograd wrapper

This commit is contained in:
Wing Lian
2025-05-21 12:24:46 -07:00
parent ca70fb7cb0
commit 5cfaac3767
5 changed files with 740 additions and 41 deletions

View File

@@ -34,3 +34,8 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer
return None
def pre_model_load(self, cfg):
from .kernels.liger import apply_kernel_to_qwen2
apply_kernel_to_qwen2()

View File

@@ -183,13 +183,124 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
return tokenized_prompt
class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
"""
Strat for datasets with complete structured KD logprob data
"""
def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
# pylint: disable=duplicate-code
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
target_logprobs = []
target_token_ids = []
target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = [row[:top_k] for row in logprobs]
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# we shift for causal models in the trainer, so start the range from 0
for _ in range(0, input_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for token_pos_logprobs, pos_target_token_ids in zip(
logprobs, sample["target_token_ids"]
):
# Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor(
token_pos_logprobs, dtype=torch.float
)
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# normalize probabilities to sum to 1 in case they aren't already
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
if teacher_probs_t1_sum > 1e-9:
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(pos_target_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
sample["target_mask"] = target_mask
return sample
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
target_token_ids = prompt.pop("target_token_ids")
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt["target_token_ids"] = target_token_ids
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
class KDStrategyLoader(StrategyLoader):
"""
Load ChatTemplateStrategy with KD support using StrategyLoader.
"""
def _get_strategy_cls(self):
return ChatTemplateStrategyWithKD
return ChatTemplateStrategyWithKDv2
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
strategy_params = super()._get_strategy_params(cfg, ds_cfg)

View File

@@ -56,6 +56,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
return_tensors = self.return_tensors
padding_side = self.tokenizer.padding_side
max_len = 0
# Pad labels and position_ids first
for feature_name, pad_token_id in [
@@ -106,7 +107,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
target_mask_list.append(f.pop("target_mask"))
# Determine max lengths
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
max_teacher_seq_len = max_len or max(
len(seq) for seq in target_logprobs_list
)
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
padded_target_logprobs = []

View File

@@ -0,0 +1,576 @@
"""
Liger Kernels for Chunked Top-K Log-Prob Distillation
"""
from typing import Optional, Union, Unpack
import torch
import torch.nn.functional as F
# Assuming LigerFusedLinearDistillationBase is in this path and can be imported
# If not, its structure would need to be replicated or specific utilities copied.
from liger_kernel.chunked_loss.fused_linear_distillation import (
LigerFusedLinearDistillationBase,
)
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen2.modeling_qwen2 import KwargsForCausalLM
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
@staticmethod
def distillation_loss_fn(
student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
temperature: float,
) -> torch.Tensor:
"""
Compute Top-K KL divergence loss for a chunk.
Args:
student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V).
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
temperature: Temperature used for scaling.
Returns:
Sum of KL divergence losses for the chunk.
"""
student_logits_temp_scaled = student_logits_temp_scaled.float()
target_logprobs_chunk = target_logprobs_chunk.float()
# Gather student logits for the top-k teacher token IDs
# student_logits_temp_scaled: [chunk_size, vocab_size]
# target_token_ids_chunk: [chunk_size, top_k]
# student_logits_topk_temp_scaled: [chunk_size, top_k]
student_logits_topk_temp_scaled = torch.gather(
student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk
)
# Student log-probabilities for the gathered top-k tokens
student_lse = torch.logsumexp(
student_logits_temp_scaled, dim=-1, keepdim=True
) # [chunk_size, 1]
student_logprobs_topk_temp_scaled = (
student_logits_topk_temp_scaled - student_lse
)
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
target_logprobs_valid = target_logprobs_chunk[valid_mask]
# Teacher probabilities P(y|x_teacher) from logprobs
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
teacher_probs_valid = target_logprobs_valid.exp()
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
# The distillation loss is often formulated as -sum(P_teacher * log P_student)
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
# Here, target_logprobs_valid are log_softmax_teacher.
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
kd_loss_per_token = teacher_probs_valid * (
target_logprobs_valid - student_logprobs_topk_valid
)
kd_loss = kd_loss_per_token.sum()
if temperature != 1.0:
kd_loss = kd_loss * (temperature**2)
return kd_loss
@staticmethod
def _compute_loss_kl_topk(
student_input_chunk: torch.Tensor,
student_weight: torch.Tensor,
# Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value
# or through `partial`. Let's make them explicit here for clarity.
target_token_ids_chunk: torch.Tensor,
target_logprobs_chunk: torch.Tensor,
target_mask_chunk: torch.Tensor,
target_chunk: torch.Tensor, # For hard loss (true labels)
student_bias: torch.Tensor = None, # This will be one of the grad targets
# Other params passed via `partial` from `forward`
distillation_loss_fn=None,
ignore_index: int = -100,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
compute_ce_loss: bool = True,
temperature: float = 1.0,
):
# Compute student logits for the chunk from hidden states and LM head
# student_input_chunk: [chunk_size, hidden_dim]
# student_lm_head_weight: [vocab_size, hidden_dim]
# student_logits_chunk: [chunk_size, vocab_size]
student_logits_chunk = F.linear(
student_input_chunk, student_weight, student_bias
)
ce_loss = torch.tensor(
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
)
if compute_ce_loss and weight_hard_loss > 0.0:
ce_loss = F.cross_entropy(
student_logits_chunk.view(-1, student_logits_chunk.shape[-1]),
target_chunk.view(-1),
reduction="sum",
ignore_index=ignore_index,
)
soft_loss = torch.tensor(
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
)
if weight_soft_loss > 0.0:
student_logits_chunk_temp_scaled = student_logits_chunk / temperature
# Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max()
# No explicit padding here; user must ensure vocab alignment or pre-pad student_weight.
soft_loss = distillation_loss_fn(
student_logits_chunk_temp_scaled,
target_token_ids_chunk,
target_logprobs_chunk,
target_mask_chunk,
temperature=temperature,
)
loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss
# return loss, (soft_loss, ce_loss, student_logits_chunk) # Aux outputs
return loss, (soft_loss, ce_loss) # Aux outputs
@classmethod
def forward(
cls,
ctx,
student_input: torch.Tensor, # [batch_size, seq_len, dim]
student_lm_head_weight: torch.Tensor, # [dim, vocab_size]
target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k]
target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k]
target_mask: torch.Tensor, # [batch_size, seq_len, top_k]
true_labels: torch.Tensor, # [batch_size, seq_len]
student_lm_head_bias: torch.Tensor = None,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
compiled: bool = False,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
):
CHUNK_SIZE = chunk_size
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
grad_inputs_list = []
grad_bias_acc = (
torch.zeros_like(student_lm_head_bias)
if student_lm_head_bias is not None
else None
)
loss_acc = torch.zeros(
(), device=student_input.device, dtype=student_input.dtype
)
# This function will be what torch.func.grad_and_value differentiates.
# It takes student_input_chunk, student_weight (full), student_bias (full) as primals.
# Other necessary data (target_*, etc.) are passed as non-differentiable arguments.
def loss_fn_for_grad(
_student_input_chunk,
_student_lm_head_weight, # full weight
_student_lm_head_bias, # full bias
# Fixed arguments for a given chunk, not differentiated:
_target_token_ids_chunk,
_target_logprobs_chunk,
_target_mask_chunk,
_true_labels_chunk,
):
return cls._compute_loss_kl_topk(
student_input_chunk=_student_input_chunk,
student_weight=_student_lm_head_weight,
target_token_ids_chunk=_target_token_ids_chunk,
target_logprobs_chunk=_target_logprobs_chunk,
target_mask_chunk=_target_mask_chunk,
target_chunk=_true_labels_chunk,
student_bias=_student_lm_head_bias,
distillation_loss_fn=cls.distillation_loss_fn,
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
)
def accumulate_chunk_grads(
student_input_chunk_ac,
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
):
# student_weight and student_bias are closed over from the outer scope (full tensors)
if student_lm_head_bias is not None:
(
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
(chunk_loss, _aux_outputs),
) = torch.func.grad_and_value(
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
)(
student_input_chunk_ac,
student_lm_head_weight,
student_lm_head_bias, # primals
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
) # non-primals
grad_bias_acc.add_(chunk_grad_bias)
else:
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
(
(chunk_grad_input, chunk_grad_weight), # No grad for bias
(chunk_loss, _aux_outputs),
) = torch.func.grad_and_value(
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
)(
student_input_chunk_ac,
student_lm_head_weight,
None, # Pass None for student_bias primal
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
)
grad_weight_acc.add_(chunk_grad_weight)
loss_acc.add_(chunk_loss)
return chunk_grad_input
if compiled:
accumulate_chunk_grads_compiled = torch.compile(
accumulate_chunk_grads, dynamic=True, backend="inductor"
) # dynamic=True often helpful
else:
accumulate_chunk_grads_compiled = accumulate_chunk_grads
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
B, N, D = student_input.shape
K = target_token_ids.shape[-1]
print("student_input shape: " + str(student_input.shape))
print("target_token_ids shape: " + str(target_token_ids.shape))
print("target_logprobs shape: " + str(target_logprobs.shape))
print("target_mask shape: " + str(target_mask.shape))
print("true_labels shape: " + str(true_labels.shape))
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
true_labels_flat = true_labels.reshape(-1)
print("student_input_flat shape: " + str(student_input_flat.shape))
print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape))
print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape))
print("target_mask_flat shape: " + str(target_mask_flat.shape))
print("true_labels_flat shape: " + str(true_labels_flat.shape))
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(
student_input_flat, chunks=num_chunks, dim=0
)
_target_token_ids_chunks = torch.chunk(
target_token_ids_flat, chunks=num_chunks, dim=0
)
_target_logprobs_chunks = torch.chunk(
target_logprobs_flat, chunks=num_chunks, dim=0
)
_target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0)
_true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0)
for i in range(num_chunks):
grad_input_chunk = accumulate_chunk_grads_compiled(
_student_input_chunks[i],
_target_token_ids_chunks[i],
_target_logprobs_chunks[i],
_target_mask_chunks[i],
_true_labels_chunks[i],
)
grad_inputs_list.append(grad_input_chunk)
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
ctx.hyperparams_count = 7 # Corresponds to number of hyperparams after main tensors in fwd signature
ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K)
return loss_acc / (true_labels_flat != ignore_index).sum()
@staticmethod
def backward(ctx, grad_output):
grad_input_flat, grad_weight, grad_bias_maybe = (
ctx.saved_tensors
) # grad_input_flat is (B*N, D)
# Scale gradients by grad_output if it's not 1.0
if not torch.equal(
grad_output,
torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype),
):
grad_input_flat = grad_input_flat * grad_output
grad_weight = grad_weight * grad_output
if grad_bias_maybe is not None:
grad_bias_maybe = grad_bias_maybe * grad_output
# Reshape grad_input_flat to match original student_input shape (B, N, D)
# ctx.orig_dims stores (B, N, D, K)
# We need the first three dimensions for student_input's shape.
# Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors
if (
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
and grad_input_flat.numel() == 0
):
# If original input was empty, gradient should also be empty with correct shape
grad_input_reshaped = torch.zeros(
ctx.orig_dims[0],
ctx.orig_dims[1],
ctx.orig_dims[2],
dtype=grad_input_flat.dtype,
device=grad_input_flat.device,
)
elif grad_input_flat.numel() == 0 and not (
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
):
# This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad)
# but as a safeguard:
grad_input_reshaped = torch.zeros(
ctx.orig_dims[0],
ctx.orig_dims[1],
ctx.orig_dims[2],
dtype=grad_input_flat.dtype,
device=grad_input_flat.device,
)
else:
grad_input_reshaped = grad_input_flat.view(
ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2]
)
nones_for_hyperparams = [None] * ctx.hyperparams_count
grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None
return (
grad_input_reshaped, # Gradient for student_input (reshaped)
grad_weight, # Gradient for student_lm_head_weight
None, # Gradient for target_token_ids
None, # Gradient for target_logprobs
None, # Gradient for target_mask
None, # Gradient for true_labels
grad_bias_return, # Gradient for student_lm_head_bias
*nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss
)
class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
def __init__(
self,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
temperature: float = 1.0, # This is the kd_temperature
ignore_index: int = -100,
compiled: bool = True,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
):
super().__init__()
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
raise ValueError("Loss weights must be between 0.0 and 1.0.")
if temperature <= 0:
raise ValueError("Temperature must be positive.")
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.temperature = temperature
self.ignore_index = ignore_index
self.compiled = compiled
self.chunk_size = chunk_size
self.compute_ce_loss = compute_ce_loss
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
print(
f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero."
)
# self.weight_hard_loss = 0.0 # Or let user manage this
if self.weight_soft_loss == 0.0:
print(
"Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed."
)
def forward(
self,
lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head
student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head
target_token_ids: torch.Tensor,
target_logprobs: torch.Tensor,
target_mask: torch.Tensor,
true_labels: torch.Tensor,
student_bias: torch.Tensor = None,
) -> torch.Tensor:
print(student_hidden_states.shape)
# Input validation
# if student_hidden_states.ndim != 2 or lm_head_weight.ndim != 2:
# raise ValueError("student_input and student_weight must be 2D tensors.")
# if student_hidden_states.shape[1] != lm_head_weight.shape[1]:
# raise ValueError("Hidden dimension mismatch between student_input and student_weight.")
# if student_bias is not None and (student_bias.ndim != 1 or student_bias.shape[0] != lm_head_weight.shape[0]):
# raise ValueError("student_bias shape mismatch.")
if self.weight_soft_loss > 0.0:
expected_len = student_hidden_states.shape[0]
for name, tensor in [
("target_token_ids", target_token_ids),
("target_logprobs", target_logprobs),
("target_mask", target_mask),
]:
# if tensor.ndim != 2:
# raise ValueError(f"{name} must be a 2D tensor.")
if tensor.shape[0] != expected_len:
raise ValueError(
f"Length mismatch: student_input ({expected_len}) vs {name} ({tensor.shape[0]})."
)
if not (
target_token_ids.shape[1]
== target_logprobs.shape[1]
== target_mask.shape[1]
):
raise ValueError(
"Top-k dimension mismatch among target_token_ids, target_logprobs, target_mask."
)
if target_token_ids.max() >= lm_head_weight.shape[0]:
raise ValueError(
f"target_token_ids contains indices ({target_token_ids.max().item()}) "
f"out of bounds for student vocabulary size ({lm_head_weight.shape[0]})."
)
# if self.compute_ce_loss and self.weight_hard_loss > 0.0:
# if true_labels.ndim != 1 or true_labels.shape[0] != student_hidden_states.shape[0]:
# raise ValueError("true_labels shape mismatch or incorrect dimensions.")
return LigerFusedLinearKLTopKLogprobFunction.apply(
student_hidden_states,
lm_head_weight,
target_token_ids,
target_logprobs,
target_mask,
true_labels,
student_bias,
self.weight_hard_loss,
self.weight_soft_loss,
self.ignore_index,
self.temperature,
self.compiled,
self.chunk_size,
self.compute_ce_loss,
)
# class LigerFusedTopKKDLossFunction(torch.autograd.Function):
# def forward(
# self,
# student_input: torch.Tensor,
# student_weight: torch.Tensor,
# teacher_input: torch.Tensor, # teacher logprobs
# teacher_token_ids: torch.Tensor,
# teacher_mask: torch.Tensor,
# hard_labels: torch.LongTensor,
# ):
# return LigerFusedLinearKLTopKLogprobFunction.apply(
# student_input,
# student_weight,
# teacher_token_ids,
# teacher_input, # teacher logprobs
# teacher_mask,
# hard_labels,
# None,
# self.weight_hard_loss,
# self.weight_soft_loss,
# self.ignore_index,
# self.temperature,
# self.compiled,
# self.chunk_size,
# self.compute_ce_loss,
# )
def kldiv_forward_qwen2(
self,
input_ids: Optional[torch.LongTensor] = None,
target_logprobs: Optional[torch.Tensor] = None,
target_token_ids: Optional[torch.LongTensor] = None,
target_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
loss = self.loss_function(
self.lm_head.weight,
hidden_states,
target_token_ids,
target_logprobs,
target_mask,
true_labels=labels,
)
return CausalLMOutputWithPast(
loss=loss,
logits=None,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_kernel_to_qwen2():
from transformers.models.qwen2 import modeling_qwen2
modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_qwen2

View File

@@ -18,7 +18,7 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import ChunkedTopKKDLoss
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
class AxolotlKDTrainer(AxolotlTrainer):
@@ -29,8 +29,11 @@ class AxolotlKDTrainer(AxolotlTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = True
self.kd_loss_fn = ChunkedTopKKDLoss(
num_output_chunks=8, kd_temperature=self.args.kd_temperature
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
self.args.kd_ce_alpha, # hard label loss
self.args.kd_alpha, # kd loss
self.args.kd_temperature,
compute_ce_loss=bool(self.args.kd_ce_alpha),
)
def _set_signature_columns_if_needed(self):
@@ -59,11 +62,11 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior.
"""
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")
# target_logprobs = inputs.pop("target_logprobs")
# target_token_ids = inputs.pop("target_token_ids")
# target_mask = inputs.pop("target_mask")
seq_len = target_token_ids.shape[1]
# seq_len = target_token_ids.shape[1]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
@@ -71,36 +74,37 @@ class AxolotlKDTrainer(AxolotlTrainer):
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
return outputs[0]
#
# # FIXME: account for tokenizer.padding_side
# student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
#
# shift_logits = student_logits.contiguous()
# target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
# target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
# target_mask_for_loss = target_mask[..., 1:, :].contiguous()
#
# loss_kd = self.kd_loss_fn(
# shift_logits,
# target_token_ids_for_loss,
# target_logprobs_for_loss,
# target_mask_for_loss,
# num_items_in_batch=num_items_in_batch,
# )
#
# if self.args.kd_ce_alpha > 0:
# kd_alpha = self.args.kd_alpha
# loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
# else:
# loss = loss_kd
# # Save past state if it exists
# # TODO: this needs to be fixed and made cleaner later.
# if self.args.past_index >= 0:
# self._past = outputs[ # pylint: disable=attribute-defined-outside-init
# self.args.past_index
# ]
#
# if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
# loss *= self.accelerator.num_processes
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
loss_kd = self.kd_loss_fn(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
)
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
self.args.past_index
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss
# return (loss, outputs) if return_outputs else loss