drop top_k before softmax
This commit is contained in:
@@ -324,10 +324,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
||||
self.cfg.kd_zscore_base_temp
|
||||
)
|
||||
if self.cfg.kd_top_k_before_softmax is not None:
|
||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
|
||||
@@ -201,13 +201,6 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
kd_top_k_before_softmax: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
adam_beta3: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
||||
@@ -32,6 +32,3 @@ class KDArgs(BaseModel):
|
||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||
kd_top_k_before_softmax: Optional[bool] = (
|
||||
None # whether to sample top k before softmax during KD
|
||||
)
|
||||
|
||||
@@ -47,6 +47,10 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
|
||||
@@ -61,7 +61,6 @@ def loss(
|
||||
target_mask: torch.Tensor,
|
||||
num_items_in_batch: int = -1, # Use -1 to indicate "None"
|
||||
kd_temperature: float = 1.0,
|
||||
top_k_before_softmax: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A KD loss function that is TorchScript-friendly.
|
||||
@@ -78,8 +77,6 @@ def loss(
|
||||
num_items_in_batch (int, optional): The number of items in the batch.
|
||||
kd_temperature (float, optional): The temperature for KD.
|
||||
Default: 1.0
|
||||
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
|
||||
Default: 0
|
||||
"""
|
||||
|
||||
target_logprobs = target_logprobs.float()
|
||||
@@ -89,46 +86,24 @@ def loss(
|
||||
# student_logits shape: [B, student_seq_len, vocab_size]
|
||||
teacher_seq_len = target_token_ids.shape[1]
|
||||
|
||||
if top_k_before_softmax:
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, teacher_seq_len, vocab_size]
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = (
|
||||
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
||||
) # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
# keep in full precision for numerical stability of loss
|
||||
student_logits_for_kd = student_logits_for_kd.float()
|
||||
|
||||
student_logits_topk = student_logits_topk.float()
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
# Apply KD temperature to student’s logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_topk = student_logits_topk / kd_temperature
|
||||
# Compute logsumexp across full vocabulary
|
||||
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
||||
|
||||
# Convert student top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||
student_logits_topk, dim=-1, keepdim=True
|
||||
) # [B, teacher_seq_len, K]
|
||||
else:
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = (
|
||||
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
||||
) # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# keep in full precision for numerical stability of loss
|
||||
student_logits_for_kd = student_logits_for_kd.float()
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
# Compute logsumexp across full vocabulary
|
||||
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
||||
|
||||
# Convert just the top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - student_lse
|
||||
# Convert just the top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - student_lse
|
||||
|
||||
# Convert teacher_mask to boolean for indexing
|
||||
# In TorchScript, .bool() is sometimes unsupported, so we do:
|
||||
@@ -243,7 +218,7 @@ class ChunkedTopKKDLoss(nn.Module):
|
||||
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
|
||||
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
|
||||
|
||||
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to your top-K teacher logprobs.
|
||||
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
|
||||
"""
|
||||
|
||||
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
|
||||
@@ -258,7 +233,6 @@ class ChunkedTopKKDLoss(nn.Module):
|
||||
target_logprobs: torch.Tensor, # [B, seq_len, K]
|
||||
target_mask: torch.Tensor, # [B, seq_len, K]
|
||||
num_items_in_batch: int = -1, # optional batch size for normalization
|
||||
top_k_before_softmax: int = 0, # optional top-k before softmax for teacher logits
|
||||
) -> torch.Tensor:
|
||||
|
||||
# 1. Split along the "token" dimension (dim=1).
|
||||
@@ -285,7 +259,6 @@ class ChunkedTopKKDLoss(nn.Module):
|
||||
target_mask=msk_chunk,
|
||||
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
|
||||
kd_temperature=self.kd_temperature,
|
||||
top_k_before_softmax=top_k_before_softmax,
|
||||
)
|
||||
|
||||
# kd_loss returns an average over the chunk's valid tokens.
|
||||
|
||||
@@ -97,7 +97,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
# top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
|
||||
Reference in New Issue
Block a user