simplfy and remove zscore
This commit is contained in:
@@ -320,10 +320,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
||||||
if self.cfg.kd_temperature is not None:
|
if self.cfg.kd_temperature is not None:
|
||||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||||
if self.cfg.kd_zscore_base_temp is not None:
|
|
||||||
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
|
||||||
self.cfg.kd_zscore_base_temp
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
|
|||||||
@@ -194,13 +194,6 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
kd_zscore_base_temp: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "the base temperature parameter for KL divergence with z-score when using KD"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
adam_beta3: Optional[float] = field(
|
adam_beta3: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
|
|||||||
@@ -31,4 +31,3 @@ class KDArgs(BaseModel):
|
|||||||
)
|
)
|
||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
|
||||||
|
|||||||
@@ -19,40 +19,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
def zscore_standardize(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
mask: torch.Tensor = None,
|
|
||||||
base_temperature: float = 1.0,
|
|
||||||
eps: float = 1e-9,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Z-score standardize along the last dimension of `logits`.
|
|
||||||
i.e., for each [B, seq_len] row, across K entries:
|
|
||||||
z = (logits - mean) / std,
|
|
||||||
then scale by 1 / base_temperature if desired.
|
|
||||||
|
|
||||||
mask can be broadcastable or None. If None, we standardize all elements.
|
|
||||||
"""
|
|
||||||
if mask is None:
|
|
||||||
# shape: [B, seq_len, K]
|
|
||||||
# Mean and std over dim=-1
|
|
||||||
mean = logits.mean(dim=-1, keepdim=True)
|
|
||||||
var = logits.var(dim=-1, unbiased=False, keepdim=True)
|
|
||||||
else:
|
|
||||||
# If you have to exclude some tokens, multiply by mask, etc.
|
|
||||||
float_mask = mask.to(logits.dtype)
|
|
||||||
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
|
|
||||||
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
|
|
||||||
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
|
|
||||||
|
|
||||||
std = torch.sqrt(var.clamp_min(eps))
|
|
||||||
z = (logits - mean) / std
|
|
||||||
|
|
||||||
# Scale by 1 / base_temperature
|
|
||||||
z = z / base_temperature
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def loss(
|
def loss(
|
||||||
student_logits: torch.Tensor,
|
student_logits: torch.Tensor,
|
||||||
@@ -134,85 +100,6 @@ def loss(
|
|||||||
return kd_loss
|
return kd_loss
|
||||||
|
|
||||||
|
|
||||||
def topk_kd_loss_with_zscore(
|
|
||||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
|
||||||
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
|
||||||
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
|
||||||
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
|
||||||
kd_temperature: float = 1.0, # classic KD temperature
|
|
||||||
zscore_base_temp: float = 1.0, # from the paper
|
|
||||||
num_items_in_batch: int = -1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
A variant of top_k KL divergence with Z-score scaling
|
|
||||||
from "Logit Standardization in Knowledge Distillation".
|
|
||||||
"""
|
|
||||||
|
|
||||||
target_logprobs = target_logprobs.float()
|
|
||||||
|
|
||||||
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
|
|
||||||
# 1) Gather the student's top-k logits to match teacher
|
|
||||||
student_logits_for_kd = student_logits[
|
|
||||||
:, :teacher_seq_len, :
|
|
||||||
] # [B, seq_len, vocab]
|
|
||||||
student_topk_logits = torch.gather(
|
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
|
||||||
) # [B, seq_len, K]
|
|
||||||
|
|
||||||
student_topk_logits = student_topk_logits.float()
|
|
||||||
|
|
||||||
# 2) If you want to keep the "classical" T scaling, apply it first
|
|
||||||
if kd_temperature != 1.0:
|
|
||||||
student_topk_logits = student_topk_logits / kd_temperature
|
|
||||||
|
|
||||||
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
|
|
||||||
# (They differ by +some_constant from real logits, but in z-score
|
|
||||||
# that constant is subtracted out anyway.)
|
|
||||||
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
|
|
||||||
|
|
||||||
# 4) Z-score teacher and student
|
|
||||||
# If target_mask is 2D, expand to 3D for the K dimension
|
|
||||||
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
|
|
||||||
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
|
|
||||||
|
|
||||||
teacher_z = zscore_standardize(
|
|
||||||
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
|
|
||||||
)
|
|
||||||
student_z = zscore_standardize(
|
|
||||||
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5) Convert to log-probs for KL
|
|
||||||
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
|
|
||||||
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# 6) Restrict to valid tokens if needed
|
|
||||||
valid_mask = target_mask.bool() # shape [B, seq_len, K]
|
|
||||||
teacher_probs_z = teacher_logprobs_z.exp()
|
|
||||||
teacher_probs_z = teacher_probs_z[valid_mask]
|
|
||||||
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
|
|
||||||
student_logprobs_z = student_logprobs_z[valid_mask]
|
|
||||||
|
|
||||||
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
|
|
||||||
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
|
|
||||||
kd_loss = kd_loss_per_token.sum()
|
|
||||||
|
|
||||||
# 8) If using classical KD scaling by T^2
|
|
||||||
if kd_temperature != 1.0:
|
|
||||||
kd_loss = kd_loss * (kd_temperature**2)
|
|
||||||
|
|
||||||
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
|
|
||||||
# kd_loss = kd_loss * (zscore_base_temp**2)
|
|
||||||
|
|
||||||
# 9) Normalize
|
|
||||||
if num_items_in_batch is not None and num_items_in_batch > 0:
|
|
||||||
kd_loss = kd_loss / float(num_items_in_batch)
|
|
||||||
else:
|
|
||||||
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
|
||||||
|
|
||||||
return kd_loss
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkedTopKKDLoss(nn.Module):
|
class ChunkedTopKKDLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
|
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ KD trainer
|
|||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
from .topk_logprob.forward_kl import ChunkedTopKKDLoss, topk_kd_loss_with_zscore
|
from .topk_logprob.forward_kl import ChunkedTopKKDLoss
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKDTrainer(AxolotlTrainer):
|
class AxolotlKDTrainer(AxolotlTrainer):
|
||||||
@@ -80,24 +80,13 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||||
|
|
||||||
if self.args.kd_zscore_base_temp:
|
loss_kd = self.kd_loss_fn(
|
||||||
loss_kd = topk_kd_loss_with_zscore(
|
shift_logits,
|
||||||
shift_logits,
|
target_token_ids_for_loss,
|
||||||
target_token_ids_for_loss,
|
target_logprobs_for_loss,
|
||||||
target_logprobs_for_loss,
|
target_mask_for_loss,
|
||||||
target_mask_for_loss,
|
num_items_in_batch=num_items_in_batch,
|
||||||
kd_temperature=self.args.kd_temperature,
|
)
|
||||||
zscore_base_temp=self.args.kd_zscore_base_temp,
|
|
||||||
num_items_in_batch=num_items_in_batch,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
loss_kd = self.kd_loss_fn(
|
|
||||||
shift_logits,
|
|
||||||
target_token_ids_for_loss,
|
|
||||||
target_logprobs_for_loss,
|
|
||||||
target_mask_for_loss,
|
|
||||||
num_items_in_batch=num_items_in_batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.kd_ce_alpha > 0:
|
if self.args.kd_ce_alpha > 0:
|
||||||
kd_alpha = self.args.kd_alpha
|
kd_alpha = self.args.kd_alpha
|
||||||
|
|||||||
Reference in New Issue
Block a user