fix collator setup

This commit is contained in:
Wing Lian
2025-05-20 07:33:20 -07:00
parent 0e46367e01
commit 6fafe46562
3 changed files with 95 additions and 15 deletions

View File

@@ -439,6 +439,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator_args = [self.tokenizer]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing and use_batch_sampler_collator:
collator = KDBatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForKD
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama
@@ -468,16 +478,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing:
collator = KDBatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq

View File

@@ -16,6 +16,7 @@
loss for top_k KL divergence
"""
import torch
from torch import nn
def zscore_standardize(
@@ -235,3 +236,76 @@ def topk_kd_loss_with_zscore(
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss
class ChunkedTopKKDLoss(nn.Module):
"""
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to your top-K teacher logprobs.
"""
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
super().__init__()
self.num_output_chunks = num_output_chunks
self.kd_temperature = kd_temperature
def forward(
self,
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K]
target_mask: torch.Tensor, # [B, seq_len, K]
num_items_in_batch: int = -1, # optional batch size for normalization
) -> torch.Tensor:
# 1. Split along the "token" dimension (dim=1).
student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1)
mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1)
# We'll accumulate a global "sum of losses" and "sum of valid tokens"
# so that our final average is consistent with the entire sequence/batch.
total_loss = 0.0
total_valid_tokens = 0
# 2. Loop over each chunk and compute a chunk-specific loss.
for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks
):
# We pass num_items_in_batch=-1 so that the kd_loss
# will average over *this chunk's* valid tokens only.
chunk_loss = loss(
student_logits=st_chunk,
target_token_ids=tid_chunk,
target_logprobs=lp_chunk,
target_mask=msk_chunk,
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
kd_temperature=self.kd_temperature,
)
# kd_loss returns an average over the chunk's valid tokens.
# We want a global average in the end, so we need to reweight
# by the number of valid tokens in this chunk and keep track of the total.
chunk_valid_mask = msk_chunk.to(torch.bool)
chunk_valid_count = chunk_valid_mask.sum() # scalar tensor
# Re-scale "chunk average" back to "chunk sum"
chunk_loss_sum = chunk_loss * chunk_valid_count
total_loss += chunk_loss_sum
total_valid_tokens += chunk_valid_count
# 3. Normalize *once* at the end.
if num_items_in_batch > 0:
# If the user gave us a manual denominator (e.g. total items in batch),
# we divide by it. Typically used if each item is of different length.
final_loss = total_loss / float(num_items_in_batch)
else:
# Otherwise, divide by total valid tokens across all chunks.
# to get the same result as a non-chunked approach.
final_loss = total_loss / float(total_valid_tokens)
return final_loss

View File

@@ -18,8 +18,7 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
from .topk_logprob.forward_kl import ChunkedTopKKDLoss, topk_kd_loss_with_zscore
class AxolotlKDTrainer(AxolotlTrainer):
@@ -27,6 +26,13 @@ class AxolotlKDTrainer(AxolotlTrainer):
Custom trainer subclass for Knowledge Distillation (KD)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = True
self.kd_loss_fn = ChunkedTopKKDLoss(
num_output_chunks=8, kd_temperature=self.args.kd_temperature
)
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
@@ -85,14 +91,14 @@ class AxolotlKDTrainer(AxolotlTrainer):
num_items_in_batch=num_items_in_batch,
)
else:
loss_kd = topk_kd_loss(
shift_logits,
loss_kd = self.kd_loss_fn(
# shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
# top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
)
if self.args.kd_ce_alpha > 0: