more fixes and liger-type chunked loss
This commit is contained in:
@@ -36,6 +36,6 @@ class KDPlugin(BasePlugin):
|
||||
return None
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from .kernels.liger import apply_kernel_to_qwen2
|
||||
from .kernels.models import apply_kernel
|
||||
|
||||
apply_kernel_to_qwen2()
|
||||
apply_kernel(cfg.model_config_type)
|
||||
|
||||
@@ -2,22 +2,17 @@
|
||||
Liger Kernels for Chunked Top-K Log-Prob Distillation
|
||||
"""
|
||||
|
||||
from typing import Optional, Union, Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Assuming LigerFusedLinearDistillationBase is in this path and can be imported
|
||||
# If not, its structure would need to be replicated or specific utilities copied.
|
||||
from liger_kernel.chunked_loss.fused_linear_distillation import (
|
||||
LigerFusedLinearDistillationBase,
|
||||
)
|
||||
from transformers import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.qwen2.modeling_qwen2 import KwargsForCausalLM
|
||||
|
||||
|
||||
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
"""
|
||||
Chunked kl-div loss for top-k logprobs
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def distillation_loss_fn(
|
||||
@@ -161,7 +156,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
chunk_size: int = 1024,
|
||||
compute_ce_loss: bool = True,
|
||||
):
|
||||
CHUNK_SIZE = chunk_size
|
||||
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
|
||||
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
|
||||
grad_inputs_list = []
|
||||
grad_bias_acc = (
|
||||
@@ -255,24 +250,36 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
accumulate_chunk_grads_compiled = accumulate_chunk_grads
|
||||
|
||||
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
|
||||
B, N, D = student_input.shape
|
||||
K = target_token_ids.shape[-1]
|
||||
B, N, D = student_input.shape # pylint: disable=invalid-name
|
||||
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
|
||||
|
||||
# print("student_input shape: " + str(student_input.shape))
|
||||
# print("target_token_ids shape: " + str(target_token_ids.shape))
|
||||
# print("target_logprobs shape: " + str(target_logprobs.shape))
|
||||
# print("target_mask shape: " + str(target_mask.shape))
|
||||
# print("true_labels shape: " + str(true_labels.shape))
|
||||
|
||||
print("student_input shape: " + str(student_input.shape))
|
||||
print("target_token_ids shape: " + str(target_token_ids.shape))
|
||||
print("target_logprobs shape: " + str(target_logprobs.shape))
|
||||
print("target_mask shape: " + str(target_mask.shape))
|
||||
print("true_labels shape: " + str(true_labels.shape))
|
||||
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
|
||||
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
|
||||
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
|
||||
target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
|
||||
true_labels_flat = true_labels.reshape(-1)
|
||||
print("student_input_flat shape: " + str(student_input_flat.shape))
|
||||
print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape))
|
||||
print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape))
|
||||
print("target_mask_flat shape: " + str(target_mask_flat.shape))
|
||||
print("true_labels_flat shape: " + str(true_labels_flat.shape))
|
||||
# pad and shift for cross entropy loss
|
||||
true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
|
||||
true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
|
||||
# true_labels_flat = true_labels.reshape(-1)
|
||||
|
||||
# student_input_flat = student_input[:, :-1, :].contiguous().view(-1, student_input.shape[-1])
|
||||
# target_token_ids_flat = target_token_ids[:, 1:, :].contiguous().view(-1, target_token_ids.shape[-1])
|
||||
# target_logprobs_flat = target_logprobs[:, 1:, :].contiguous().view(-1, target_logprobs.shape[-1])
|
||||
# target_mask_flat = target_mask[:, 1:, :].contiguous().view(-1, target_mask.shape[-1])
|
||||
# true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
|
||||
# N = N - 1
|
||||
|
||||
# print("student_input_flat shape: " + str(student_input_flat.shape))
|
||||
# print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape))
|
||||
# print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape))
|
||||
# print("target_mask_flat shape: " + str(target_mask_flat.shape))
|
||||
# print("true_labels_flat shape: " + str(true_labels_flat.shape))
|
||||
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
|
||||
|
||||
_student_input_chunks = torch.chunk(
|
||||
@@ -298,6 +305,8 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
grad_inputs_list.append(grad_input_chunk)
|
||||
|
||||
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
|
||||
print("grad_inputs_combined")
|
||||
print(grad_inputs_combined.shape)
|
||||
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
||||
|
||||
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
||||
@@ -305,13 +314,27 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
ctx.bias_was_none = student_lm_head_bias is None
|
||||
ctx.orig_dims = (B, N, D, K)
|
||||
|
||||
return loss_acc / (true_labels_flat != ignore_index).sum()
|
||||
num_valid_tokens_scalar: float = (true_labels_flat != ignore_index).sum().item()
|
||||
ctx.num_valid_tokens_scalar = num_valid_tokens_scalar
|
||||
final_loss = loss_acc # / ctx.num_valid_tokens_scalar
|
||||
|
||||
return final_loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_input_flat, grad_weight, grad_bias_maybe = (
|
||||
ctx.saved_tensors
|
||||
) # grad_input_flat is (B*N, D)
|
||||
print("grad_input_flat")
|
||||
print(grad_input_flat.shape)
|
||||
|
||||
# num_valid_tokens_scalar = ctx.num_valid_tokens_scalar
|
||||
# normalizer = float(num_valid_tokens_scalar) if num_valid_tokens_scalar > 0 else 1.0
|
||||
|
||||
# grad_input_flat = grad_input_flat / normalizer
|
||||
# grad_weight = grad_weight / normalizer
|
||||
# if grad_bias_maybe is not None:
|
||||
# grad_bias_maybe = grad_bias_maybe / normalizer
|
||||
|
||||
# Scale gradients by grad_output if it's not 1.0
|
||||
if not torch.equal(
|
||||
@@ -372,6 +395,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
|
||||
|
||||
class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
||||
"""
|
||||
wrapper for chunked top-k logprob kl-d
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_hard_loss: float = 0.5,
|
||||
@@ -416,46 +443,6 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
||||
true_labels: torch.Tensor,
|
||||
student_bias: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
print(student_hidden_states.shape)
|
||||
# Input validation
|
||||
# if student_hidden_states.ndim != 2 or lm_head_weight.ndim != 2:
|
||||
# raise ValueError("student_input and student_weight must be 2D tensors.")
|
||||
# if student_hidden_states.shape[1] != lm_head_weight.shape[1]:
|
||||
# raise ValueError("Hidden dimension mismatch between student_input and student_weight.")
|
||||
# if student_bias is not None and (student_bias.ndim != 1 or student_bias.shape[0] != lm_head_weight.shape[0]):
|
||||
# raise ValueError("student_bias shape mismatch.")
|
||||
|
||||
if self.weight_soft_loss > 0.0:
|
||||
expected_len = student_hidden_states.shape[0]
|
||||
for name, tensor in [
|
||||
("target_token_ids", target_token_ids),
|
||||
("target_logprobs", target_logprobs),
|
||||
("target_mask", target_mask),
|
||||
]:
|
||||
# if tensor.ndim != 2:
|
||||
# raise ValueError(f"{name} must be a 2D tensor.")
|
||||
if tensor.shape[0] != expected_len:
|
||||
raise ValueError(
|
||||
f"Length mismatch: student_input ({expected_len}) vs {name} ({tensor.shape[0]})."
|
||||
)
|
||||
if not (
|
||||
target_token_ids.shape[1]
|
||||
== target_logprobs.shape[1]
|
||||
== target_mask.shape[1]
|
||||
):
|
||||
raise ValueError(
|
||||
"Top-k dimension mismatch among target_token_ids, target_logprobs, target_mask."
|
||||
)
|
||||
if target_token_ids.max() >= lm_head_weight.shape[0]:
|
||||
raise ValueError(
|
||||
f"target_token_ids contains indices ({target_token_ids.max().item()}) "
|
||||
f"out of bounds for student vocabulary size ({lm_head_weight.shape[0]})."
|
||||
)
|
||||
|
||||
# if self.compute_ce_loss and self.weight_hard_loss > 0.0:
|
||||
# if true_labels.ndim != 1 or true_labels.shape[0] != student_hidden_states.shape[0]:
|
||||
# raise ValueError("true_labels shape mismatch or incorrect dimensions.")
|
||||
|
||||
return LigerFusedLinearKLTopKLogprobFunction.apply(
|
||||
student_hidden_states,
|
||||
lm_head_weight,
|
||||
@@ -472,105 +459,3 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
||||
self.chunk_size,
|
||||
self.compute_ce_loss,
|
||||
)
|
||||
|
||||
|
||||
# class LigerFusedTopKKDLossFunction(torch.autograd.Function):
|
||||
# def forward(
|
||||
# self,
|
||||
# student_input: torch.Tensor,
|
||||
# student_weight: torch.Tensor,
|
||||
# teacher_input: torch.Tensor, # teacher logprobs
|
||||
# teacher_token_ids: torch.Tensor,
|
||||
# teacher_mask: torch.Tensor,
|
||||
# hard_labels: torch.LongTensor,
|
||||
# ):
|
||||
# return LigerFusedLinearKLTopKLogprobFunction.apply(
|
||||
# student_input,
|
||||
# student_weight,
|
||||
# teacher_token_ids,
|
||||
# teacher_input, # teacher logprobs
|
||||
# teacher_mask,
|
||||
# hard_labels,
|
||||
# None,
|
||||
# self.weight_hard_loss,
|
||||
# self.weight_soft_loss,
|
||||
# self.ignore_index,
|
||||
# self.temperature,
|
||||
# self.compiled,
|
||||
# self.chunk_size,
|
||||
# self.compute_ce_loss,
|
||||
# )
|
||||
|
||||
|
||||
def kldiv_forward_qwen2(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
target_logprobs: Optional[torch.Tensor] = None,
|
||||
target_token_ids: Optional[torch.LongTensor] = None,
|
||||
target_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
loss = self.loss_function(
|
||||
self.lm_head.weight,
|
||||
hidden_states,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
true_labels=labels,
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=None,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_kernel_to_qwen2():
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_qwen2
|
||||
|
||||
109
src/axolotl/integrations/kd/kernels/models.py
Normal file
109
src/axolotl/integrations/kd/kernels/models.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
model patcher for chunked top-k kl-div
|
||||
"""
|
||||
|
||||
from typing import Optional, Union, Unpack
|
||||
|
||||
import torch
|
||||
from transformers import Cache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import LossKwargs
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
||||
"""
|
||||
placeholder kwargs for hf model classes
|
||||
"""
|
||||
|
||||
|
||||
def kldiv_forward_llama_like(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
target_logprobs: Optional[torch.Tensor] = None,
|
||||
target_token_ids: Optional[torch.LongTensor] = None,
|
||||
target_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
|
||||
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
|
||||
) -> CausalLMOutputWithPast:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
loss = self.loss_function(
|
||||
self.lm_head.weight,
|
||||
hidden_states,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
true_labels=labels,
|
||||
)
|
||||
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||
if num_items_in_batch is not None:
|
||||
loss = loss / num_items_in_batch
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=None,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_kernel_to_qwen2():
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = kldiv_forward_llama_like
|
||||
|
||||
|
||||
def apply_kernel_to_llama():
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
modeling_llama.LlamaForCausalLM.forward = kldiv_forward_llama_like
|
||||
|
||||
|
||||
def apply_kernel(model_type):
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
||||
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||
model_cls.forward = kldiv_forward_llama_like
|
||||
Reference in New Issue
Block a user