diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 7a81616ba..d32c87bb4 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -439,6 +439,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator_args = [self.tokenizer] if self.cfg.reward_model: collator = RewardDataCollatorWithPadding + elif self.cfg.kd_trainer: + from axolotl.integrations.kd.collator import ( + DataCollatorForKD, + KDBatchSamplerDataCollatorForSeq2Seq, + ) + + if self.cfg.sample_packing and use_batch_sampler_collator: + collator = KDBatchSamplerDataCollatorForSeq2Seq + else: + collator = DataCollatorForKD elif use_batch_sampler_collator: # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # supported multipack models, or non-flash-attention llama @@ -468,16 +478,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator_args.pop(0) kwargs.pop("pad_to_multiple_of", None) kwargs.pop("padding", None) - elif self.cfg.kd_trainer: - from axolotl.integrations.kd.collator import ( - DataCollatorForKD, - KDBatchSamplerDataCollatorForSeq2Seq, - ) - - if self.cfg.sample_packing: - collator = KDBatchSamplerDataCollatorForSeq2Seq - else: - collator = DataCollatorForKD else: collator = DataCollatorForSeq2Seq diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 3c9515091..06bce6971 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -16,6 +16,7 @@ loss for top_k KL divergence """ import torch +from torch import nn def zscore_standardize( @@ -235,3 +236,76 @@ def topk_kd_loss_with_zscore( kd_loss = kd_loss / float(kd_loss_per_token.size(0)) return kd_loss + + +class ChunkedTopKKDLoss(nn.Module): + """ + A wrapper that chunks (splits) the student and teacher outputs along the time dimension + to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies. + + Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to your top-K teacher logprobs. + """ + + def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0): + super().__init__() + self.num_output_chunks = num_output_chunks + self.kd_temperature = kd_temperature + + def forward( + self, + student_logits: torch.Tensor, # [B, seq_len, vocab_size] + target_token_ids: torch.Tensor, # [B, seq_len, K] + target_logprobs: torch.Tensor, # [B, seq_len, K] + target_mask: torch.Tensor, # [B, seq_len, K] + num_items_in_batch: int = -1, # optional batch size for normalization + ) -> torch.Tensor: + + # 1. Split along the "token" dimension (dim=1). + student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1) + token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1) + logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1) + mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1) + + # We'll accumulate a global "sum of losses" and "sum of valid tokens" + # so that our final average is consistent with the entire sequence/batch. + total_loss = 0.0 + total_valid_tokens = 0 + + # 2. Loop over each chunk and compute a chunk-specific loss. + for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip( + student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks + ): + # We pass num_items_in_batch=-1 so that the kd_loss + # will average over *this chunk's* valid tokens only. + chunk_loss = loss( + student_logits=st_chunk, + target_token_ids=tid_chunk, + target_logprobs=lp_chunk, + target_mask=msk_chunk, + num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens + kd_temperature=self.kd_temperature, + ) + + # kd_loss returns an average over the chunk's valid tokens. + # We want a global average in the end, so we need to re‐weight + # by the number of valid tokens in this chunk and keep track of the total. + chunk_valid_mask = msk_chunk.to(torch.bool) + chunk_valid_count = chunk_valid_mask.sum() # scalar tensor + + # Re-scale "chunk average" back to "chunk sum" + chunk_loss_sum = chunk_loss * chunk_valid_count + + total_loss += chunk_loss_sum + total_valid_tokens += chunk_valid_count + + # 3. Normalize *once* at the end. + if num_items_in_batch > 0: + # If the user gave us a manual denominator (e.g. total items in batch), + # we divide by it. Typically used if each item is of different length. + final_loss = total_loss / float(num_items_in_batch) + else: + # Otherwise, divide by total valid tokens across all chunks. + # to get the same result as a non-chunked approach. + final_loss = total_loss / float(total_valid_tokens) + + return final_loss diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index f99f2ca28..0e2996504 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -18,8 +18,7 @@ KD trainer from axolotl.core.trainers.base import AxolotlTrainer -from .topk_logprob.forward_kl import loss as topk_kd_loss -from .topk_logprob.forward_kl import topk_kd_loss_with_zscore +from .topk_logprob.forward_kl import ChunkedTopKKDLoss, topk_kd_loss_with_zscore class AxolotlKDTrainer(AxolotlTrainer): @@ -27,6 +26,13 @@ class AxolotlKDTrainer(AxolotlTrainer): Custom trainer subclass for Knowledge Distillation (KD) """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_accepts_loss_kwargs = True + self.kd_loss_fn = ChunkedTopKKDLoss( + num_output_chunks=8, kd_temperature=self.args.kd_temperature + ) + def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() columns_to_add = [] @@ -85,14 +91,14 @@ class AxolotlKDTrainer(AxolotlTrainer): num_items_in_batch=num_items_in_batch, ) else: - loss_kd = topk_kd_loss( - shift_logits, + loss_kd = self.kd_loss_fn( + # shift_logits, target_token_ids_for_loss, target_logprobs_for_loss, target_mask_for_loss, num_items_in_batch=num_items_in_batch, kd_temperature=self.args.kd_temperature, - top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, + # top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, ) if self.args.kd_ce_alpha > 0: