additional plugin collator kwargs, don't scale up kd loss by t^2
This commit is contained in:
31
deepspeed_configs/zero2_torch_compile.json
Normal file
31
deepspeed_configs/zero2_torch_compile.json
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"compile": {
|
||||||
|
"disable": false,
|
||||||
|
"backend": "inductor"
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu"
|
||||||
|
},
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"overlap_comm": true
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
@@ -432,10 +432,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.plugins:
|
if self.cfg.plugins:
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval)
|
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
|
||||||
|
self.cfg, is_eval=is_eval
|
||||||
|
)
|
||||||
|
|
||||||
if collator_cls:
|
if collator_cls_and_kwargs:
|
||||||
collator = collator_cls
|
collator = collator_cls_and_kwargs[0]
|
||||||
|
if kwargs and isinstance(kwargs, dict):
|
||||||
|
kwargs.update(collator_cls_and_kwargs[1])
|
||||||
elif self.cfg.reward_model:
|
elif self.cfg.reward_model:
|
||||||
collator = RewardDataCollatorWithPadding
|
collator = RewardDataCollatorWithPadding
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
|
|||||||
@@ -208,7 +208,9 @@ class BasePlugin:
|
|||||||
class: The class for the trainer.
|
class: The class for the trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_collator_cls(self, cfg, is_eval=False): # pylint: disable=unused-argument):
|
def get_collator_cls_and_kwargs(
|
||||||
|
self, cfg, is_eval=False
|
||||||
|
): # pylint: disable=unused-argument):
|
||||||
"""
|
"""
|
||||||
Returns a custom class for the collator.
|
Returns a custom class for the collator.
|
||||||
|
|
||||||
@@ -517,9 +519,9 @@ class PluginManager:
|
|||||||
return trainer_cls
|
return trainer_cls
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_collator_cls(self, cfg, is_eval=False):
|
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||||
"""
|
"""
|
||||||
Calls the get_collator_cls method of all registered plugins and returns the first non-None collator class.
|
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugins.
|
cfg (dict): The configuration for the plugins.
|
||||||
@@ -529,9 +531,11 @@ class PluginManager:
|
|||||||
object: The collator class, or None if none was found.
|
object: The collator class, or None if none was found.
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
collator_cls = plugin.get_collator_cls(cfg, is_eval=is_eval)
|
collator_cls, collator_kwargs = plugin.get_collator_cls_and_kwargs(
|
||||||
|
cfg, is_eval=is_eval
|
||||||
|
)
|
||||||
if collator_cls is not None:
|
if collator_cls is not None:
|
||||||
return collator_cls
|
return collator_cls, collator_kwargs
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||||
|
|||||||
@@ -35,9 +35,9 @@ class KDPlugin(BasePlugin):
|
|||||||
return AxolotlKDTrainer
|
return AxolotlKDTrainer
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_collator_cls(self, cfg, is_eval=False):
|
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||||
if not cfg.kd_trainer:
|
if not cfg.kd_trainer:
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
|
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
|
||||||
@@ -48,8 +48,8 @@ class KDPlugin(BasePlugin):
|
|||||||
use_batch_sampler_collator = True
|
use_batch_sampler_collator = True
|
||||||
|
|
||||||
if use_batch_sampler_collator:
|
if use_batch_sampler_collator:
|
||||||
return KDBatchSamplerDataCollatorForSeq2Seq
|
return KDBatchSamplerDataCollatorForSeq2Seq, {}
|
||||||
return DataCollatorForKD
|
return DataCollatorForKD, {}
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
from .kernels.models import apply_kernel
|
from .kernels.models import apply_kernel
|
||||||
|
|||||||
@@ -15,8 +15,6 @@
|
|||||||
"""
|
"""
|
||||||
Plugin args for KD support.
|
Plugin args for KD support.
|
||||||
"""
|
"""
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@@ -25,9 +23,13 @@ class KDArgs(BaseModel):
|
|||||||
Input args for knowledge distillation.
|
Input args for knowledge distillation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
kd_trainer: float | None = None # whether to use KD trainer
|
||||||
kd_ce_alpha: Optional[float] = (
|
kd_ce_alpha: float | None = (
|
||||||
None # loss coefficient for cross-entropy loss during KD
|
None # loss coefficient for cross-entropy loss during KD
|
||||||
)
|
)
|
||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
kd_temperature: float | None = None # temperature for sampling during KD
|
||||||
|
|
||||||
|
# TODO online kd
|
||||||
|
# kd_online_server_base_url: str | None = None
|
||||||
|
# kd_online_topk: int | None = None
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
|
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||||
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
||||||
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||||
temperature: float,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute Top-K KL divergence loss for a chunk.
|
Compute Top-K KL divergence loss for a chunk.
|
||||||
@@ -29,15 +28,15 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
|
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
|
||||||
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
|
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
|
||||||
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
|
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
|
||||||
temperature: Temperature used for scaling.
|
|
||||||
Returns:
|
Returns:
|
||||||
Sum of KL divergence losses for the chunk.
|
Sum of KL divergence losses for the chunk.
|
||||||
"""
|
"""
|
||||||
student_logits_temp_scaled = student_logits_temp_scaled.float()
|
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
|
||||||
|
student_logits_temp_scaled.float()
|
||||||
|
)
|
||||||
target_logprobs_chunk = target_logprobs_chunk.float()
|
target_logprobs_chunk = target_logprobs_chunk.float()
|
||||||
|
|
||||||
# Gather student logits for the top-k teacher token IDs
|
# Gather student logits for the top-k teacher token IDs
|
||||||
# student_logits_temp_scaled: [chunk_size, vocab_size]
|
|
||||||
# target_token_ids_chunk: [chunk_size, top_k]
|
# target_token_ids_chunk: [chunk_size, top_k]
|
||||||
# student_logits_topk_temp_scaled: [chunk_size, top_k]
|
# student_logits_topk_temp_scaled: [chunk_size, top_k]
|
||||||
student_logits_topk_temp_scaled = torch.gather(
|
student_logits_topk_temp_scaled = torch.gather(
|
||||||
@@ -72,9 +71,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
)
|
)
|
||||||
kd_loss = kd_loss_per_token.sum()
|
kd_loss = kd_loss_per_token.sum()
|
||||||
|
|
||||||
if temperature != 1.0:
|
|
||||||
kd_loss = kd_loss * (temperature**2)
|
|
||||||
|
|
||||||
return kd_loss
|
return kd_loss
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -86,10 +86,6 @@ def loss(
|
|||||||
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||||
kd_loss = kd_loss_per_token.sum()
|
kd_loss = kd_loss_per_token.sum()
|
||||||
|
|
||||||
# Multiply by T^2 (classical KD scaling)
|
|
||||||
if kd_temperature != 1.0:
|
|
||||||
kd_loss = kd_loss * (kd_temperature**2)
|
|
||||||
|
|
||||||
# Normalize by number of items (if provided) or by valid tokens
|
# Normalize by number of items (if provided) or by valid tokens
|
||||||
if num_items_in_batch > 0:
|
if num_items_in_batch > 0:
|
||||||
kd_loss = kd_loss / float(num_items_in_batch)
|
kd_loss = kd_loss / float(num_items_in_batch)
|
||||||
|
|||||||
Reference in New Issue
Block a user