diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index c1aeb7803..47170cc6f 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -320,10 +320,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha if self.cfg.kd_temperature is not None: training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature - if self.cfg.kd_zscore_base_temp is not None: - training_arguments_kwargs["kd_zscore_base_temp"] = ( - self.cfg.kd_zscore_base_temp - ) if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 8d841c9bb..03cad93e7 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -194,13 +194,6 @@ class AxolotlTrainingMixins: }, ) - kd_zscore_base_temp: Optional[float] = field( - default=None, - metadata={ - "help": "the base temperature parameter for KL divergence with z-score when using KD" - }, - ) - adam_beta3: Optional[float] = field( default=None, metadata={ diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 43fe0e6db..0eede1ada 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -31,4 +31,3 @@ class KDArgs(BaseModel): ) kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_temperature: Optional[float] = None # temperature for sampling during KD - kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 1ef5ba3df..4b7251295 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -19,40 +19,6 @@ import torch from torch import nn -def zscore_standardize( - logits: torch.Tensor, - mask: torch.Tensor = None, - base_temperature: float = 1.0, - eps: float = 1e-9, -): - """ - Z-score standardize along the last dimension of `logits`. - i.e., for each [B, seq_len] row, across K entries: - z = (logits - mean) / std, - then scale by 1 / base_temperature if desired. - - mask can be broadcastable or None. If None, we standardize all elements. - """ - if mask is None: - # shape: [B, seq_len, K] - # Mean and std over dim=-1 - mean = logits.mean(dim=-1, keepdim=True) - var = logits.var(dim=-1, unbiased=False, keepdim=True) - else: - # If you have to exclude some tokens, multiply by mask, etc. - float_mask = mask.to(logits.dtype) - count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0) - mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count - var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count - - std = torch.sqrt(var.clamp_min(eps)) - z = (logits - mean) / std - - # Scale by 1 / base_temperature - z = z / base_temperature - return z - - @torch.jit.script def loss( student_logits: torch.Tensor, @@ -134,85 +100,6 @@ def loss( return kd_loss -def topk_kd_loss_with_zscore( - student_logits: torch.Tensor, # [B, seq_len, vocab_size] - target_token_ids: torch.Tensor, # [B, seq_len, K] - target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space - target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len] - kd_temperature: float = 1.0, # classic KD temperature - zscore_base_temp: float = 1.0, # from the paper - num_items_in_batch: int = -1, -): - """ - A variant of top_k KL divergence with Z-score scaling - from "Logit Standardization in Knowledge Distillation". - """ - - target_logprobs = target_logprobs.float() - - B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name - # 1) Gather the student's top-k logits to match teacher - student_logits_for_kd = student_logits[ - :, :teacher_seq_len, : - ] # [B, seq_len, vocab] - student_topk_logits = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, seq_len, K] - - student_topk_logits = student_topk_logits.float() - - # 2) If you want to keep the "classical" T scaling, apply it first - if kd_temperature != 1.0: - student_topk_logits = student_topk_logits / kd_temperature - - # 3) Convert teacher logprobs -> treat them as “logits” for z-score - # (They differ by +some_constant from real logits, but in z-score - # that constant is subtracted out anyway.) - teacher_logits_for_zscore = target_logprobs # rename variable for clarity - - # 4) Z-score teacher and student - # If target_mask is 2D, expand to 3D for the K dimension - if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len): - target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K) - - teacher_z = zscore_standardize( - teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp - ) - student_z = zscore_standardize( - student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp - ) - - # 5) Convert to log-probs for KL - teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True) - student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True) - - # 6) Restrict to valid tokens if needed - valid_mask = target_mask.bool() # shape [B, seq_len, K] - teacher_probs_z = teacher_logprobs_z.exp() - teacher_probs_z = teacher_probs_z[valid_mask] - teacher_logprobs_z = teacher_logprobs_z[valid_mask] - student_logprobs_z = student_logprobs_z[valid_mask] - - # 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] ) - kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z) - kd_loss = kd_loss_per_token.sum() - - # 8) If using classical KD scaling by T^2 - if kd_temperature != 1.0: - kd_loss = kd_loss * (kd_temperature**2) - - # Optionally scale by zscore_base_temp**2 if you want (paper might differ). - # kd_loss = kd_loss * (zscore_base_temp**2) - - # 9) Normalize - if num_items_in_batch is not None and num_items_in_batch > 0: - kd_loss = kd_loss / float(num_items_in_batch) - else: - kd_loss = kd_loss / float(kd_loss_per_token.size(0)) - - return kd_loss - - class ChunkedTopKKDLoss(nn.Module): """ A wrapper that chunks (splits) the student and teacher outputs along the time dimension diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index ea15b9c1d..9b6316784 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -18,7 +18,7 @@ KD trainer from axolotl.core.trainers.base import AxolotlTrainer -from .topk_logprob.forward_kl import ChunkedTopKKDLoss, topk_kd_loss_with_zscore +from .topk_logprob.forward_kl import ChunkedTopKKDLoss class AxolotlKDTrainer(AxolotlTrainer): @@ -80,24 +80,13 @@ class AxolotlKDTrainer(AxolotlTrainer): target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() target_mask_for_loss = target_mask[..., 1:, :].contiguous() - if self.args.kd_zscore_base_temp: - loss_kd = topk_kd_loss_with_zscore( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - kd_temperature=self.args.kd_temperature, - zscore_base_temp=self.args.kd_zscore_base_temp, - num_items_in_batch=num_items_in_batch, - ) - else: - loss_kd = self.kd_loss_fn( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - num_items_in_batch=num_items_in_batch, - ) + loss_kd = self.kd_loss_fn( + shift_logits, + target_token_ids_for_loss, + target_logprobs_for_loss, + target_mask_for_loss, + num_items_in_batch=num_items_in_batch, + ) if self.args.kd_ce_alpha > 0: kd_alpha = self.args.kd_alpha