diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index d32c87bb4..c1aeb7803 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -324,10 +324,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["kd_zscore_base_temp"] = ( self.cfg.kd_zscore_base_temp ) - if self.cfg.kd_top_k_before_softmax is not None: - training_arguments_kwargs["kd_top_k_before_softmax"] = ( - self.cfg.kd_top_k_before_softmax - ) if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 0cba7cf94..8d841c9bb 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -201,13 +201,6 @@ class AxolotlTrainingMixins: }, ) - kd_top_k_before_softmax: Optional[bool] = field( - default=False, - metadata={ - "help": "Whether to apply top_k_before_softmax to the logits when using KD" - }, - ) - adam_beta3: Optional[float] = field( default=None, metadata={ diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 2fbba2c6a..43fe0e6db 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -32,6 +32,3 @@ class KDArgs(BaseModel): kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_temperature: Optional[float] = None # temperature for sampling during KD kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling - kd_top_k_before_softmax: Optional[bool] = ( - None # whether to sample top k before softmax during KD - ) diff --git a/src/axolotl/integrations/kd/collator.py b/src/axolotl/integrations/kd/collator.py index 167b8cdbd..8089b5b25 100644 --- a/src/axolotl/integrations/kd/collator.py +++ b/src/axolotl/integrations/kd/collator.py @@ -47,6 +47,10 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): position_pad_token_id: int = 0 return_tensors: str = "pt" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True + def __call__(self, features, return_tensors=None): if return_tensors is None: return_tensors = self.return_tensors diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 823d28ac2..1ef5ba3df 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -61,7 +61,6 @@ def loss( target_mask: torch.Tensor, num_items_in_batch: int = -1, # Use -1 to indicate "None" kd_temperature: float = 1.0, - top_k_before_softmax: int = 0, ) -> torch.Tensor: """ A KD loss function that is TorchScript-friendly. @@ -78,8 +77,6 @@ def loss( num_items_in_batch (int, optional): The number of items in the batch. kd_temperature (float, optional): The temperature for KD. Default: 1.0 - top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits - Default: 0 """ target_logprobs = target_logprobs.float() @@ -89,46 +86,24 @@ def loss( # student_logits shape: [B, student_seq_len, vocab_size] teacher_seq_len = target_token_ids.shape[1] - if top_k_before_softmax: - # Slice student logits to match teacher-provided sequence length - student_logits_for_kd = student_logits[ - :, :teacher_seq_len, : - ] # [B, teacher_seq_len, vocab_size] + # Slice student logits to match teacher-provided sequence length + student_logits_for_kd = ( + student_logits[:, :teacher_seq_len, :] / kd_temperature + ) # [B, teacher_seq_len, vocab_size] - # Gather student logits for teacher's top-K tokens - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, teacher_seq_len, K] + # keep in full precision for numerical stability of loss + student_logits_for_kd = student_logits_for_kd.float() - student_logits_topk = student_logits_topk.float() + # Gather student logits for teacher's top-K tokens + student_logits_topk = torch.gather( + student_logits_for_kd, dim=-1, index=target_token_ids + ) # [B, teacher_seq_len, K] - # Apply KD temperature to student’s logits - if kd_temperature != 1.0: - student_logits_topk = student_logits_topk / kd_temperature + # Compute logsumexp across full vocabulary + student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True) - # Convert student top-k logits to logprobs - student_logprobs_topk = student_logits_topk - torch.logsumexp( - student_logits_topk, dim=-1, keepdim=True - ) # [B, teacher_seq_len, K] - else: - # Slice student logits to match teacher-provided sequence length - student_logits_for_kd = ( - student_logits[:, :teacher_seq_len, :] / kd_temperature - ) # [B, teacher_seq_len, vocab_size] - - # keep in full precision for numerical stability of loss - student_logits_for_kd = student_logits_for_kd.float() - - # Gather student logits for teacher's top-K tokens - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, teacher_seq_len, K] - - # Compute logsumexp across full vocabulary - student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True) - - # Convert just the top-k logits to logprobs - student_logprobs_topk = student_logits_topk - student_lse + # Convert just the top-k logits to logprobs + student_logprobs_topk = student_logits_topk - student_lse # Convert teacher_mask to boolean for indexing # In TorchScript, .bool() is sometimes unsupported, so we do: @@ -243,7 +218,7 @@ class ChunkedTopKKDLoss(nn.Module): A wrapper that chunks (splits) the student and teacher outputs along the time dimension to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies. - Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to your top-K teacher logprobs. + Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs. """ def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0): @@ -258,7 +233,6 @@ class ChunkedTopKKDLoss(nn.Module): target_logprobs: torch.Tensor, # [B, seq_len, K] target_mask: torch.Tensor, # [B, seq_len, K] num_items_in_batch: int = -1, # optional batch size for normalization - top_k_before_softmax: int = 0, # optional top-k before softmax for teacher logits ) -> torch.Tensor: # 1. Split along the "token" dimension (dim=1). @@ -285,7 +259,6 @@ class ChunkedTopKKDLoss(nn.Module): target_mask=msk_chunk, num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens kd_temperature=self.kd_temperature, - top_k_before_softmax=top_k_before_softmax, ) # kd_loss returns an average over the chunk's valid tokens. diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index da02eac27..ea15b9c1d 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -97,7 +97,6 @@ class AxolotlKDTrainer(AxolotlTrainer): target_logprobs_for_loss, target_mask_for_loss, num_items_in_batch=num_items_in_batch, - # top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, ) if self.args.kd_ce_alpha > 0: