From 2c9dfbed2eb3e3a36cdb78e25f71d807f3db4673 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 27 Jan 2025 14:27:35 -0500 Subject: [PATCH] apply z-score scaling to kd --- src/axolotl/core/trainer_builder.py | 6 + src/axolotl/core/training_args.py | 7 ++ src/axolotl/integrations/kd/args.py | 1 + .../kd/topk_logprob/forward_kl.py | 109 ++++++++++++++++++ src/axolotl/integrations/kd/trainer.py | 44 +++---- 5 files changed, 147 insertions(+), 20 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 8be180c95..088186f86 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -697,6 +697,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha if self.cfg.kd_alpha is not None: training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha + if self.cfg.kd_temperature is not None: + training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature + if self.cfg.kd_zscore_base_temp is not None: + training_arguments_kwargs[ + "kd_zscore_base_temp" + ] = self.cfg.kd_zscore_base_temp training_args_cls = ( AxolotlTrainingArguments diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 36c43025d..2a015330e 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -188,6 +188,13 @@ class AxolotlTrainingMixins: }, ) + kd_zscore_base_temp: Optional[float] = field( + default=None, + metadata={ + "help": "the base temperature parameter for KL divergence with z-score when using KD" + }, + ) + @dataclass class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index d5dec42ea..16a11cd70 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -31,3 +31,4 @@ class KDArgs(BaseModel): ] = None # loss coefficient for cross-entropy loss during KD kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_temperature: Optional[float] = None # temperature for sampling during KD + kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 545223825..6a1c80411 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -16,6 +16,40 @@ loss for top_k KL divergence import torch +def zscore_standardize( + logits: torch.Tensor, + mask: torch.Tensor = None, + base_temperature: float = 1.0, + eps: float = 1e-9, +): + """ + Z-score standardize along the last dimension of `logits`. + i.e., for each [B, seq_len] row, across K entries: + z = (logits - mean) / std, + then scale by 1 / base_temperature if desired. + + mask can be broadcastable or None. If None, we standardize all elements. + """ + if mask is None: + # shape: [B, seq_len, K] + # Mean and std over dim=-1 + mean = logits.mean(dim=-1, keepdim=True) + var = logits.var(dim=-1, unbiased=False, keepdim=True) + else: + # If you have to exclude some tokens, multiply by mask, etc. + float_mask = mask.to(logits.dtype) + count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0) + mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count + var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count + + std = torch.sqrt(var.clamp_min(eps)) + z = (logits - mean) / std + + # Scale by 1 / base_temperature + z = z / base_temperature + return z + + @torch.jit.script def loss( student_logits: torch.Tensor, @@ -80,3 +114,78 @@ def loss( kd_loss = kd_loss / float(kd_loss_per_token.size(0)) return kd_loss + + +def topk_kd_loss_with_zscore( + student_logits: torch.Tensor, # [B, seq_len, vocab_size] + teacher_topk_ids: torch.Tensor, # [B, seq_len, K] + teacher_topk_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space + teacher_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len] + kd_temperature: float = 1.0, # classic KD temperature + zscore_base_temp: float = 1.0, # from the paper + num_items_in_batch: int = -1, +): + """ + A variant of top_k KL divergence with Z-score scaling + from "Logit Standardization in Knowledge Distillation". + """ + + B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name + # 1) Gather the student's top-k logits to match teacher + student_logits_for_kd = student_logits[ + :, :teacher_seq_len, : + ] # [B, seq_len, vocab] + student_topk_logits = torch.gather( + student_logits_for_kd, dim=-1, index=teacher_topk_ids + ) # [B, seq_len, K] + + # 2) If you want to keep the "classical" T scaling, apply it first + if kd_temperature != 1.0: + student_topk_logits = student_topk_logits / kd_temperature + + # 3) Convert teacher logprobs -> treat them as “logits” for z-score + # (They differ by +some_constant from real logits, but in z-score + # that constant is subtracted out anyway.) + teacher_logits_for_zscore = teacher_topk_logprobs # rename variable for clarity + + # 4) Z-score teacher and student + # If teacher_mask is 2D, expand to 3D for the K dimension + if teacher_mask.dim() == 2 and teacher_mask.shape[:2] == (B, teacher_seq_len): + teacher_mask = teacher_mask.unsqueeze(-1).expand(-1, -1, K) + + teacher_z = zscore_standardize( + teacher_logits_for_zscore, mask=teacher_mask, base_temperature=zscore_base_temp + ) + student_z = zscore_standardize( + student_topk_logits, mask=teacher_mask, base_temperature=zscore_base_temp + ) + + # 5) Convert to log-probs for KL + teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True) + student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True) + + # 6) Restrict to valid tokens if needed + valid_mask = teacher_mask.bool() # shape [B, seq_len, K] + teacher_probs_z = teacher_logprobs_z.exp() + teacher_probs_z = teacher_probs_z[valid_mask] + teacher_logprobs_z = teacher_logprobs_z[valid_mask] + student_logprobs_z = student_logprobs_z[valid_mask] + + # 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] ) + kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z) + kd_loss = kd_loss_per_token.sum() + + # 8) If using classical KD scaling by T^2 + if kd_temperature != 1.0: + kd_loss = kd_loss * (kd_temperature**2) + + # Optionally scale by zscore_base_temp**2 if you want (paper might differ). + # kd_loss = kd_loss * (zscore_base_temp**2) + + # 9) Normalize + if num_items_in_batch is not None and num_items_in_batch > 0: + kd_loss = kd_loss / float(num_items_in_batch) + else: + kd_loss = kd_loss / float(kd_loss_per_token.size(0)) + + return kd_loss diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 9eac4cc1d..9d99a7e1d 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -19,6 +19,7 @@ KD trainer from axolotl.core.trainers.base import AxolotlTrainer from .topk_logprob.forward_kl import loss as topk_kd_loss +from .topk_logprob.forward_kl import topk_kd_loss_with_zscore class AxolotlKDTrainer(AxolotlTrainer): @@ -45,7 +46,6 @@ class AxolotlKDTrainer(AxolotlTrainer): inputs, return_outputs=False, num_items_in_batch=None, - shift_targets=True, ): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. @@ -69,26 +69,30 @@ class AxolotlKDTrainer(AxolotlTrainer): # FIXME: account for tokenizer.padding_side student_logits = outputs["logits"][:, :seq_len, :].contiguous() - if shift_targets: - # shift_logits = student_logits[..., :-1, :].contiguous() - shift_logits = student_logits.contiguous() - target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() - target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() - target_mask_for_loss = target_mask[..., 1:, :].contiguous() - else: - shift_logits = student_logits.contiguous() - target_logprobs_for_loss = target_logprobs.contiguous() - target_token_ids_for_loss = target_token_ids.contiguous() - target_mask_for_loss = target_mask.contiguous() + shift_logits = student_logits.contiguous() + target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() + target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() + target_mask_for_loss = target_mask[..., 1:, :].contiguous() - loss_kd = topk_kd_loss( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - num_items_in_batch=num_items_in_batch, - kd_temperature=self.args.kd_temperature, - ) + if self.args.kd_zscore_base_temp: + loss_kd = topk_kd_loss_with_zscore( + student_logits=shift_logits, + teacher_topk_ids=target_token_ids_for_loss, + teacher_topk_logprobs=target_logprobs_for_loss, + teacher_mask=target_mask_for_loss, + kd_temperature=self.args.kd_temperature, + zscore_base_temp=self.args.kd_zscore_base_temp, + num_items_in_batch=num_items_in_batch, + ) + else: + loss_kd = topk_kd_loss( + shift_logits, + target_token_ids_for_loss, + target_logprobs_for_loss, + target_mask_for_loss, + num_items_in_batch=num_items_in_batch, + kd_temperature=self.args.kd_temperature, + ) if self.args.kd_ce_alpha > 0: kd_alpha = self.args.kd_alpha