WIP chunked KD loss with autograd wrapper
This commit is contained in:
@@ -34,3 +34,8 @@ class KDPlugin(BasePlugin):
|
||||
|
||||
return AxolotlKDTrainer
|
||||
return None
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from .kernels.liger import apply_kernel_to_qwen2
|
||||
|
||||
apply_kernel_to_qwen2()
|
||||
|
||||
@@ -183,13 +183,124 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
||||
"""
|
||||
Strat for datasets with complete structured KD logprob data
|
||||
"""
|
||||
|
||||
def transform_logprobs(self, sample):
|
||||
"""
|
||||
Transform logprobs to target format for KD training
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||
top_k_vals = [
|
||||
len(logprobs[i])
|
||||
for i in range(len(logprobs))
|
||||
if logprobs[i] is not None and len(logprobs[i])
|
||||
]
|
||||
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||
top_k = min(max_top_k, min_top_k)
|
||||
if top_k == 0:
|
||||
raise ValueError("No non-zero top-k logprobs found.")
|
||||
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
if input_padding_len < 0:
|
||||
# logprobs is longer than target_seq_len,
|
||||
# so we need to slice from the left/beginning of logprobs
|
||||
logprobs = logprobs[:-input_seq_len]
|
||||
input_padding_len = 0
|
||||
# target_seq_len = input_seq_len
|
||||
|
||||
# truncate the second dimension of the logprobs to top_k
|
||||
logprobs = [row[:top_k] for row in logprobs]
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
|
||||
# we shift for causal models in the trainer, so start the range from 0
|
||||
for _ in range(0, input_padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for position in range(input_padding_len, input_seq_len):
|
||||
if sample["labels"][position] == -100:
|
||||
target_mask.append([0] * top_k)
|
||||
else:
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for token_pos_logprobs, pos_target_token_ids in zip(
|
||||
logprobs, sample["target_token_ids"]
|
||||
):
|
||||
# Convert to a tensor for easier manipulation
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
token_pos_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# normalize probabilities to sum to 1 in case they aren't already
|
||||
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||
if teacher_probs_t1_sum > 1e-9:
|
||||
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
else:
|
||||
teacher_probs_t2 = teacher_probs_t1
|
||||
# Re-normalize
|
||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
# Convert back to log
|
||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||
|
||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(pos_target_token_ids)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
sample["target_mask"] = target_mask
|
||||
|
||||
return sample
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
target_token_ids = prompt.pop("target_token_ids")
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class KDStrategyLoader(StrategyLoader):
|
||||
"""
|
||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
return ChatTemplateStrategyWithKD
|
||||
return ChatTemplateStrategyWithKDv2
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
strategy_params = super()._get_strategy_params(cfg, ds_cfg)
|
||||
|
||||
@@ -56,6 +56,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
return_tensors = self.return_tensors
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
max_len = 0
|
||||
|
||||
# Pad labels and position_ids first
|
||||
for feature_name, pad_token_id in [
|
||||
@@ -106,7 +107,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
target_mask_list.append(f.pop("target_mask"))
|
||||
|
||||
# Determine max lengths
|
||||
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
|
||||
max_teacher_seq_len = max_len or max(
|
||||
len(seq) for seq in target_logprobs_list
|
||||
)
|
||||
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
||||
|
||||
padded_target_logprobs = []
|
||||
|
||||
576
src/axolotl/integrations/kd/kernels/liger.py
Normal file
576
src/axolotl/integrations/kd/kernels/liger.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""
|
||||
Liger Kernels for Chunked Top-K Log-Prob Distillation
|
||||
"""
|
||||
|
||||
from typing import Optional, Union, Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Assuming LigerFusedLinearDistillationBase is in this path and can be imported
|
||||
# If not, its structure would need to be replicated or specific utilities copied.
|
||||
from liger_kernel.chunked_loss.fused_linear_distillation import (
|
||||
LigerFusedLinearDistillationBase,
|
||||
)
|
||||
from transformers import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.qwen2.modeling_qwen2 import KwargsForCausalLM
|
||||
|
||||
|
||||
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
|
||||
@staticmethod
|
||||
def distillation_loss_fn(
|
||||
student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled
|
||||
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
||||
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||
temperature: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Top-K KL divergence loss for a chunk.
|
||||
Args:
|
||||
student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V).
|
||||
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
|
||||
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
|
||||
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
|
||||
temperature: Temperature used for scaling.
|
||||
Returns:
|
||||
Sum of KL divergence losses for the chunk.
|
||||
"""
|
||||
student_logits_temp_scaled = student_logits_temp_scaled.float()
|
||||
target_logprobs_chunk = target_logprobs_chunk.float()
|
||||
|
||||
# Gather student logits for the top-k teacher token IDs
|
||||
# student_logits_temp_scaled: [chunk_size, vocab_size]
|
||||
# target_token_ids_chunk: [chunk_size, top_k]
|
||||
# student_logits_topk_temp_scaled: [chunk_size, top_k]
|
||||
student_logits_topk_temp_scaled = torch.gather(
|
||||
student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk
|
||||
)
|
||||
|
||||
# Student log-probabilities for the gathered top-k tokens
|
||||
student_lse = torch.logsumexp(
|
||||
student_logits_temp_scaled, dim=-1, keepdim=True
|
||||
) # [chunk_size, 1]
|
||||
student_logprobs_topk_temp_scaled = (
|
||||
student_logits_topk_temp_scaled - student_lse
|
||||
)
|
||||
|
||||
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
|
||||
|
||||
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
|
||||
target_logprobs_valid = target_logprobs_chunk[valid_mask]
|
||||
|
||||
# Teacher probabilities P(y|x_teacher) from logprobs
|
||||
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
|
||||
teacher_probs_valid = target_logprobs_valid.exp()
|
||||
|
||||
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
|
||||
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
|
||||
# The distillation loss is often formulated as -sum(P_teacher * log P_student)
|
||||
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
|
||||
# Here, target_logprobs_valid are log_softmax_teacher.
|
||||
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
||||
kd_loss_per_token = teacher_probs_valid * (
|
||||
target_logprobs_valid - student_logprobs_topk_valid
|
||||
)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
if temperature != 1.0:
|
||||
kd_loss = kd_loss * (temperature**2)
|
||||
|
||||
return kd_loss
|
||||
|
||||
@staticmethod
|
||||
def _compute_loss_kl_topk(
|
||||
student_input_chunk: torch.Tensor,
|
||||
student_weight: torch.Tensor,
|
||||
# Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value
|
||||
# or through `partial`. Let's make them explicit here for clarity.
|
||||
target_token_ids_chunk: torch.Tensor,
|
||||
target_logprobs_chunk: torch.Tensor,
|
||||
target_mask_chunk: torch.Tensor,
|
||||
target_chunk: torch.Tensor, # For hard loss (true labels)
|
||||
student_bias: torch.Tensor = None, # This will be one of the grad targets
|
||||
# Other params passed via `partial` from `forward`
|
||||
distillation_loss_fn=None,
|
||||
ignore_index: int = -100,
|
||||
weight_hard_loss: float = 0.5,
|
||||
weight_soft_loss: float = 0.5,
|
||||
compute_ce_loss: bool = True,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
# Compute student logits for the chunk from hidden states and LM head
|
||||
# student_input_chunk: [chunk_size, hidden_dim]
|
||||
# student_lm_head_weight: [vocab_size, hidden_dim]
|
||||
# student_logits_chunk: [chunk_size, vocab_size]
|
||||
student_logits_chunk = F.linear(
|
||||
student_input_chunk, student_weight, student_bias
|
||||
)
|
||||
|
||||
ce_loss = torch.tensor(
|
||||
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
|
||||
)
|
||||
if compute_ce_loss and weight_hard_loss > 0.0:
|
||||
ce_loss = F.cross_entropy(
|
||||
student_logits_chunk.view(-1, student_logits_chunk.shape[-1]),
|
||||
target_chunk.view(-1),
|
||||
reduction="sum",
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
soft_loss = torch.tensor(
|
||||
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
|
||||
)
|
||||
if weight_soft_loss > 0.0:
|
||||
student_logits_chunk_temp_scaled = student_logits_chunk / temperature
|
||||
|
||||
# Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max()
|
||||
# No explicit padding here; user must ensure vocab alignment or pre-pad student_weight.
|
||||
|
||||
soft_loss = distillation_loss_fn(
|
||||
student_logits_chunk_temp_scaled,
|
||||
target_token_ids_chunk,
|
||||
target_logprobs_chunk,
|
||||
target_mask_chunk,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss
|
||||
|
||||
# return loss, (soft_loss, ce_loss, student_logits_chunk) # Aux outputs
|
||||
return loss, (soft_loss, ce_loss) # Aux outputs
|
||||
|
||||
@classmethod
|
||||
def forward(
|
||||
cls,
|
||||
ctx,
|
||||
student_input: torch.Tensor, # [batch_size, seq_len, dim]
|
||||
student_lm_head_weight: torch.Tensor, # [dim, vocab_size]
|
||||
target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||
target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||
target_mask: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||
true_labels: torch.Tensor, # [batch_size, seq_len]
|
||||
student_lm_head_bias: torch.Tensor = None,
|
||||
weight_hard_loss: float = 0.5,
|
||||
weight_soft_loss: float = 0.5,
|
||||
ignore_index: int = -100,
|
||||
temperature: float = 1.0,
|
||||
compiled: bool = False,
|
||||
chunk_size: int = 1024,
|
||||
compute_ce_loss: bool = True,
|
||||
):
|
||||
CHUNK_SIZE = chunk_size
|
||||
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
|
||||
grad_inputs_list = []
|
||||
grad_bias_acc = (
|
||||
torch.zeros_like(student_lm_head_bias)
|
||||
if student_lm_head_bias is not None
|
||||
else None
|
||||
)
|
||||
loss_acc = torch.zeros(
|
||||
(), device=student_input.device, dtype=student_input.dtype
|
||||
)
|
||||
|
||||
# This function will be what torch.func.grad_and_value differentiates.
|
||||
# It takes student_input_chunk, student_weight (full), student_bias (full) as primals.
|
||||
# Other necessary data (target_*, etc.) are passed as non-differentiable arguments.
|
||||
def loss_fn_for_grad(
|
||||
_student_input_chunk,
|
||||
_student_lm_head_weight, # full weight
|
||||
_student_lm_head_bias, # full bias
|
||||
# Fixed arguments for a given chunk, not differentiated:
|
||||
_target_token_ids_chunk,
|
||||
_target_logprobs_chunk,
|
||||
_target_mask_chunk,
|
||||
_true_labels_chunk,
|
||||
):
|
||||
return cls._compute_loss_kl_topk(
|
||||
student_input_chunk=_student_input_chunk,
|
||||
student_weight=_student_lm_head_weight,
|
||||
target_token_ids_chunk=_target_token_ids_chunk,
|
||||
target_logprobs_chunk=_target_logprobs_chunk,
|
||||
target_mask_chunk=_target_mask_chunk,
|
||||
target_chunk=_true_labels_chunk,
|
||||
student_bias=_student_lm_head_bias,
|
||||
distillation_loss_fn=cls.distillation_loss_fn,
|
||||
ignore_index=ignore_index,
|
||||
weight_hard_loss=weight_hard_loss,
|
||||
weight_soft_loss=weight_soft_loss,
|
||||
compute_ce_loss=compute_ce_loss,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
def accumulate_chunk_grads(
|
||||
student_input_chunk_ac,
|
||||
target_token_ids_chunk_ac,
|
||||
target_logprobs_chunk_ac,
|
||||
target_mask_chunk_ac,
|
||||
true_labels_chunk_ac,
|
||||
):
|
||||
# student_weight and student_bias are closed over from the outer scope (full tensors)
|
||||
if student_lm_head_bias is not None:
|
||||
(
|
||||
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
||||
(chunk_loss, _aux_outputs),
|
||||
) = torch.func.grad_and_value(
|
||||
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
|
||||
)(
|
||||
student_input_chunk_ac,
|
||||
student_lm_head_weight,
|
||||
student_lm_head_bias, # primals
|
||||
target_token_ids_chunk_ac,
|
||||
target_logprobs_chunk_ac,
|
||||
target_mask_chunk_ac,
|
||||
true_labels_chunk_ac,
|
||||
) # non-primals
|
||||
grad_bias_acc.add_(chunk_grad_bias)
|
||||
else:
|
||||
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
|
||||
(
|
||||
(chunk_grad_input, chunk_grad_weight), # No grad for bias
|
||||
(chunk_loss, _aux_outputs),
|
||||
) = torch.func.grad_and_value(
|
||||
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
|
||||
)(
|
||||
student_input_chunk_ac,
|
||||
student_lm_head_weight,
|
||||
None, # Pass None for student_bias primal
|
||||
target_token_ids_chunk_ac,
|
||||
target_logprobs_chunk_ac,
|
||||
target_mask_chunk_ac,
|
||||
true_labels_chunk_ac,
|
||||
)
|
||||
|
||||
grad_weight_acc.add_(chunk_grad_weight)
|
||||
loss_acc.add_(chunk_loss)
|
||||
return chunk_grad_input
|
||||
|
||||
if compiled:
|
||||
accumulate_chunk_grads_compiled = torch.compile(
|
||||
accumulate_chunk_grads, dynamic=True, backend="inductor"
|
||||
) # dynamic=True often helpful
|
||||
else:
|
||||
accumulate_chunk_grads_compiled = accumulate_chunk_grads
|
||||
|
||||
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
|
||||
B, N, D = student_input.shape
|
||||
K = target_token_ids.shape[-1]
|
||||
|
||||
print("student_input shape: " + str(student_input.shape))
|
||||
print("target_token_ids shape: " + str(target_token_ids.shape))
|
||||
print("target_logprobs shape: " + str(target_logprobs.shape))
|
||||
print("target_mask shape: " + str(target_mask.shape))
|
||||
print("true_labels shape: " + str(true_labels.shape))
|
||||
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
|
||||
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
|
||||
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
|
||||
target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
|
||||
true_labels_flat = true_labels.reshape(-1)
|
||||
print("student_input_flat shape: " + str(student_input_flat.shape))
|
||||
print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape))
|
||||
print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape))
|
||||
print("target_mask_flat shape: " + str(target_mask_flat.shape))
|
||||
print("true_labels_flat shape: " + str(true_labels_flat.shape))
|
||||
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
|
||||
|
||||
_student_input_chunks = torch.chunk(
|
||||
student_input_flat, chunks=num_chunks, dim=0
|
||||
)
|
||||
_target_token_ids_chunks = torch.chunk(
|
||||
target_token_ids_flat, chunks=num_chunks, dim=0
|
||||
)
|
||||
_target_logprobs_chunks = torch.chunk(
|
||||
target_logprobs_flat, chunks=num_chunks, dim=0
|
||||
)
|
||||
_target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0)
|
||||
_true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0)
|
||||
|
||||
for i in range(num_chunks):
|
||||
grad_input_chunk = accumulate_chunk_grads_compiled(
|
||||
_student_input_chunks[i],
|
||||
_target_token_ids_chunks[i],
|
||||
_target_logprobs_chunks[i],
|
||||
_target_mask_chunks[i],
|
||||
_true_labels_chunks[i],
|
||||
)
|
||||
grad_inputs_list.append(grad_input_chunk)
|
||||
|
||||
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
|
||||
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
||||
|
||||
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
||||
ctx.hyperparams_count = 7 # Corresponds to number of hyperparams after main tensors in fwd signature
|
||||
ctx.bias_was_none = student_lm_head_bias is None
|
||||
ctx.orig_dims = (B, N, D, K)
|
||||
|
||||
return loss_acc / (true_labels_flat != ignore_index).sum()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_input_flat, grad_weight, grad_bias_maybe = (
|
||||
ctx.saved_tensors
|
||||
) # grad_input_flat is (B*N, D)
|
||||
|
||||
# Scale gradients by grad_output if it's not 1.0
|
||||
if not torch.equal(
|
||||
grad_output,
|
||||
torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype),
|
||||
):
|
||||
grad_input_flat = grad_input_flat * grad_output
|
||||
grad_weight = grad_weight * grad_output
|
||||
if grad_bias_maybe is not None:
|
||||
grad_bias_maybe = grad_bias_maybe * grad_output
|
||||
|
||||
# Reshape grad_input_flat to match original student_input shape (B, N, D)
|
||||
# ctx.orig_dims stores (B, N, D, K)
|
||||
# We need the first three dimensions for student_input's shape.
|
||||
# Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors
|
||||
if (
|
||||
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
|
||||
and grad_input_flat.numel() == 0
|
||||
):
|
||||
# If original input was empty, gradient should also be empty with correct shape
|
||||
grad_input_reshaped = torch.zeros(
|
||||
ctx.orig_dims[0],
|
||||
ctx.orig_dims[1],
|
||||
ctx.orig_dims[2],
|
||||
dtype=grad_input_flat.dtype,
|
||||
device=grad_input_flat.device,
|
||||
)
|
||||
elif grad_input_flat.numel() == 0 and not (
|
||||
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
|
||||
):
|
||||
# This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad)
|
||||
# but as a safeguard:
|
||||
grad_input_reshaped = torch.zeros(
|
||||
ctx.orig_dims[0],
|
||||
ctx.orig_dims[1],
|
||||
ctx.orig_dims[2],
|
||||
dtype=grad_input_flat.dtype,
|
||||
device=grad_input_flat.device,
|
||||
)
|
||||
else:
|
||||
grad_input_reshaped = grad_input_flat.view(
|
||||
ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2]
|
||||
)
|
||||
|
||||
nones_for_hyperparams = [None] * ctx.hyperparams_count
|
||||
grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None
|
||||
|
||||
return (
|
||||
grad_input_reshaped, # Gradient for student_input (reshaped)
|
||||
grad_weight, # Gradient for student_lm_head_weight
|
||||
None, # Gradient for target_token_ids
|
||||
None, # Gradient for target_logprobs
|
||||
None, # Gradient for target_mask
|
||||
None, # Gradient for true_labels
|
||||
grad_bias_return, # Gradient for student_lm_head_bias
|
||||
*nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss
|
||||
)
|
||||
|
||||
|
||||
class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight_hard_loss: float = 0.5,
|
||||
weight_soft_loss: float = 0.5,
|
||||
temperature: float = 1.0, # This is the kd_temperature
|
||||
ignore_index: int = -100,
|
||||
compiled: bool = True,
|
||||
chunk_size: int = 1024,
|
||||
compute_ce_loss: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
|
||||
raise ValueError("Loss weights must be between 0.0 and 1.0.")
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be positive.")
|
||||
|
||||
self.weight_hard_loss = weight_hard_loss
|
||||
self.weight_soft_loss = weight_soft_loss
|
||||
self.temperature = temperature
|
||||
self.ignore_index = ignore_index
|
||||
self.compiled = compiled
|
||||
self.chunk_size = chunk_size
|
||||
self.compute_ce_loss = compute_ce_loss
|
||||
|
||||
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
|
||||
print(
|
||||
f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero."
|
||||
)
|
||||
# self.weight_hard_loss = 0.0 # Or let user manage this
|
||||
if self.weight_soft_loss == 0.0:
|
||||
print(
|
||||
"Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head
|
||||
student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head
|
||||
target_token_ids: torch.Tensor,
|
||||
target_logprobs: torch.Tensor,
|
||||
target_mask: torch.Tensor,
|
||||
true_labels: torch.Tensor,
|
||||
student_bias: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
print(student_hidden_states.shape)
|
||||
# Input validation
|
||||
# if student_hidden_states.ndim != 2 or lm_head_weight.ndim != 2:
|
||||
# raise ValueError("student_input and student_weight must be 2D tensors.")
|
||||
# if student_hidden_states.shape[1] != lm_head_weight.shape[1]:
|
||||
# raise ValueError("Hidden dimension mismatch between student_input and student_weight.")
|
||||
# if student_bias is not None and (student_bias.ndim != 1 or student_bias.shape[0] != lm_head_weight.shape[0]):
|
||||
# raise ValueError("student_bias shape mismatch.")
|
||||
|
||||
if self.weight_soft_loss > 0.0:
|
||||
expected_len = student_hidden_states.shape[0]
|
||||
for name, tensor in [
|
||||
("target_token_ids", target_token_ids),
|
||||
("target_logprobs", target_logprobs),
|
||||
("target_mask", target_mask),
|
||||
]:
|
||||
# if tensor.ndim != 2:
|
||||
# raise ValueError(f"{name} must be a 2D tensor.")
|
||||
if tensor.shape[0] != expected_len:
|
||||
raise ValueError(
|
||||
f"Length mismatch: student_input ({expected_len}) vs {name} ({tensor.shape[0]})."
|
||||
)
|
||||
if not (
|
||||
target_token_ids.shape[1]
|
||||
== target_logprobs.shape[1]
|
||||
== target_mask.shape[1]
|
||||
):
|
||||
raise ValueError(
|
||||
"Top-k dimension mismatch among target_token_ids, target_logprobs, target_mask."
|
||||
)
|
||||
if target_token_ids.max() >= lm_head_weight.shape[0]:
|
||||
raise ValueError(
|
||||
f"target_token_ids contains indices ({target_token_ids.max().item()}) "
|
||||
f"out of bounds for student vocabulary size ({lm_head_weight.shape[0]})."
|
||||
)
|
||||
|
||||
# if self.compute_ce_loss and self.weight_hard_loss > 0.0:
|
||||
# if true_labels.ndim != 1 or true_labels.shape[0] != student_hidden_states.shape[0]:
|
||||
# raise ValueError("true_labels shape mismatch or incorrect dimensions.")
|
||||
|
||||
return LigerFusedLinearKLTopKLogprobFunction.apply(
|
||||
student_hidden_states,
|
||||
lm_head_weight,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
true_labels,
|
||||
student_bias,
|
||||
self.weight_hard_loss,
|
||||
self.weight_soft_loss,
|
||||
self.ignore_index,
|
||||
self.temperature,
|
||||
self.compiled,
|
||||
self.chunk_size,
|
||||
self.compute_ce_loss,
|
||||
)
|
||||
|
||||
|
||||
# class LigerFusedTopKKDLossFunction(torch.autograd.Function):
|
||||
# def forward(
|
||||
# self,
|
||||
# student_input: torch.Tensor,
|
||||
# student_weight: torch.Tensor,
|
||||
# teacher_input: torch.Tensor, # teacher logprobs
|
||||
# teacher_token_ids: torch.Tensor,
|
||||
# teacher_mask: torch.Tensor,
|
||||
# hard_labels: torch.LongTensor,
|
||||
# ):
|
||||
# return LigerFusedLinearKLTopKLogprobFunction.apply(
|
||||
# student_input,
|
||||
# student_weight,
|
||||
# teacher_token_ids,
|
||||
# teacher_input, # teacher logprobs
|
||||
# teacher_mask,
|
||||
# hard_labels,
|
||||
# None,
|
||||
# self.weight_hard_loss,
|
||||
# self.weight_soft_loss,
|
||||
# self.ignore_index,
|
||||
# self.temperature,
|
||||
# self.compiled,
|
||||
# self.chunk_size,
|
||||
# self.compute_ce_loss,
|
||||
# )
|
||||
|
||||
|
||||
def kldiv_forward_qwen2(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
target_logprobs: Optional[torch.Tensor] = None,
|
||||
target_token_ids: Optional[torch.LongTensor] = None,
|
||||
target_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
loss = self.loss_function(
|
||||
self.lm_head.weight,
|
||||
hidden_states,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
true_labels=labels,
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=None,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_kernel_to_qwen2():
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_qwen2
|
||||
@@ -18,7 +18,7 @@ KD trainer
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import ChunkedTopKKDLoss
|
||||
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
@@ -29,8 +29,11 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_accepts_loss_kwargs = True
|
||||
self.kd_loss_fn = ChunkedTopKKDLoss(
|
||||
num_output_chunks=8, kd_temperature=self.args.kd_temperature
|
||||
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
|
||||
self.args.kd_ce_alpha, # hard label loss
|
||||
self.args.kd_alpha, # kd loss
|
||||
self.args.kd_temperature,
|
||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||
)
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
@@ -59,11 +62,11 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
target_mask = inputs.pop("target_mask")
|
||||
# target_logprobs = inputs.pop("target_logprobs")
|
||||
# target_token_ids = inputs.pop("target_token_ids")
|
||||
# target_mask = inputs.pop("target_mask")
|
||||
|
||||
seq_len = target_token_ids.shape[1]
|
||||
# seq_len = target_token_ids.shape[1]
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
@@ -71,36 +74,37 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
return outputs[0]
|
||||
#
|
||||
# # FIXME: account for tokenizer.padding_side
|
||||
# student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
|
||||
#
|
||||
# shift_logits = student_logits.contiguous()
|
||||
# target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||
# target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
# target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
#
|
||||
# loss_kd = self.kd_loss_fn(
|
||||
# shift_logits,
|
||||
# target_token_ids_for_loss,
|
||||
# target_logprobs_for_loss,
|
||||
# target_mask_for_loss,
|
||||
# num_items_in_batch=num_items_in_batch,
|
||||
# )
|
||||
#
|
||||
# if self.args.kd_ce_alpha > 0:
|
||||
# kd_alpha = self.args.kd_alpha
|
||||
# loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
# else:
|
||||
# loss = loss_kd
|
||||
# # Save past state if it exists
|
||||
# # TODO: this needs to be fixed and made cleaner later.
|
||||
# if self.args.past_index >= 0:
|
||||
# self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
||||
# self.args.past_index
|
||||
# ]
|
||||
#
|
||||
# if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
# loss *= self.accelerator.num_processes
|
||||
|
||||
# FIXME: account for tokenizer.padding_side
|
||||
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
|
||||
|
||||
shift_logits = student_logits.contiguous()
|
||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
|
||||
loss_kd = self.kd_loss_fn(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
kd_alpha = self.args.kd_alpha
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
else:
|
||||
loss = loss_kd
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
||||
self.args.past_index
|
||||
]
|
||||
|
||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
loss *= self.accelerator.num_processes
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
|
||||
Reference in New Issue
Block a user