drop top_k before softmax

This commit is contained in:
Wing Lian
2025-05-20 12:43:23 -07:00
parent a2248673d8
commit 22b50d6619
6 changed files with 19 additions and 57 deletions

View File

@@ -324,10 +324,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["kd_zscore_base_temp"] = (
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig

View File

@@ -201,13 +201,6 @@ class AxolotlTrainingMixins:
},
)
kd_top_k_before_softmax: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
},
)
adam_beta3: Optional[float] = field(
default=None,
metadata={

View File

@@ -32,6 +32,3 @@ class KDArgs(BaseModel):
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[bool] = (
None # whether to sample top k before softmax during KD
)

View File

@@ -47,6 +47,10 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
position_pad_token_id: int = 0
return_tensors: str = "pt"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors

View File

@@ -61,7 +61,6 @@ def loss(
target_mask: torch.Tensor,
num_items_in_batch: int = -1, # Use -1 to indicate "None"
kd_temperature: float = 1.0,
top_k_before_softmax: int = 0,
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
@@ -78,8 +77,6 @@ def loss(
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
Default: 0
"""
target_logprobs = target_logprobs.float()
@@ -89,46 +86,24 @@ def loss(
# student_logits shape: [B, student_seq_len, vocab_size]
teacher_seq_len = target_token_ids.shape[1]
if top_k_before_softmax:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = (
student_logits[:, :teacher_seq_len, :] / kd_temperature
) # [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# keep in full precision for numerical stability of loss
student_logits_for_kd = student_logits_for_kd.float()
student_logits_topk = student_logits_topk.float()
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Apply KD temperature to students logits
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Compute logsumexp across full vocabulary
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, teacher_seq_len, K]
else:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = (
student_logits[:, :teacher_seq_len, :] / kd_temperature
) # [B, teacher_seq_len, vocab_size]
# keep in full precision for numerical stability of loss
student_logits_for_kd = student_logits_for_kd.float()
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Compute logsumexp across full vocabulary
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
# Convert just the top-k logits to logprobs
student_logprobs_topk = student_logits_topk - student_lse
# Convert just the top-k logits to logprobs
student_logprobs_topk = student_logits_topk - student_lse
# Convert teacher_mask to boolean for indexing
# In TorchScript, .bool() is sometimes unsupported, so we do:
@@ -243,7 +218,7 @@ class ChunkedTopKKDLoss(nn.Module):
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to your top-K teacher logprobs.
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
"""
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
@@ -258,7 +233,6 @@ class ChunkedTopKKDLoss(nn.Module):
target_logprobs: torch.Tensor, # [B, seq_len, K]
target_mask: torch.Tensor, # [B, seq_len, K]
num_items_in_batch: int = -1, # optional batch size for normalization
top_k_before_softmax: int = 0, # optional top-k before softmax for teacher logits
) -> torch.Tensor:
# 1. Split along the "token" dimension (dim=1).
@@ -285,7 +259,6 @@ class ChunkedTopKKDLoss(nn.Module):
target_mask=msk_chunk,
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
kd_temperature=self.kd_temperature,
top_k_before_softmax=top_k_before_softmax,
)
# kd_loss returns an average over the chunk's valid tokens.

View File

@@ -97,7 +97,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
# top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
)
if self.args.kd_ce_alpha > 0: