additional plugin collator kwargs, don't scale up kd loss by t^2

This commit is contained in:
Wing Lian
2025-05-23 08:35:44 -04:00
parent 7263845207
commit 49e2fa825d
7 changed files with 62 additions and 29 deletions

View File

@@ -0,0 +1,31 @@
{
"compile": {
"disable": false,
"backend": "inductor"
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -432,10 +432,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.plugins: if self.cfg.plugins:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval) collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
self.cfg, is_eval=is_eval
)
if collator_cls: if collator_cls_and_kwargs:
collator = collator_cls collator = collator_cls_and_kwargs[0]
if kwargs and isinstance(kwargs, dict):
kwargs.update(collator_cls_and_kwargs[1])
elif self.cfg.reward_model: elif self.cfg.reward_model:
collator = RewardDataCollatorWithPadding collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator: elif use_batch_sampler_collator:

View File

@@ -208,7 +208,9 @@ class BasePlugin:
class: The class for the trainer. class: The class for the trainer.
""" """
def get_collator_cls(self, cfg, is_eval=False): # pylint: disable=unused-argument): def get_collator_cls_and_kwargs(
self, cfg, is_eval=False
): # pylint: disable=unused-argument):
""" """
Returns a custom class for the collator. Returns a custom class for the collator.
@@ -517,9 +519,9 @@ class PluginManager:
return trainer_cls return trainer_cls
return None return None
def get_collator_cls(self, cfg, is_eval=False): def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
""" """
Calls the get_collator_cls method of all registered plugins and returns the first non-None collator class. Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
Parameters: Parameters:
cfg (dict): The configuration for the plugins. cfg (dict): The configuration for the plugins.
@@ -529,9 +531,11 @@ class PluginManager:
object: The collator class, or None if none was found. object: The collator class, or None if none was found.
""" """
for plugin in self.plugins.values(): for plugin in self.plugins.values():
collator_cls = plugin.get_collator_cls(cfg, is_eval=is_eval) collator_cls, collator_kwargs = plugin.get_collator_cls_and_kwargs(
cfg, is_eval=is_eval
)
if collator_cls is not None: if collator_cls is not None:
return collator_cls return collator_cls, collator_kwargs
return None return None
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):

View File

@@ -35,9 +35,9 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer return AxolotlKDTrainer
return None return None
def get_collator_cls(self, cfg, is_eval=False): def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
if not cfg.kd_trainer: if not cfg.kd_trainer:
return None return None, None
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
@@ -48,8 +48,8 @@ class KDPlugin(BasePlugin):
use_batch_sampler_collator = True use_batch_sampler_collator = True
if use_batch_sampler_collator: if use_batch_sampler_collator:
return KDBatchSamplerDataCollatorForSeq2Seq return KDBatchSamplerDataCollatorForSeq2Seq, {}
return DataCollatorForKD return DataCollatorForKD, {}
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
from .kernels.models import apply_kernel from .kernels.models import apply_kernel

View File

@@ -15,8 +15,6 @@
""" """
Plugin args for KD support. Plugin args for KD support.
""" """
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
@@ -25,9 +23,13 @@ class KDArgs(BaseModel):
Input args for knowledge distillation. Input args for knowledge distillation.
""" """
kd_trainer: Optional[bool] = None # whether to use KD trainer kd_trainer: float | None = None # whether to use KD trainer
kd_ce_alpha: Optional[float] = ( kd_ce_alpha: float | None = (
None # loss coefficient for cross-entropy loss during KD None # loss coefficient for cross-entropy loss during KD
) )
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: float | None = None # temperature for sampling during KD
# TODO online kd
# kd_online_server_base_url: str | None = None
# kd_online_topk: int | None = None

View File

@@ -20,7 +20,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k] target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
target_mask_chunk: torch.Tensor, # [chunk_size, top_k] target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
temperature: float,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute Top-K KL divergence loss for a chunk. Compute Top-K KL divergence loss for a chunk.
@@ -29,15 +28,15 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K). target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K). target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K). target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
temperature: Temperature used for scaling.
Returns: Returns:
Sum of KL divergence losses for the chunk. Sum of KL divergence losses for the chunk.
""" """
student_logits_temp_scaled = student_logits_temp_scaled.float() student_logits_temp_scaled = ( # [chunk_size, vocab_size]
student_logits_temp_scaled.float()
)
target_logprobs_chunk = target_logprobs_chunk.float() target_logprobs_chunk = target_logprobs_chunk.float()
# Gather student logits for the top-k teacher token IDs # Gather student logits for the top-k teacher token IDs
# student_logits_temp_scaled: [chunk_size, vocab_size]
# target_token_ids_chunk: [chunk_size, top_k] # target_token_ids_chunk: [chunk_size, top_k]
# student_logits_topk_temp_scaled: [chunk_size, top_k] # student_logits_topk_temp_scaled: [chunk_size, top_k]
student_logits_topk_temp_scaled = torch.gather( student_logits_topk_temp_scaled = torch.gather(
@@ -72,9 +71,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
) )
kd_loss = kd_loss_per_token.sum() kd_loss = kd_loss_per_token.sum()
if temperature != 1.0:
kd_loss = kd_loss * (temperature**2)
return kd_loss return kd_loss
@staticmethod @staticmethod

View File

@@ -86,10 +86,6 @@ def loss(
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum() kd_loss = kd_loss_per_token.sum()
# Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items (if provided) or by valid tokens # Normalize by number of items (if provided) or by valid tokens
if num_items_in_batch > 0: if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch) kd_loss = kd_loss / float(num_items_in_batch)