drop top_k before softmax
This commit is contained in:
@@ -324,10 +324,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
||||||
self.cfg.kd_zscore_base_temp
|
self.cfg.kd_zscore_base_temp
|
||||||
)
|
)
|
||||||
if self.cfg.kd_top_k_before_softmax is not None:
|
|
||||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
|
||||||
self.cfg.kd_top_k_before_softmax
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
|
|||||||
@@ -201,13 +201,6 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
kd_top_k_before_softmax: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
adam_beta3: Optional[float] = field(
|
adam_beta3: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
|
|||||||
@@ -32,6 +32,3 @@ class KDArgs(BaseModel):
|
|||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||||
kd_top_k_before_softmax: Optional[bool] = (
|
|
||||||
None # whether to sample top k before softmax during KD
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -47,6 +47,10 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
|||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
if return_tensors is None:
|
if return_tensors is None:
|
||||||
return_tensors = self.return_tensors
|
return_tensors = self.return_tensors
|
||||||
|
|||||||
@@ -61,7 +61,6 @@ def loss(
|
|||||||
target_mask: torch.Tensor,
|
target_mask: torch.Tensor,
|
||||||
num_items_in_batch: int = -1, # Use -1 to indicate "None"
|
num_items_in_batch: int = -1, # Use -1 to indicate "None"
|
||||||
kd_temperature: float = 1.0,
|
kd_temperature: float = 1.0,
|
||||||
top_k_before_softmax: int = 0,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
A KD loss function that is TorchScript-friendly.
|
A KD loss function that is TorchScript-friendly.
|
||||||
@@ -78,8 +77,6 @@ def loss(
|
|||||||
num_items_in_batch (int, optional): The number of items in the batch.
|
num_items_in_batch (int, optional): The number of items in the batch.
|
||||||
kd_temperature (float, optional): The temperature for KD.
|
kd_temperature (float, optional): The temperature for KD.
|
||||||
Default: 1.0
|
Default: 1.0
|
||||||
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
|
|
||||||
Default: 0
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
target_logprobs = target_logprobs.float()
|
target_logprobs = target_logprobs.float()
|
||||||
@@ -89,46 +86,24 @@ def loss(
|
|||||||
# student_logits shape: [B, student_seq_len, vocab_size]
|
# student_logits shape: [B, student_seq_len, vocab_size]
|
||||||
teacher_seq_len = target_token_ids.shape[1]
|
teacher_seq_len = target_token_ids.shape[1]
|
||||||
|
|
||||||
if top_k_before_softmax:
|
# Slice student logits to match teacher-provided sequence length
|
||||||
# Slice student logits to match teacher-provided sequence length
|
student_logits_for_kd = (
|
||||||
student_logits_for_kd = student_logits[
|
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
||||||
:, :teacher_seq_len, :
|
) # [B, teacher_seq_len, vocab_size]
|
||||||
] # [B, teacher_seq_len, vocab_size]
|
|
||||||
|
|
||||||
# Gather student logits for teacher's top-K tokens
|
# keep in full precision for numerical stability of loss
|
||||||
student_logits_topk = torch.gather(
|
student_logits_for_kd = student_logits_for_kd.float()
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
|
||||||
) # [B, teacher_seq_len, K]
|
|
||||||
|
|
||||||
student_logits_topk = student_logits_topk.float()
|
# Gather student logits for teacher's top-K tokens
|
||||||
|
student_logits_topk = torch.gather(
|
||||||
|
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||||
|
) # [B, teacher_seq_len, K]
|
||||||
|
|
||||||
# Apply KD temperature to student’s logits
|
# Compute logsumexp across full vocabulary
|
||||||
if kd_temperature != 1.0:
|
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
||||||
student_logits_topk = student_logits_topk / kd_temperature
|
|
||||||
|
|
||||||
# Convert student top-k logits to logprobs
|
# Convert just the top-k logits to logprobs
|
||||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
student_logprobs_topk = student_logits_topk - student_lse
|
||||||
student_logits_topk, dim=-1, keepdim=True
|
|
||||||
) # [B, teacher_seq_len, K]
|
|
||||||
else:
|
|
||||||
# Slice student logits to match teacher-provided sequence length
|
|
||||||
student_logits_for_kd = (
|
|
||||||
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
|
||||||
) # [B, teacher_seq_len, vocab_size]
|
|
||||||
|
|
||||||
# keep in full precision for numerical stability of loss
|
|
||||||
student_logits_for_kd = student_logits_for_kd.float()
|
|
||||||
|
|
||||||
# Gather student logits for teacher's top-K tokens
|
|
||||||
student_logits_topk = torch.gather(
|
|
||||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
|
||||||
) # [B, teacher_seq_len, K]
|
|
||||||
|
|
||||||
# Compute logsumexp across full vocabulary
|
|
||||||
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Convert just the top-k logits to logprobs
|
|
||||||
student_logprobs_topk = student_logits_topk - student_lse
|
|
||||||
|
|
||||||
# Convert teacher_mask to boolean for indexing
|
# Convert teacher_mask to boolean for indexing
|
||||||
# In TorchScript, .bool() is sometimes unsupported, so we do:
|
# In TorchScript, .bool() is sometimes unsupported, so we do:
|
||||||
@@ -243,7 +218,7 @@ class ChunkedTopKKDLoss(nn.Module):
|
|||||||
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
|
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
|
||||||
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
|
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
|
||||||
|
|
||||||
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to your top-K teacher logprobs.
|
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
|
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
|
||||||
@@ -258,7 +233,6 @@ class ChunkedTopKKDLoss(nn.Module):
|
|||||||
target_logprobs: torch.Tensor, # [B, seq_len, K]
|
target_logprobs: torch.Tensor, # [B, seq_len, K]
|
||||||
target_mask: torch.Tensor, # [B, seq_len, K]
|
target_mask: torch.Tensor, # [B, seq_len, K]
|
||||||
num_items_in_batch: int = -1, # optional batch size for normalization
|
num_items_in_batch: int = -1, # optional batch size for normalization
|
||||||
top_k_before_softmax: int = 0, # optional top-k before softmax for teacher logits
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# 1. Split along the "token" dimension (dim=1).
|
# 1. Split along the "token" dimension (dim=1).
|
||||||
@@ -285,7 +259,6 @@ class ChunkedTopKKDLoss(nn.Module):
|
|||||||
target_mask=msk_chunk,
|
target_mask=msk_chunk,
|
||||||
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
|
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
|
||||||
kd_temperature=self.kd_temperature,
|
kd_temperature=self.kd_temperature,
|
||||||
top_k_before_softmax=top_k_before_softmax,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# kd_loss returns an average over the chunk's valid tokens.
|
# kd_loss returns an average over the chunk's valid tokens.
|
||||||
|
|||||||
@@ -97,7 +97,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
target_logprobs_for_loss,
|
target_logprobs_for_loss,
|
||||||
target_mask_for_loss,
|
target_mask_for_loss,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
# top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.kd_ce_alpha > 0:
|
if self.args.kd_ce_alpha > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user