better rescaling for temperatures

This commit is contained in:
Wing Lian
2024-12-24 09:26:27 -05:00
parent d8d817eaed
commit 7366efc4ca
5 changed files with 62 additions and 17 deletions

View File

@@ -680,6 +680,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_ce_alpha is not None: if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
training_args_cls = ( training_args_cls = (
AxolotlTrainingArguments AxolotlTrainingArguments

View File

@@ -16,6 +16,7 @@ def kd_loss_function(
target_logprobs, target_logprobs,
target_mask, target_mask,
num_items_in_batch: Optional[int] = None, num_items_in_batch: Optional[int] = None,
kd_temperature: float = 1.0,
): ):
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding # teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding
@@ -28,9 +29,15 @@ def kd_loss_function(
] # [B, teacher_seq_len, vocab_size] ] # [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens # Gather student logits for teacher's top-K tokens
# shape -> [B, teacher_seq_len, K]
student_logits_topk = torch.gather( student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K] )
# Apply KD temperature to students logits:
# z_s(T) = z_s / T
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Convert student top-k logits to logprobs # Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp( student_logprobs_topk = student_logits_topk - torch.logsumexp(
@@ -53,6 +60,10 @@ def kd_loss_function(
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum() kd_loss = kd_loss_per_token.sum()
# 9) Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items or mean over valid tokens # Normalize by number of items or mean over valid tokens
if num_items_in_batch is not None: if num_items_in_batch is not None:
# If you know how many items should be considered in the batch # If you know how many items should be considered in the batch
@@ -129,7 +140,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
# optionally combine with CE loss # optionally combine with CE loss
if self.args.kd_ce_alpha > 0: if self.args.kd_ce_alpha > 0:
loss = self.args.kd_ce_alpha * outputs["loss"] + loss_kd kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else: else:
loss = loss_kd loss = loss_kd
@@ -166,6 +178,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
target_logprobs, target_logprobs,
target_mask, target_mask,
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
) )
if self.args.kd_ce_alpha > 0: if self.args.kd_ce_alpha > 0:

View File

@@ -172,7 +172,19 @@ class AxolotlTrainingMixins:
kd_ce_alpha: Optional[float] = field( kd_ce_alpha: Optional[float] = field(
default=None, default=None,
metadata={ metadata={
"help": "The alpha parameter for SFT cross entropy loss when using KD" "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
},
)
kd_alpha: Optional[float] = field(
default=1.0,
metadata={"help": "The alpha scaling parameter for KD loss"},
)
kd_temperature: Optional[float] = field(
default=1.0,
metadata={
"help": "the temperature parameter for KL divergence loss when using KD"
}, },
) )

View File

@@ -474,10 +474,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
roles_to_train=None, roles_to_train=None,
train_on_eos=None, train_on_eos=None,
logprobs_field="logprobs", logprobs_field="logprobs",
temperature=1.0, gen_temperature=1.0,
kd_temperature=1.0,
): ):
self.logprobs_field = logprobs_field self.logprobs_field = logprobs_field
self.temperature = temperature self.gen_temperature = gen_temperature
self.kd_temperature = kd_temperature
super().__init__( super().__init__(
prompter, prompter,
@@ -531,14 +533,25 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
position_logprobs, dtype=torch.float position_logprobs, dtype=torch.float
) )
# Apply temperature scaling at data load time if self.kd_temperature != self.gen_temperature:
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T) #
position_logprobs_tensor = position_logprobs_tensor / self.temperature # Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# normalize to probabilities so they sum up to 1 # Next, re-scale to T2 = self.kd_temperature via exponent-based trick
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp( # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
position_logprobs_tensor, dim=0, keepdim=True #
) # Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist() position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled) target_logprobs.append(position_logprobs_scaled)
@@ -593,13 +606,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
} }
strategy_cls = ChatTemplateStrategy strategy_cls = ChatTemplateStrategy
if logprobs_field := ds_cfg.get("logprobs_field"):
strategy_params["logprobs_field"] = logprobs_field
if temperature := ds_cfg.get("temperature"):
strategy_params["temperature"] = temperature
if cfg.trainer == "kd" or logprobs_field: if cfg.trainer == "kd":
strategy_cls = ChatTemplateStrategyWithKD strategy_cls = ChatTemplateStrategyWithKD
if logprobs_field := ds_cfg.get("logprobs_field"):
strategy_params["logprobs_field"] = logprobs_field
if gen_temperature := ds_cfg.get("temperature"):
strategy_params["gen_temperature"] = gen_temperature
if kd_temperature := cfg.get("kd_temperature"):
strategy_params["kd_temperature"] = kd_temperature
strategy = strategy_cls( strategy = strategy_cls(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params

View File

@@ -178,6 +178,7 @@ class SFTDataset(BaseModel):
message_field_training: Optional[str] = None message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None message_field_training_detail: Optional[str] = None
logprobs_field: Optional[str] = None logprobs_field: Optional[str] = None
temperature: Optional[float] = None
roles_to_train: Optional[List[str]] = None roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None roles: Optional[Dict[str, List[str]]] = None
@@ -626,6 +627,8 @@ class AxolotlInputConfig(
kd_ce_alpha: Optional[ kd_ce_alpha: Optional[
float float
] = None # loss coefficient for cross-entropy loss during KD ] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore