diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index 8a6e3eda1..017b9533d 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -34,3 +34,8 @@ class KDPlugin(BasePlugin): return AxolotlKDTrainer return None + + def pre_model_load(self, cfg): + from .kernels.liger import apply_kernel_to_qwen2 + + apply_kernel_to_qwen2() diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 43a90216e..1d696540d 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -183,13 +183,124 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): return tokenized_prompt +class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD): + """ + Strat for datasets with complete structured KD logprob data + """ + + def transform_logprobs(self, sample): + """ + Transform logprobs to target format for KD training + """ + # pylint: disable=duplicate-code + + logprobs = sample.pop(self.logprobs_field) + target_seq_len = len(logprobs) + input_seq_len = len(sample["input_ids"]) + input_padding_len = input_seq_len - target_seq_len + # get non-zero top-k (prune None logprobs from vllm data step) + top_k_vals = [ + len(logprobs[i]) + for i in range(len(logprobs)) + if logprobs[i] is not None and len(logprobs[i]) + ] + max_top_k = max(set(top_k_vals), key=top_k_vals.count) + min_top_k = min(set(top_k_vals), key=top_k_vals.count) + top_k = min(max_top_k, min_top_k) + if top_k == 0: + raise ValueError("No non-zero top-k logprobs found.") + + target_logprobs = [] + target_token_ids = [] + target_mask = [] + + if input_padding_len < 0: + # logprobs is longer than target_seq_len, + # so we need to slice from the left/beginning of logprobs + logprobs = logprobs[:-input_seq_len] + input_padding_len = 0 + # target_seq_len = input_seq_len + + # truncate the second dimension of the logprobs to top_k + logprobs = [row[:top_k] for row in logprobs] + + # fill with -inf for padding_len tokens for top_k tokens + # extend target_logprobs with a padding_len x top_k 2D list filled with -inf + + # we shift for causal models in the trainer, so start the range from 0 + for _ in range(0, input_padding_len): + target_logprobs.append([-float("inf")] * top_k) + target_token_ids.append(list(range(top_k))) + target_mask.append([0] * top_k) + + for position in range(input_padding_len, input_seq_len): + if sample["labels"][position] == -100: + target_mask.append([0] * top_k) + else: + target_mask.append([1] * top_k) + + for token_pos_logprobs, pos_target_token_ids in zip( + logprobs, sample["target_token_ids"] + ): + # Convert to a tensor for easier manipulation + position_logprobs_tensor = torch.tensor( + token_pos_logprobs, dtype=torch.float + ) + + # Now we have distribution at T1 in log form, i.e. log p_{T1}(k). + # Next, re-scale to T2 = self.kd_temperature via exponent-based trick + # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z + # + # Convert from log to probability + teacher_probs_t1 = position_logprobs_tensor.exp() + # normalize probabilities to sum to 1 in case they aren't already + teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True) + if teacher_probs_t1_sum > 1e-9: + teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum + if self.kd_temperature != self.gen_temperature: + # Exponentiate by factor (T1 / T2) + exponent = self.gen_temperature / self.kd_temperature + teacher_probs_t2 = teacher_probs_t1**exponent + else: + teacher_probs_t2 = teacher_probs_t1 + # Re-normalize + teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( + dim=0, keepdim=True + ) + # Convert back to log + position_logprobs_tensor = torch.log(teacher_probs_t2) + + # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor + position_logprobs_scaled = position_logprobs_tensor.tolist() + + target_logprobs.append(position_logprobs_scaled) + target_token_ids.append(pos_target_token_ids) + + # Update sample with transformed logprobs + sample["target_logprobs"] = target_logprobs + sample["target_token_ids"] = target_token_ids + sample["target_mask"] = target_mask + + return sample + + def _tokenize_single_prompt(self, prompt): + logprobs = prompt.pop(self.logprobs_field) + target_token_ids = prompt.pop("target_token_ids") + tokenized_prompt = super()._tokenize_single_prompt(prompt) + tokenized_prompt[self.logprobs_field] = logprobs + tokenized_prompt["target_token_ids"] = target_token_ids + tokenized_prompt = self.transform_logprobs(tokenized_prompt) + + return tokenized_prompt + + class KDStrategyLoader(StrategyLoader): """ Load ChatTemplateStrategy with KD support using StrategyLoader. """ def _get_strategy_cls(self): - return ChatTemplateStrategyWithKD + return ChatTemplateStrategyWithKDv2 def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]): strategy_params = super()._get_strategy_params(cfg, ds_cfg) diff --git a/src/axolotl/integrations/kd/collator.py b/src/axolotl/integrations/kd/collator.py index 8089b5b25..0cc745b78 100644 --- a/src/axolotl/integrations/kd/collator.py +++ b/src/axolotl/integrations/kd/collator.py @@ -56,6 +56,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): return_tensors = self.return_tensors padding_side = self.tokenizer.padding_side + max_len = 0 # Pad labels and position_ids first for feature_name, pad_token_id in [ @@ -106,7 +107,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): target_mask_list.append(f.pop("target_mask")) # Determine max lengths - max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list) + max_teacher_seq_len = max_len or max( + len(seq) for seq in target_logprobs_list + ) max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq) padded_target_logprobs = [] diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py new file mode 100644 index 000000000..d42716432 --- /dev/null +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -0,0 +1,576 @@ +""" +Liger Kernels for Chunked Top-K Log-Prob Distillation +""" + +from typing import Optional, Union, Unpack + +import torch +import torch.nn.functional as F + +# Assuming LigerFusedLinearDistillationBase is in this path and can be imported +# If not, its structure would need to be replicated or specific utilities copied. +from liger_kernel.chunked_loss.fused_linear_distillation import ( + LigerFusedLinearDistillationBase, +) +from transformers import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.qwen2.modeling_qwen2 import KwargsForCausalLM + + +class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): + + @staticmethod + def distillation_loss_fn( + student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled + target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k] + target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs + target_mask_chunk: torch.Tensor, # [chunk_size, top_k] + temperature: float, + ) -> torch.Tensor: + """ + Compute Top-K KL divergence loss for a chunk. + Args: + student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V). + target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K). + target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K). + target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K). + temperature: Temperature used for scaling. + Returns: + Sum of KL divergence losses for the chunk. + """ + student_logits_temp_scaled = student_logits_temp_scaled.float() + target_logprobs_chunk = target_logprobs_chunk.float() + + # Gather student logits for the top-k teacher token IDs + # student_logits_temp_scaled: [chunk_size, vocab_size] + # target_token_ids_chunk: [chunk_size, top_k] + # student_logits_topk_temp_scaled: [chunk_size, top_k] + student_logits_topk_temp_scaled = torch.gather( + student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk + ) + + # Student log-probabilities for the gathered top-k tokens + student_lse = torch.logsumexp( + student_logits_temp_scaled, dim=-1, keepdim=True + ) # [chunk_size, 1] + student_logprobs_topk_temp_scaled = ( + student_logits_topk_temp_scaled - student_lse + ) + + valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k] + + student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask] + target_logprobs_valid = target_logprobs_chunk[valid_mask] + + # Teacher probabilities P(y|x_teacher) from logprobs + # target_logprobs_valid are already normalized (log(softmax(teacher_logits/T))) + teacher_probs_valid = target_logprobs_valid.exp() + + # KL divergence: sum(P_teacher * (log P_teacher - log P_student)) + # = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student) + # The distillation loss is often formulated as -sum(P_teacher * log P_student) + # or as sum(P_teacher * (log_softmax_teacher - log_softmax_student)) + # Here, target_logprobs_valid are log_softmax_teacher. + # student_logprobs_topk_valid are log_softmax_student (for the selected K indices). + kd_loss_per_token = teacher_probs_valid * ( + target_logprobs_valid - student_logprobs_topk_valid + ) + kd_loss = kd_loss_per_token.sum() + + if temperature != 1.0: + kd_loss = kd_loss * (temperature**2) + + return kd_loss + + @staticmethod + def _compute_loss_kl_topk( + student_input_chunk: torch.Tensor, + student_weight: torch.Tensor, + # Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value + # or through `partial`. Let's make them explicit here for clarity. + target_token_ids_chunk: torch.Tensor, + target_logprobs_chunk: torch.Tensor, + target_mask_chunk: torch.Tensor, + target_chunk: torch.Tensor, # For hard loss (true labels) + student_bias: torch.Tensor = None, # This will be one of the grad targets + # Other params passed via `partial` from `forward` + distillation_loss_fn=None, + ignore_index: int = -100, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + compute_ce_loss: bool = True, + temperature: float = 1.0, + ): + # Compute student logits for the chunk from hidden states and LM head + # student_input_chunk: [chunk_size, hidden_dim] + # student_lm_head_weight: [vocab_size, hidden_dim] + # student_logits_chunk: [chunk_size, vocab_size] + student_logits_chunk = F.linear( + student_input_chunk, student_weight, student_bias + ) + + ce_loss = torch.tensor( + 0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype + ) + if compute_ce_loss and weight_hard_loss > 0.0: + ce_loss = F.cross_entropy( + student_logits_chunk.view(-1, student_logits_chunk.shape[-1]), + target_chunk.view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + soft_loss = torch.tensor( + 0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype + ) + if weight_soft_loss > 0.0: + student_logits_chunk_temp_scaled = student_logits_chunk / temperature + + # Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max() + # No explicit padding here; user must ensure vocab alignment or pre-pad student_weight. + + soft_loss = distillation_loss_fn( + student_logits_chunk_temp_scaled, + target_token_ids_chunk, + target_logprobs_chunk, + target_mask_chunk, + temperature=temperature, + ) + + loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss + + # return loss, (soft_loss, ce_loss, student_logits_chunk) # Aux outputs + return loss, (soft_loss, ce_loss) # Aux outputs + + @classmethod + def forward( + cls, + ctx, + student_input: torch.Tensor, # [batch_size, seq_len, dim] + student_lm_head_weight: torch.Tensor, # [dim, vocab_size] + target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k] + target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k] + target_mask: torch.Tensor, # [batch_size, seq_len, top_k] + true_labels: torch.Tensor, # [batch_size, seq_len] + student_lm_head_bias: torch.Tensor = None, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + compiled: bool = False, + chunk_size: int = 1024, + compute_ce_loss: bool = True, + ): + CHUNK_SIZE = chunk_size + grad_weight_acc = torch.zeros_like(student_lm_head_weight) + grad_inputs_list = [] + grad_bias_acc = ( + torch.zeros_like(student_lm_head_bias) + if student_lm_head_bias is not None + else None + ) + loss_acc = torch.zeros( + (), device=student_input.device, dtype=student_input.dtype + ) + + # This function will be what torch.func.grad_and_value differentiates. + # It takes student_input_chunk, student_weight (full), student_bias (full) as primals. + # Other necessary data (target_*, etc.) are passed as non-differentiable arguments. + def loss_fn_for_grad( + _student_input_chunk, + _student_lm_head_weight, # full weight + _student_lm_head_bias, # full bias + # Fixed arguments for a given chunk, not differentiated: + _target_token_ids_chunk, + _target_logprobs_chunk, + _target_mask_chunk, + _true_labels_chunk, + ): + return cls._compute_loss_kl_topk( + student_input_chunk=_student_input_chunk, + student_weight=_student_lm_head_weight, + target_token_ids_chunk=_target_token_ids_chunk, + target_logprobs_chunk=_target_logprobs_chunk, + target_mask_chunk=_target_mask_chunk, + target_chunk=_true_labels_chunk, + student_bias=_student_lm_head_bias, + distillation_loss_fn=cls.distillation_loss_fn, + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + compute_ce_loss=compute_ce_loss, + temperature=temperature, + ) + + def accumulate_chunk_grads( + student_input_chunk_ac, + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ): + # student_weight and student_bias are closed over from the outer scope (full tensors) + if student_lm_head_bias is not None: + ( + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), + (chunk_loss, _aux_outputs), + ) = torch.func.grad_and_value( + loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True + )( + student_input_chunk_ac, + student_lm_head_weight, + student_lm_head_bias, # primals + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ) # non-primals + grad_bias_acc.add_(chunk_grad_bias) + else: + argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight + ( + (chunk_grad_input, chunk_grad_weight), # No grad for bias + (chunk_loss, _aux_outputs), + ) = torch.func.grad_and_value( + loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True + )( + student_input_chunk_ac, + student_lm_head_weight, + None, # Pass None for student_bias primal + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ) + + grad_weight_acc.add_(chunk_grad_weight) + loss_acc.add_(chunk_loss) + return chunk_grad_input + + if compiled: + accumulate_chunk_grads_compiled = torch.compile( + accumulate_chunk_grads, dynamic=True, backend="inductor" + ) # dynamic=True often helpful + else: + accumulate_chunk_grads_compiled = accumulate_chunk_grads + + # Use the same chunking logic as LigerFusedLinearDistillationBase.forward + B, N, D = student_input.shape + K = target_token_ids.shape[-1] + + print("student_input shape: " + str(student_input.shape)) + print("target_token_ids shape: " + str(target_token_ids.shape)) + print("target_logprobs shape: " + str(target_logprobs.shape)) + print("target_mask shape: " + str(target_mask.shape)) + print("true_labels shape: " + str(true_labels.shape)) + student_input_flat = student_input.reshape(-1, student_input.shape[-1]) + target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1]) + target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1]) + target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1]) + true_labels_flat = true_labels.reshape(-1) + print("student_input_flat shape: " + str(student_input_flat.shape)) + print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape)) + print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape)) + print("target_mask_flat shape: " + str(target_mask_flat.shape)) + print("true_labels_flat shape: " + str(true_labels_flat.shape)) + num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE) + + _student_input_chunks = torch.chunk( + student_input_flat, chunks=num_chunks, dim=0 + ) + _target_token_ids_chunks = torch.chunk( + target_token_ids_flat, chunks=num_chunks, dim=0 + ) + _target_logprobs_chunks = torch.chunk( + target_logprobs_flat, chunks=num_chunks, dim=0 + ) + _target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0) + _true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0) + + for i in range(num_chunks): + grad_input_chunk = accumulate_chunk_grads_compiled( + _student_input_chunks[i], + _target_token_ids_chunks[i], + _target_logprobs_chunks[i], + _target_mask_chunks[i], + _true_labels_chunks[i], + ) + grad_inputs_list.append(grad_input_chunk) + + grad_inputs_combined = torch.cat(grad_inputs_list, dim=0) + ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc) + + # For matching None returns in backward for non-tensor/non-grad_requiring inputs + ctx.hyperparams_count = 7 # Corresponds to number of hyperparams after main tensors in fwd signature + ctx.bias_was_none = student_lm_head_bias is None + ctx.orig_dims = (B, N, D, K) + + return loss_acc / (true_labels_flat != ignore_index).sum() + + @staticmethod + def backward(ctx, grad_output): + grad_input_flat, grad_weight, grad_bias_maybe = ( + ctx.saved_tensors + ) # grad_input_flat is (B*N, D) + + # Scale gradients by grad_output if it's not 1.0 + if not torch.equal( + grad_output, + torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype), + ): + grad_input_flat = grad_input_flat * grad_output + grad_weight = grad_weight * grad_output + if grad_bias_maybe is not None: + grad_bias_maybe = grad_bias_maybe * grad_output + + # Reshape grad_input_flat to match original student_input shape (B, N, D) + # ctx.orig_dims stores (B, N, D, K) + # We need the first three dimensions for student_input's shape. + # Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors + if ( + ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0 + and grad_input_flat.numel() == 0 + ): + # If original input was empty, gradient should also be empty with correct shape + grad_input_reshaped = torch.zeros( + ctx.orig_dims[0], + ctx.orig_dims[1], + ctx.orig_dims[2], + dtype=grad_input_flat.dtype, + device=grad_input_flat.device, + ) + elif grad_input_flat.numel() == 0 and not ( + ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0 + ): + # This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad) + # but as a safeguard: + grad_input_reshaped = torch.zeros( + ctx.orig_dims[0], + ctx.orig_dims[1], + ctx.orig_dims[2], + dtype=grad_input_flat.dtype, + device=grad_input_flat.device, + ) + else: + grad_input_reshaped = grad_input_flat.view( + ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2] + ) + + nones_for_hyperparams = [None] * ctx.hyperparams_count + grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None + + return ( + grad_input_reshaped, # Gradient for student_input (reshaped) + grad_weight, # Gradient for student_lm_head_weight + None, # Gradient for target_token_ids + None, # Gradient for target_logprobs + None, # Gradient for target_mask + None, # Gradient for true_labels + grad_bias_return, # Gradient for student_lm_head_bias + *nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss + ) + + +class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): + def __init__( + self, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + temperature: float = 1.0, # This is the kd_temperature + ignore_index: int = -100, + compiled: bool = True, + chunk_size: int = 1024, + compute_ce_loss: bool = True, + ): + super().__init__() + if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0): + raise ValueError("Loss weights must be between 0.0 and 1.0.") + if temperature <= 0: + raise ValueError("Temperature must be positive.") + + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.temperature = temperature + self.ignore_index = ignore_index + self.compiled = compiled + self.chunk_size = chunk_size + self.compute_ce_loss = compute_ce_loss + + if not self.compute_ce_loss and self.weight_hard_loss > 0.0: + print( + f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero." + ) + # self.weight_hard_loss = 0.0 # Or let user manage this + if self.weight_soft_loss == 0.0: + print( + "Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed." + ) + + def forward( + self, + lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head + student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head + target_token_ids: torch.Tensor, + target_logprobs: torch.Tensor, + target_mask: torch.Tensor, + true_labels: torch.Tensor, + student_bias: torch.Tensor = None, + ) -> torch.Tensor: + print(student_hidden_states.shape) + # Input validation + # if student_hidden_states.ndim != 2 or lm_head_weight.ndim != 2: + # raise ValueError("student_input and student_weight must be 2D tensors.") + # if student_hidden_states.shape[1] != lm_head_weight.shape[1]: + # raise ValueError("Hidden dimension mismatch between student_input and student_weight.") + # if student_bias is not None and (student_bias.ndim != 1 or student_bias.shape[0] != lm_head_weight.shape[0]): + # raise ValueError("student_bias shape mismatch.") + + if self.weight_soft_loss > 0.0: + expected_len = student_hidden_states.shape[0] + for name, tensor in [ + ("target_token_ids", target_token_ids), + ("target_logprobs", target_logprobs), + ("target_mask", target_mask), + ]: + # if tensor.ndim != 2: + # raise ValueError(f"{name} must be a 2D tensor.") + if tensor.shape[0] != expected_len: + raise ValueError( + f"Length mismatch: student_input ({expected_len}) vs {name} ({tensor.shape[0]})." + ) + if not ( + target_token_ids.shape[1] + == target_logprobs.shape[1] + == target_mask.shape[1] + ): + raise ValueError( + "Top-k dimension mismatch among target_token_ids, target_logprobs, target_mask." + ) + if target_token_ids.max() >= lm_head_weight.shape[0]: + raise ValueError( + f"target_token_ids contains indices ({target_token_ids.max().item()}) " + f"out of bounds for student vocabulary size ({lm_head_weight.shape[0]})." + ) + + # if self.compute_ce_loss and self.weight_hard_loss > 0.0: + # if true_labels.ndim != 1 or true_labels.shape[0] != student_hidden_states.shape[0]: + # raise ValueError("true_labels shape mismatch or incorrect dimensions.") + + return LigerFusedLinearKLTopKLogprobFunction.apply( + student_hidden_states, + lm_head_weight, + target_token_ids, + target_logprobs, + target_mask, + true_labels, + student_bias, + self.weight_hard_loss, + self.weight_soft_loss, + self.ignore_index, + self.temperature, + self.compiled, + self.chunk_size, + self.compute_ce_loss, + ) + + +# class LigerFusedTopKKDLossFunction(torch.autograd.Function): +# def forward( +# self, +# student_input: torch.Tensor, +# student_weight: torch.Tensor, +# teacher_input: torch.Tensor, # teacher logprobs +# teacher_token_ids: torch.Tensor, +# teacher_mask: torch.Tensor, +# hard_labels: torch.LongTensor, +# ): +# return LigerFusedLinearKLTopKLogprobFunction.apply( +# student_input, +# student_weight, +# teacher_token_ids, +# teacher_input, # teacher logprobs +# teacher_mask, +# hard_labels, +# None, +# self.weight_hard_loss, +# self.weight_soft_loss, +# self.ignore_index, +# self.temperature, +# self.compiled, +# self.chunk_size, +# self.compute_ce_loss, +# ) + + +def kldiv_forward_qwen2( + self, + input_ids: Optional[torch.LongTensor] = None, + target_logprobs: Optional[torch.Tensor] = None, + target_token_ids: Optional[torch.LongTensor] = None, + target_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], +) -> CausalLMOutputWithPast: + # pylint: disable=duplicate-code + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100 + # self.loss_function should be LigerFusedLinearKLTopKLogprobLoss + + loss = self.loss_function( + self.lm_head.weight, + hidden_states, + target_token_ids, + target_logprobs, + target_mask, + true_labels=labels, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=None, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def apply_kernel_to_qwen2(): + from transformers.models.qwen2 import modeling_qwen2 + + modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_qwen2 diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 9b6316784..42d4c1d6b 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -18,7 +18,7 @@ KD trainer from axolotl.core.trainers.base import AxolotlTrainer -from .topk_logprob.forward_kl import ChunkedTopKKDLoss +from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss class AxolotlKDTrainer(AxolotlTrainer): @@ -29,8 +29,11 @@ class AxolotlKDTrainer(AxolotlTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.model_accepts_loss_kwargs = True - self.kd_loss_fn = ChunkedTopKKDLoss( - num_output_chunks=8, kd_temperature=self.args.kd_temperature + self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss( + self.args.kd_ce_alpha, # hard label loss + self.args.kd_alpha, # kd loss + self.args.kd_temperature, + compute_ce_loss=bool(self.args.kd_ce_alpha), ) def _set_signature_columns_if_needed(self): @@ -59,11 +62,11 @@ class AxolotlKDTrainer(AxolotlTrainer): Subclass and override for custom behavior. """ - target_logprobs = inputs.pop("target_logprobs") - target_token_ids = inputs.pop("target_token_ids") - target_mask = inputs.pop("target_mask") + # target_logprobs = inputs.pop("target_logprobs") + # target_token_ids = inputs.pop("target_token_ids") + # target_mask = inputs.pop("target_mask") - seq_len = target_token_ids.shape[1] + # seq_len = target_token_ids.shape[1] if self.model_accepts_loss_kwargs: loss_kwargs = {} @@ -71,36 +74,37 @@ class AxolotlKDTrainer(AxolotlTrainer): loss_kwargs["num_items_in_batch"] = num_items_in_batch inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) + return outputs[0] + # + # # FIXME: account for tokenizer.padding_side + # student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous() + # + # shift_logits = student_logits.contiguous() + # target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() + # target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() + # target_mask_for_loss = target_mask[..., 1:, :].contiguous() + # + # loss_kd = self.kd_loss_fn( + # shift_logits, + # target_token_ids_for_loss, + # target_logprobs_for_loss, + # target_mask_for_loss, + # num_items_in_batch=num_items_in_batch, + # ) + # + # if self.args.kd_ce_alpha > 0: + # kd_alpha = self.args.kd_alpha + # loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd + # else: + # loss = loss_kd + # # Save past state if it exists + # # TODO: this needs to be fixed and made cleaner later. + # if self.args.past_index >= 0: + # self._past = outputs[ # pylint: disable=attribute-defined-outside-init + # self.args.past_index + # ] + # + # if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + # loss *= self.accelerator.num_processes - # FIXME: account for tokenizer.padding_side - student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous() - - shift_logits = student_logits.contiguous() - target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() - target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() - target_mask_for_loss = target_mask[..., 1:, :].contiguous() - - loss_kd = self.kd_loss_fn( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - num_items_in_batch=num_items_in_batch, - ) - - if self.args.kd_ce_alpha > 0: - kd_alpha = self.args.kd_alpha - loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd - else: - loss = loss_kd - # Save past state if it exists - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[ # pylint: disable=attribute-defined-outside-init - self.args.past_index - ] - - if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: - loss *= self.accelerator.num_processes - - return (loss, outputs) if return_outputs else loss + # return (loss, outputs) if return_outputs else loss