additional plugin collator kwargs, don't scale up kd loss by t^2

This commit is contained in:
Wing Lian
2025-05-23 08:35:44 -04:00
parent 7263845207
commit 49e2fa825d
7 changed files with 62 additions and 29 deletions

View File

@@ -0,0 +1,31 @@
{
"compile": {
"disable": false,
"backend": "inductor"
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -432,10 +432,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval)
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
self.cfg, is_eval=is_eval
)
if collator_cls:
collator = collator_cls
if collator_cls_and_kwargs:
collator = collator_cls_and_kwargs[0]
if kwargs and isinstance(kwargs, dict):
kwargs.update(collator_cls_and_kwargs[1])
elif self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:

View File

@@ -208,7 +208,9 @@ class BasePlugin:
class: The class for the trainer.
"""
def get_collator_cls(self, cfg, is_eval=False): # pylint: disable=unused-argument):
def get_collator_cls_and_kwargs(
self, cfg, is_eval=False
): # pylint: disable=unused-argument):
"""
Returns a custom class for the collator.
@@ -517,9 +519,9 @@ class PluginManager:
return trainer_cls
return None
def get_collator_cls(self, cfg, is_eval=False):
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
"""
Calls the get_collator_cls method of all registered plugins and returns the first non-None collator class.
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
Parameters:
cfg (dict): The configuration for the plugins.
@@ -529,9 +531,11 @@ class PluginManager:
object: The collator class, or None if none was found.
"""
for plugin in self.plugins.values():
collator_cls = plugin.get_collator_cls(cfg, is_eval=is_eval)
collator_cls, collator_kwargs = plugin.get_collator_cls_and_kwargs(
cfg, is_eval=is_eval
)
if collator_cls is not None:
return collator_cls
return collator_cls, collator_kwargs
return None
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):

View File

@@ -35,9 +35,9 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer
return None
def get_collator_cls(self, cfg, is_eval=False):
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
if not cfg.kd_trainer:
return None
return None, None
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
@@ -48,8 +48,8 @@ class KDPlugin(BasePlugin):
use_batch_sampler_collator = True
if use_batch_sampler_collator:
return KDBatchSamplerDataCollatorForSeq2Seq
return DataCollatorForKD
return KDBatchSamplerDataCollatorForSeq2Seq, {}
return DataCollatorForKD, {}
def pre_model_load(self, cfg):
from .kernels.models import apply_kernel

View File

@@ -15,8 +15,6 @@
"""
Plugin args for KD support.
"""
from typing import Optional
from pydantic import BaseModel
@@ -25,9 +23,13 @@ class KDArgs(BaseModel):
Input args for knowledge distillation.
"""
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[float] = (
kd_trainer: float | None = None # whether to use KD trainer
kd_ce_alpha: float | None = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
# TODO online kd
# kd_online_server_base_url: str | None = None
# kd_online_topk: int | None = None

View File

@@ -20,7 +20,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
temperature: float,
) -> torch.Tensor:
"""
Compute Top-K KL divergence loss for a chunk.
@@ -29,15 +28,15 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
temperature: Temperature used for scaling.
Returns:
Sum of KL divergence losses for the chunk.
"""
student_logits_temp_scaled = student_logits_temp_scaled.float()
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
student_logits_temp_scaled.float()
)
target_logprobs_chunk = target_logprobs_chunk.float()
# Gather student logits for the top-k teacher token IDs
# student_logits_temp_scaled: [chunk_size, vocab_size]
# target_token_ids_chunk: [chunk_size, top_k]
# student_logits_topk_temp_scaled: [chunk_size, top_k]
student_logits_topk_temp_scaled = torch.gather(
@@ -72,9 +71,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
)
kd_loss = kd_loss_per_token.sum()
if temperature != 1.0:
kd_loss = kd_loss * (temperature**2)
return kd_loss
@staticmethod

View File

@@ -86,10 +86,6 @@ def loss(
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()
# Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items (if provided) or by valid tokens
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)