From 28eb8632a198a1e00727ec94a6916ef36c650544 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 May 2025 07:58:59 -0400 Subject: [PATCH] more fixes and liger-type chunked loss --- src/axolotl/integrations/kd/__init__.py | 4 +- src/axolotl/integrations/kd/kernels/liger.py | 215 ++++-------------- src/axolotl/integrations/kd/kernels/models.py | 109 +++++++++ 3 files changed, 161 insertions(+), 167 deletions(-) create mode 100644 src/axolotl/integrations/kd/kernels/models.py diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index 017b9533d..00d73e036 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -36,6 +36,6 @@ class KDPlugin(BasePlugin): return None def pre_model_load(self, cfg): - from .kernels.liger import apply_kernel_to_qwen2 + from .kernels.models import apply_kernel - apply_kernel_to_qwen2() + apply_kernel(cfg.model_config_type) diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py index d42716432..e4e80df14 100644 --- a/src/axolotl/integrations/kd/kernels/liger.py +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -2,22 +2,17 @@ Liger Kernels for Chunked Top-K Log-Prob Distillation """ -from typing import Optional, Union, Unpack - import torch import torch.nn.functional as F - -# Assuming LigerFusedLinearDistillationBase is in this path and can be imported -# If not, its structure would need to be replicated or specific utilities copied. from liger_kernel.chunked_loss.fused_linear_distillation import ( LigerFusedLinearDistillationBase, ) -from transformers import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.qwen2.modeling_qwen2 import KwargsForCausalLM class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): + """ + Chunked kl-div loss for top-k logprobs + """ @staticmethod def distillation_loss_fn( @@ -161,7 +156,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): chunk_size: int = 1024, compute_ce_loss: bool = True, ): - CHUNK_SIZE = chunk_size + CHUNK_SIZE = chunk_size # pylint: disable=invalid-name grad_weight_acc = torch.zeros_like(student_lm_head_weight) grad_inputs_list = [] grad_bias_acc = ( @@ -255,24 +250,36 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): accumulate_chunk_grads_compiled = accumulate_chunk_grads # Use the same chunking logic as LigerFusedLinearDistillationBase.forward - B, N, D = student_input.shape - K = target_token_ids.shape[-1] + B, N, D = student_input.shape # pylint: disable=invalid-name + K = target_token_ids.shape[-1] # pylint: disable=invalid-name + + # print("student_input shape: " + str(student_input.shape)) + # print("target_token_ids shape: " + str(target_token_ids.shape)) + # print("target_logprobs shape: " + str(target_logprobs.shape)) + # print("target_mask shape: " + str(target_mask.shape)) + # print("true_labels shape: " + str(true_labels.shape)) - print("student_input shape: " + str(student_input.shape)) - print("target_token_ids shape: " + str(target_token_ids.shape)) - print("target_logprobs shape: " + str(target_logprobs.shape)) - print("target_mask shape: " + str(target_mask.shape)) - print("true_labels shape: " + str(true_labels.shape)) student_input_flat = student_input.reshape(-1, student_input.shape[-1]) target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1]) target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1]) target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1]) - true_labels_flat = true_labels.reshape(-1) - print("student_input_flat shape: " + str(student_input_flat.shape)) - print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape)) - print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape)) - print("target_mask_flat shape: " + str(target_mask_flat.shape)) - print("true_labels_flat shape: " + str(true_labels_flat.shape)) + # pad and shift for cross entropy loss + true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index) + true_labels_flat = true_labels[:, 1:].contiguous().view(-1) + # true_labels_flat = true_labels.reshape(-1) + + # student_input_flat = student_input[:, :-1, :].contiguous().view(-1, student_input.shape[-1]) + # target_token_ids_flat = target_token_ids[:, 1:, :].contiguous().view(-1, target_token_ids.shape[-1]) + # target_logprobs_flat = target_logprobs[:, 1:, :].contiguous().view(-1, target_logprobs.shape[-1]) + # target_mask_flat = target_mask[:, 1:, :].contiguous().view(-1, target_mask.shape[-1]) + # true_labels_flat = true_labels[:, 1:].contiguous().view(-1) + # N = N - 1 + + # print("student_input_flat shape: " + str(student_input_flat.shape)) + # print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape)) + # print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape)) + # print("target_mask_flat shape: " + str(target_mask_flat.shape)) + # print("true_labels_flat shape: " + str(true_labels_flat.shape)) num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE) _student_input_chunks = torch.chunk( @@ -298,6 +305,8 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): grad_inputs_list.append(grad_input_chunk) grad_inputs_combined = torch.cat(grad_inputs_list, dim=0) + print("grad_inputs_combined") + print(grad_inputs_combined.shape) ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc) # For matching None returns in backward for non-tensor/non-grad_requiring inputs @@ -305,13 +314,27 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): ctx.bias_was_none = student_lm_head_bias is None ctx.orig_dims = (B, N, D, K) - return loss_acc / (true_labels_flat != ignore_index).sum() + num_valid_tokens_scalar: float = (true_labels_flat != ignore_index).sum().item() + ctx.num_valid_tokens_scalar = num_valid_tokens_scalar + final_loss = loss_acc # / ctx.num_valid_tokens_scalar + + return final_loss @staticmethod def backward(ctx, grad_output): grad_input_flat, grad_weight, grad_bias_maybe = ( ctx.saved_tensors ) # grad_input_flat is (B*N, D) + print("grad_input_flat") + print(grad_input_flat.shape) + + # num_valid_tokens_scalar = ctx.num_valid_tokens_scalar + # normalizer = float(num_valid_tokens_scalar) if num_valid_tokens_scalar > 0 else 1.0 + + # grad_input_flat = grad_input_flat / normalizer + # grad_weight = grad_weight / normalizer + # if grad_bias_maybe is not None: + # grad_bias_maybe = grad_bias_maybe / normalizer # Scale gradients by grad_output if it's not 1.0 if not torch.equal( @@ -372,6 +395,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): + """ + wrapper for chunked top-k logprob kl-d + """ + def __init__( self, weight_hard_loss: float = 0.5, @@ -416,46 +443,6 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): true_labels: torch.Tensor, student_bias: torch.Tensor = None, ) -> torch.Tensor: - print(student_hidden_states.shape) - # Input validation - # if student_hidden_states.ndim != 2 or lm_head_weight.ndim != 2: - # raise ValueError("student_input and student_weight must be 2D tensors.") - # if student_hidden_states.shape[1] != lm_head_weight.shape[1]: - # raise ValueError("Hidden dimension mismatch between student_input and student_weight.") - # if student_bias is not None and (student_bias.ndim != 1 or student_bias.shape[0] != lm_head_weight.shape[0]): - # raise ValueError("student_bias shape mismatch.") - - if self.weight_soft_loss > 0.0: - expected_len = student_hidden_states.shape[0] - for name, tensor in [ - ("target_token_ids", target_token_ids), - ("target_logprobs", target_logprobs), - ("target_mask", target_mask), - ]: - # if tensor.ndim != 2: - # raise ValueError(f"{name} must be a 2D tensor.") - if tensor.shape[0] != expected_len: - raise ValueError( - f"Length mismatch: student_input ({expected_len}) vs {name} ({tensor.shape[0]})." - ) - if not ( - target_token_ids.shape[1] - == target_logprobs.shape[1] - == target_mask.shape[1] - ): - raise ValueError( - "Top-k dimension mismatch among target_token_ids, target_logprobs, target_mask." - ) - if target_token_ids.max() >= lm_head_weight.shape[0]: - raise ValueError( - f"target_token_ids contains indices ({target_token_ids.max().item()}) " - f"out of bounds for student vocabulary size ({lm_head_weight.shape[0]})." - ) - - # if self.compute_ce_loss and self.weight_hard_loss > 0.0: - # if true_labels.ndim != 1 or true_labels.shape[0] != student_hidden_states.shape[0]: - # raise ValueError("true_labels shape mismatch or incorrect dimensions.") - return LigerFusedLinearKLTopKLogprobFunction.apply( student_hidden_states, lm_head_weight, @@ -472,105 +459,3 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): self.chunk_size, self.compute_ce_loss, ) - - -# class LigerFusedTopKKDLossFunction(torch.autograd.Function): -# def forward( -# self, -# student_input: torch.Tensor, -# student_weight: torch.Tensor, -# teacher_input: torch.Tensor, # teacher logprobs -# teacher_token_ids: torch.Tensor, -# teacher_mask: torch.Tensor, -# hard_labels: torch.LongTensor, -# ): -# return LigerFusedLinearKLTopKLogprobFunction.apply( -# student_input, -# student_weight, -# teacher_token_ids, -# teacher_input, # teacher logprobs -# teacher_mask, -# hard_labels, -# None, -# self.weight_hard_loss, -# self.weight_soft_loss, -# self.ignore_index, -# self.temperature, -# self.compiled, -# self.chunk_size, -# self.compute_ce_loss, -# ) - - -def kldiv_forward_qwen2( - self, - input_ids: Optional[torch.LongTensor] = None, - target_logprobs: Optional[torch.Tensor] = None, - target_token_ids: Optional[torch.LongTensor] = None, - target_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> CausalLMOutputWithPast: - # pylint: disable=duplicate-code - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100 - # self.loss_function should be LigerFusedLinearKLTopKLogprobLoss - - loss = self.loss_function( - self.lm_head.weight, - hidden_states, - target_token_ids, - target_logprobs, - target_mask, - true_labels=labels, - ) - - return CausalLMOutputWithPast( - loss=loss, - logits=None, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def apply_kernel_to_qwen2(): - from transformers.models.qwen2 import modeling_qwen2 - - modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_qwen2 diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py new file mode 100644 index 000000000..9006d92b0 --- /dev/null +++ b/src/axolotl/integrations/kd/kernels/models.py @@ -0,0 +1,109 @@ +""" +model patcher for chunked top-k kl-div +""" + +from typing import Optional, Union, Unpack + +import torch +from transformers import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import LossKwargs + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): + """ + placeholder kwargs for hf model classes + """ + + +def kldiv_forward_llama_like( + self, + input_ids: Optional[torch.LongTensor] = None, + target_logprobs: Optional[torch.Tensor] = None, + target_token_ids: Optional[torch.LongTensor] = None, + target_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument + **kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc] +) -> CausalLMOutputWithPast: + # pylint: disable=duplicate-code + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100 + # self.loss_function should be LigerFusedLinearKLTopKLogprobLoss + + loss = self.loss_function( + self.lm_head.weight, + hidden_states, + target_token_ids, + target_logprobs, + target_mask, + true_labels=labels, + ) + num_items_in_batch = kwargs.pop("num_items_in_batch", None) + if num_items_in_batch is not None: + loss = loss / num_items_in_batch + + return CausalLMOutputWithPast( + loss=loss, + logits=None, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def apply_kernel_to_qwen2(): + from transformers.models.qwen2 import modeling_qwen2 + + modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_llama_like + + +def apply_kernel_to_llama(): + from transformers.models.llama import modeling_llama + + modeling_llama.LlamaForCausalLM.forward = kldiv_forward_llama_like + + +def apply_kernel(model_type): + # Dynamically import the module and attention class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")]) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) + model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") + model_cls.forward = kldiv_forward_llama_like