From 49e2fa825df08a8d663504b1a3e4a48859b83289 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 May 2025 08:35:44 -0400 Subject: [PATCH] additional plugin collator kwargs, don't scale up kd loss by t^2 --- deepspeed_configs/zero2_torch_compile.json | 31 +++++++++++++++++++ src/axolotl/core/builders/causal.py | 10 ++++-- src/axolotl/integrations/base.py | 14 ++++++--- src/axolotl/integrations/kd/__init__.py | 8 ++--- src/axolotl/integrations/kd/args.py | 14 +++++---- src/axolotl/integrations/kd/kernels/liger.py | 10 ++---- .../kd/topk_logprob/forward_kl.py | 4 --- 7 files changed, 62 insertions(+), 29 deletions(-) create mode 100644 deepspeed_configs/zero2_torch_compile.json diff --git a/deepspeed_configs/zero2_torch_compile.json b/deepspeed_configs/zero2_torch_compile.json new file mode 100644 index 000000000..c3bcf98cf --- /dev/null +++ b/deepspeed_configs/zero2_torch_compile.json @@ -0,0 +1,31 @@ +{ + "compile": { + "disable": false, + "backend": "inductor" + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + }, + "contiguous_gradients": true, + "overlap_comm": true + }, + "bf16": { + "enabled": "auto" + }, + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 32, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index fa6a9ec37..8babf6a65 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -432,10 +432,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.plugins: plugin_manager = PluginManager.get_instance() - collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval) + collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs( + self.cfg, is_eval=is_eval + ) - if collator_cls: - collator = collator_cls + if collator_cls_and_kwargs: + collator = collator_cls_and_kwargs[0] + if kwargs and isinstance(kwargs, dict): + kwargs.update(collator_cls_and_kwargs[1]) elif self.cfg.reward_model: collator = RewardDataCollatorWithPadding elif use_batch_sampler_collator: diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 2b8eaa6e6..7c2c773b1 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -208,7 +208,9 @@ class BasePlugin: class: The class for the trainer. """ - def get_collator_cls(self, cfg, is_eval=False): # pylint: disable=unused-argument): + def get_collator_cls_and_kwargs( + self, cfg, is_eval=False + ): # pylint: disable=unused-argument): """ Returns a custom class for the collator. @@ -517,9 +519,9 @@ class PluginManager: return trainer_cls return None - def get_collator_cls(self, cfg, is_eval=False): + def get_collator_cls_and_kwargs(self, cfg, is_eval=False): """ - Calls the get_collator_cls method of all registered plugins and returns the first non-None collator class. + Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class. Parameters: cfg (dict): The configuration for the plugins. @@ -529,9 +531,11 @@ class PluginManager: object: The collator class, or None if none was found. """ for plugin in self.plugins.values(): - collator_cls = plugin.get_collator_cls(cfg, is_eval=is_eval) + collator_cls, collator_kwargs = plugin.get_collator_cls_and_kwargs( + cfg, is_eval=is_eval + ) if collator_cls is not None: - return collator_cls + return collator_cls, collator_kwargs return None def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index cce646bcb..3214ced35 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -35,9 +35,9 @@ class KDPlugin(BasePlugin): return AxolotlKDTrainer return None - def get_collator_cls(self, cfg, is_eval=False): + def get_collator_cls_and_kwargs(self, cfg, is_eval=False): if not cfg.kd_trainer: - return None + return None, None from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq @@ -48,8 +48,8 @@ class KDPlugin(BasePlugin): use_batch_sampler_collator = True if use_batch_sampler_collator: - return KDBatchSamplerDataCollatorForSeq2Seq - return DataCollatorForKD + return KDBatchSamplerDataCollatorForSeq2Seq, {} + return DataCollatorForKD, {} def pre_model_load(self, cfg): from .kernels.models import apply_kernel diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 0eede1ada..35cf7a114 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -15,8 +15,6 @@ """ Plugin args for KD support. """ -from typing import Optional - from pydantic import BaseModel @@ -25,9 +23,13 @@ class KDArgs(BaseModel): Input args for knowledge distillation. """ - kd_trainer: Optional[bool] = None # whether to use KD trainer - kd_ce_alpha: Optional[float] = ( + kd_trainer: float | None = None # whether to use KD trainer + kd_ce_alpha: float | None = ( None # loss coefficient for cross-entropy loss during KD ) - kd_alpha: Optional[float] = None # loss coefficient for KD loss - kd_temperature: Optional[float] = None # temperature for sampling during KD + kd_alpha: float | None = None # loss coefficient for KD loss + kd_temperature: float | None = None # temperature for sampling during KD + + # TODO online kd + # kd_online_server_base_url: str | None = None + # kd_online_topk: int | None = None diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py index 35f82855f..75fe718f6 100644 --- a/src/axolotl/integrations/kd/kernels/liger.py +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -20,7 +20,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k] target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs target_mask_chunk: torch.Tensor, # [chunk_size, top_k] - temperature: float, ) -> torch.Tensor: """ Compute Top-K KL divergence loss for a chunk. @@ -29,15 +28,15 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K). target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K). target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K). - temperature: Temperature used for scaling. Returns: Sum of KL divergence losses for the chunk. """ - student_logits_temp_scaled = student_logits_temp_scaled.float() + student_logits_temp_scaled = ( # [chunk_size, vocab_size] + student_logits_temp_scaled.float() + ) target_logprobs_chunk = target_logprobs_chunk.float() # Gather student logits for the top-k teacher token IDs - # student_logits_temp_scaled: [chunk_size, vocab_size] # target_token_ids_chunk: [chunk_size, top_k] # student_logits_topk_temp_scaled: [chunk_size, top_k] student_logits_topk_temp_scaled = torch.gather( @@ -72,9 +71,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): ) kd_loss = kd_loss_per_token.sum() - if temperature != 1.0: - kd_loss = kd_loss * (temperature**2) - return kd_loss @staticmethod diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 4b7251295..74184455f 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -86,10 +86,6 @@ def loss( kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) kd_loss = kd_loss_per_token.sum() - # Multiply by T^2 (classical KD scaling) - if kd_temperature != 1.0: - kd_loss = kd_loss * (kd_temperature**2) - # Normalize by number of items (if provided) or by valid tokens if num_items_in_batch > 0: kd_loss = kd_loss / float(num_items_in_batch)