more fixes and liger-type chunked loss

This commit is contained in:
Wing Lian
2025-05-22 07:58:59 -04:00
parent 5cfaac3767
commit 28eb8632a1
3 changed files with 161 additions and 167 deletions

View File

@@ -36,6 +36,6 @@ class KDPlugin(BasePlugin):
return None
def pre_model_load(self, cfg):
from .kernels.liger import apply_kernel_to_qwen2
from .kernels.models import apply_kernel
apply_kernel_to_qwen2()
apply_kernel(cfg.model_config_type)

View File

@@ -2,22 +2,17 @@
Liger Kernels for Chunked Top-K Log-Prob Distillation
"""
from typing import Optional, Union, Unpack
import torch
import torch.nn.functional as F
# Assuming LigerFusedLinearDistillationBase is in this path and can be imported
# If not, its structure would need to be replicated or specific utilities copied.
from liger_kernel.chunked_loss.fused_linear_distillation import (
LigerFusedLinearDistillationBase,
)
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen2.modeling_qwen2 import KwargsForCausalLM
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
"""
Chunked kl-div loss for top-k logprobs
"""
@staticmethod
def distillation_loss_fn(
@@ -161,7 +156,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
chunk_size: int = 1024,
compute_ce_loss: bool = True,
):
CHUNK_SIZE = chunk_size
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
grad_inputs_list = []
grad_bias_acc = (
@@ -255,24 +250,36 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
accumulate_chunk_grads_compiled = accumulate_chunk_grads
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
B, N, D = student_input.shape
K = target_token_ids.shape[-1]
B, N, D = student_input.shape # pylint: disable=invalid-name
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
# print("student_input shape: " + str(student_input.shape))
# print("target_token_ids shape: " + str(target_token_ids.shape))
# print("target_logprobs shape: " + str(target_logprobs.shape))
# print("target_mask shape: " + str(target_mask.shape))
# print("true_labels shape: " + str(true_labels.shape))
print("student_input shape: " + str(student_input.shape))
print("target_token_ids shape: " + str(target_token_ids.shape))
print("target_logprobs shape: " + str(target_logprobs.shape))
print("target_mask shape: " + str(target_mask.shape))
print("true_labels shape: " + str(true_labels.shape))
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
true_labels_flat = true_labels.reshape(-1)
print("student_input_flat shape: " + str(student_input_flat.shape))
print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape))
print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape))
print("target_mask_flat shape: " + str(target_mask_flat.shape))
print("true_labels_flat shape: " + str(true_labels_flat.shape))
# pad and shift for cross entropy loss
true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
# true_labels_flat = true_labels.reshape(-1)
# student_input_flat = student_input[:, :-1, :].contiguous().view(-1, student_input.shape[-1])
# target_token_ids_flat = target_token_ids[:, 1:, :].contiguous().view(-1, target_token_ids.shape[-1])
# target_logprobs_flat = target_logprobs[:, 1:, :].contiguous().view(-1, target_logprobs.shape[-1])
# target_mask_flat = target_mask[:, 1:, :].contiguous().view(-1, target_mask.shape[-1])
# true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
# N = N - 1
# print("student_input_flat shape: " + str(student_input_flat.shape))
# print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape))
# print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape))
# print("target_mask_flat shape: " + str(target_mask_flat.shape))
# print("true_labels_flat shape: " + str(true_labels_flat.shape))
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(
@@ -298,6 +305,8 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
grad_inputs_list.append(grad_input_chunk)
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
print("grad_inputs_combined")
print(grad_inputs_combined.shape)
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
@@ -305,13 +314,27 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K)
return loss_acc / (true_labels_flat != ignore_index).sum()
num_valid_tokens_scalar: float = (true_labels_flat != ignore_index).sum().item()
ctx.num_valid_tokens_scalar = num_valid_tokens_scalar
final_loss = loss_acc # / ctx.num_valid_tokens_scalar
return final_loss
@staticmethod
def backward(ctx, grad_output):
grad_input_flat, grad_weight, grad_bias_maybe = (
ctx.saved_tensors
) # grad_input_flat is (B*N, D)
print("grad_input_flat")
print(grad_input_flat.shape)
# num_valid_tokens_scalar = ctx.num_valid_tokens_scalar
# normalizer = float(num_valid_tokens_scalar) if num_valid_tokens_scalar > 0 else 1.0
# grad_input_flat = grad_input_flat / normalizer
# grad_weight = grad_weight / normalizer
# if grad_bias_maybe is not None:
# grad_bias_maybe = grad_bias_maybe / normalizer
# Scale gradients by grad_output if it's not 1.0
if not torch.equal(
@@ -372,6 +395,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
"""
wrapper for chunked top-k logprob kl-d
"""
def __init__(
self,
weight_hard_loss: float = 0.5,
@@ -416,46 +443,6 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
true_labels: torch.Tensor,
student_bias: torch.Tensor = None,
) -> torch.Tensor:
print(student_hidden_states.shape)
# Input validation
# if student_hidden_states.ndim != 2 or lm_head_weight.ndim != 2:
# raise ValueError("student_input and student_weight must be 2D tensors.")
# if student_hidden_states.shape[1] != lm_head_weight.shape[1]:
# raise ValueError("Hidden dimension mismatch between student_input and student_weight.")
# if student_bias is not None and (student_bias.ndim != 1 or student_bias.shape[0] != lm_head_weight.shape[0]):
# raise ValueError("student_bias shape mismatch.")
if self.weight_soft_loss > 0.0:
expected_len = student_hidden_states.shape[0]
for name, tensor in [
("target_token_ids", target_token_ids),
("target_logprobs", target_logprobs),
("target_mask", target_mask),
]:
# if tensor.ndim != 2:
# raise ValueError(f"{name} must be a 2D tensor.")
if tensor.shape[0] != expected_len:
raise ValueError(
f"Length mismatch: student_input ({expected_len}) vs {name} ({tensor.shape[0]})."
)
if not (
target_token_ids.shape[1]
== target_logprobs.shape[1]
== target_mask.shape[1]
):
raise ValueError(
"Top-k dimension mismatch among target_token_ids, target_logprobs, target_mask."
)
if target_token_ids.max() >= lm_head_weight.shape[0]:
raise ValueError(
f"target_token_ids contains indices ({target_token_ids.max().item()}) "
f"out of bounds for student vocabulary size ({lm_head_weight.shape[0]})."
)
# if self.compute_ce_loss and self.weight_hard_loss > 0.0:
# if true_labels.ndim != 1 or true_labels.shape[0] != student_hidden_states.shape[0]:
# raise ValueError("true_labels shape mismatch or incorrect dimensions.")
return LigerFusedLinearKLTopKLogprobFunction.apply(
student_hidden_states,
lm_head_weight,
@@ -472,105 +459,3 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
self.chunk_size,
self.compute_ce_loss,
)
# class LigerFusedTopKKDLossFunction(torch.autograd.Function):
# def forward(
# self,
# student_input: torch.Tensor,
# student_weight: torch.Tensor,
# teacher_input: torch.Tensor, # teacher logprobs
# teacher_token_ids: torch.Tensor,
# teacher_mask: torch.Tensor,
# hard_labels: torch.LongTensor,
# ):
# return LigerFusedLinearKLTopKLogprobFunction.apply(
# student_input,
# student_weight,
# teacher_token_ids,
# teacher_input, # teacher logprobs
# teacher_mask,
# hard_labels,
# None,
# self.weight_hard_loss,
# self.weight_soft_loss,
# self.ignore_index,
# self.temperature,
# self.compiled,
# self.chunk_size,
# self.compute_ce_loss,
# )
def kldiv_forward_qwen2(
self,
input_ids: Optional[torch.LongTensor] = None,
target_logprobs: Optional[torch.Tensor] = None,
target_token_ids: Optional[torch.LongTensor] = None,
target_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
loss = self.loss_function(
self.lm_head.weight,
hidden_states,
target_token_ids,
target_logprobs,
target_mask,
true_labels=labels,
)
return CausalLMOutputWithPast(
loss=loss,
logits=None,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_kernel_to_qwen2():
from transformers.models.qwen2 import modeling_qwen2
modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_qwen2

View File

@@ -0,0 +1,109 @@
"""
model patcher for chunked top-k kl-div
"""
from typing import Optional, Union, Unpack
import torch
from transformers import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import LossKwargs
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
"""
placeholder kwargs for hf model classes
"""
def kldiv_forward_llama_like(
self,
input_ids: Optional[torch.LongTensor] = None,
target_logprobs: Optional[torch.Tensor] = None,
target_token_ids: Optional[torch.LongTensor] = None,
target_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
) -> CausalLMOutputWithPast:
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
loss = self.loss_function(
self.lm_head.weight,
hidden_states,
target_token_ids,
target_logprobs,
target_mask,
true_labels=labels,
)
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
if num_items_in_batch is not None:
loss = loss / num_items_in_batch
return CausalLMOutputWithPast(
loss=loss,
logits=None,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_kernel_to_qwen2():
from transformers.models.qwen2 import modeling_qwen2
modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_llama_like
def apply_kernel_to_llama():
from transformers.models.llama import modeling_llama
modeling_llama.LlamaForCausalLM.forward = kldiv_forward_llama_like
def apply_kernel(model_type):
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = kldiv_forward_llama_like