apply z-score scaling to kd

This commit is contained in:
Wing Lian
2025-01-27 14:27:35 -05:00
parent 4e4a16cd8a
commit 2c9dfbed2e
5 changed files with 147 additions and 20 deletions

View File

@@ -697,6 +697,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None: if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[
"kd_zscore_base_temp"
] = self.cfg.kd_zscore_base_temp
training_args_cls = ( training_args_cls = (
AxolotlTrainingArguments AxolotlTrainingArguments

View File

@@ -188,6 +188,13 @@ class AxolotlTrainingMixins:
}, },
) )
kd_zscore_base_temp: Optional[float] = field(
default=None,
metadata={
"help": "the base temperature parameter for KL divergence with z-score when using KD"
},
)
@dataclass @dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):

View File

@@ -31,3 +31,4 @@ class KDArgs(BaseModel):
] = None # loss coefficient for cross-entropy loss during KD ] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling

View File

@@ -16,6 +16,40 @@ loss for top_k KL divergence
import torch import torch
def zscore_standardize(
logits: torch.Tensor,
mask: torch.Tensor = None,
base_temperature: float = 1.0,
eps: float = 1e-9,
):
"""
Z-score standardize along the last dimension of `logits`.
i.e., for each [B, seq_len] row, across K entries:
z = (logits - mean) / std,
then scale by 1 / base_temperature if desired.
mask can be broadcastable or None. If None, we standardize all elements.
"""
if mask is None:
# shape: [B, seq_len, K]
# Mean and std over dim=-1
mean = logits.mean(dim=-1, keepdim=True)
var = logits.var(dim=-1, unbiased=False, keepdim=True)
else:
# If you have to exclude some tokens, multiply by mask, etc.
float_mask = mask.to(logits.dtype)
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
std = torch.sqrt(var.clamp_min(eps))
z = (logits - mean) / std
# Scale by 1 / base_temperature
z = z / base_temperature
return z
@torch.jit.script @torch.jit.script
def loss( def loss(
student_logits: torch.Tensor, student_logits: torch.Tensor,
@@ -80,3 +114,78 @@ def loss(
kd_loss = kd_loss / float(kd_loss_per_token.size(0)) kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss return kd_loss
def topk_kd_loss_with_zscore(
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
teacher_topk_ids: torch.Tensor, # [B, seq_len, K]
teacher_topk_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
teacher_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
kd_temperature: float = 1.0, # classic KD temperature
zscore_base_temp: float = 1.0, # from the paper
num_items_in_batch: int = -1,
):
"""
A variant of top_k KL divergence with Z-score scaling
from "Logit Standardization in Knowledge Distillation".
"""
B, teacher_seq_len, K = teacher_topk_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab]
student_topk_logits = torch.gather(
student_logits_for_kd, dim=-1, index=teacher_topk_ids
) # [B, seq_len, K]
# 2) If you want to keep the "classical" T scaling, apply it first
if kd_temperature != 1.0:
student_topk_logits = student_topk_logits / kd_temperature
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
# (They differ by +some_constant from real logits, but in z-score
# that constant is subtracted out anyway.)
teacher_logits_for_zscore = teacher_topk_logprobs # rename variable for clarity
# 4) Z-score teacher and student
# If teacher_mask is 2D, expand to 3D for the K dimension
if teacher_mask.dim() == 2 and teacher_mask.shape[:2] == (B, teacher_seq_len):
teacher_mask = teacher_mask.unsqueeze(-1).expand(-1, -1, K)
teacher_z = zscore_standardize(
teacher_logits_for_zscore, mask=teacher_mask, base_temperature=zscore_base_temp
)
student_z = zscore_standardize(
student_topk_logits, mask=teacher_mask, base_temperature=zscore_base_temp
)
# 5) Convert to log-probs for KL
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
# 6) Restrict to valid tokens if needed
valid_mask = teacher_mask.bool() # shape [B, seq_len, K]
teacher_probs_z = teacher_logprobs_z.exp()
teacher_probs_z = teacher_probs_z[valid_mask]
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
student_logprobs_z = student_logprobs_z[valid_mask]
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
kd_loss = kd_loss_per_token.sum()
# 8) If using classical KD scaling by T^2
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
# kd_loss = kd_loss * (zscore_base_temp**2)
# 9) Normalize
if num_items_in_batch is not None and num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss

View File

@@ -19,6 +19,7 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainer(AxolotlTrainer): class AxolotlKDTrainer(AxolotlTrainer):
@@ -45,7 +46,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
inputs, inputs,
return_outputs=False, return_outputs=False,
num_items_in_batch=None, num_items_in_batch=None,
shift_targets=True,
): ):
""" """
How the loss is computed by Trainer. By default, all models return the loss in the first element. How the loss is computed by Trainer. By default, all models return the loss in the first element.
@@ -69,26 +69,30 @@ class AxolotlKDTrainer(AxolotlTrainer):
# FIXME: account for tokenizer.padding_side # FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, :seq_len, :].contiguous() student_logits = outputs["logits"][:, :seq_len, :].contiguous()
if shift_targets: shift_logits = student_logits.contiguous()
# shift_logits = student_logits[..., :-1, :].contiguous() target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
shift_logits = student_logits.contiguous() target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() target_mask_for_loss = target_mask[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
else:
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs.contiguous()
target_token_ids_for_loss = target_token_ids.contiguous()
target_mask_for_loss = target_mask.contiguous()
loss_kd = topk_kd_loss( if self.args.kd_zscore_base_temp:
shift_logits, loss_kd = topk_kd_loss_with_zscore(
target_token_ids_for_loss, student_logits=shift_logits,
target_logprobs_for_loss, teacher_topk_ids=target_token_ids_for_loss,
target_mask_for_loss, teacher_topk_logprobs=target_logprobs_for_loss,
num_items_in_batch=num_items_in_batch, teacher_mask=target_mask_for_loss,
kd_temperature=self.args.kd_temperature, kd_temperature=self.args.kd_temperature,
) zscore_base_temp=self.args.kd_zscore_base_temp,
num_items_in_batch=num_items_in_batch,
)
else:
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)
if self.args.kd_ce_alpha > 0: if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha kd_alpha = self.args.kd_alpha