From 5ccfd225cb02d350f19e8875873b7c6d9e596c76 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 May 2025 10:13:44 -0400 Subject: [PATCH] collator cls for plugins --- src/axolotl/core/builders/causal.py | 19 ++-- src/axolotl/integrations/base.py | 92 ++++++++++++++++++++ src/axolotl/integrations/kd/__init__.py | 16 ++++ src/axolotl/integrations/kd/kernels/liger.py | 2 - src/axolotl/utils/collators/batching.py | 4 +- 5 files changed, 118 insertions(+), 15 deletions(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 47170cc6f..e8838014f 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -429,18 +429,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] ] collator_args = [self.tokenizer] - if self.cfg.reward_model: - collator = RewardDataCollatorWithPadding - elif self.cfg.kd_trainer: - from axolotl.integrations.kd.collator import ( - DataCollatorForKD, - KDBatchSamplerDataCollatorForSeq2Seq, - ) - if self.cfg.sample_packing and use_batch_sampler_collator: - collator = KDBatchSamplerDataCollatorForSeq2Seq - else: - collator = DataCollatorForKD + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval) + + if collator_cls: + pass + elif self.cfg.reward_model: + collator = RewardDataCollatorWithPadding elif use_batch_sampler_collator: # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # supported multipack models, or non-flash-attention llama diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 0edc9fdea..2b8eaa6e6 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -167,7 +167,82 @@ class BasePlugin: trainer: The trainer object for training. Returns: +<<<<<<< HEAD The created optimizer. +======= + None + """ + + def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument + """ + Performs actions before LoRA weights are loaded. + + Args: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. + + Returns: + None + """ + + def post_lora_load(self, cfg, model): # pylint: disable=unused-argument + """ + Performs actions after LoRA weights are loaded. + + Args: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. + + Returns: + None + """ + + def get_trainer_cls(self, cfg): # pylint: disable=unused-argument): + """ + Returns a custom class for the trainer. + + Args: + cfg (dict): The global axolotl configuration. + + Returns: + class: The class for the trainer. + """ + + def get_collator_cls(self, cfg, is_eval=False): # pylint: disable=unused-argument): + """ + Returns a custom class for the collator. + + Args: + cfg (dict): The global axolotl configuration. + is_eval (bool): Whether this is an eval split. + + Returns: + class: The class for the collator. + """ + + def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument + """ + Performs actions after the trainer is created. + + Args: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. + + Returns: + None + """ + + def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument + """ + Creates and returns an optimizer for training. + + Args: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. + + Returns: + object: The created optimizer. +>>>>>>> f8df1563d (collator cls for plugins) """ # pylint: disable=unused-argument @@ -442,6 +517,23 @@ class PluginManager: return trainer_cls return None + def get_collator_cls(self, cfg, is_eval=False): + """ + Calls the get_collator_cls method of all registered plugins and returns the first non-None collator class. + + Parameters: + cfg (dict): The configuration for the plugins. + is_eval (bool): Whether this is an eval split. + + Returns: + object: The collator class, or None if none was found. + """ + for plugin in self.plugins.values(): + collator_cls = plugin.get_collator_cls(cfg, is_eval=is_eval) + if collator_cls is not None: + return collator_cls + return None + def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): """Calls the `post_trainer_create` method of all registered plugins. diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index 00d73e036..cce646bcb 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -35,6 +35,22 @@ class KDPlugin(BasePlugin): return AxolotlKDTrainer return None + def get_collator_cls(self, cfg, is_eval=False): + if not cfg.kd_trainer: + return None + + from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq + + use_batch_sampler_collator = False + if is_eval is False and cfg.sample_packing: + use_batch_sampler_collator = True + if cfg.eval_sample_packing and is_eval: + use_batch_sampler_collator = True + + if use_batch_sampler_collator: + return KDBatchSamplerDataCollatorForSeq2Seq + return DataCollatorForKD + def pre_model_load(self, cfg): from .kernels.models import apply_kernel diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py index e4e80df14..9fa0b663d 100644 --- a/src/axolotl/integrations/kd/kernels/liger.py +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -305,8 +305,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): grad_inputs_list.append(grad_input_chunk) grad_inputs_combined = torch.cat(grad_inputs_list, dim=0) - print("grad_inputs_combined") - print(grad_inputs_combined.shape) ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc) # For matching None returns in backward for non-tensor/non-grad_requiring inputs diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 45facf832..6cb26ec6b 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,7 +1,7 @@ """Data collators for axolotl to pad labels and position_ids for packed sequences""" from dataclasses import dataclass -from typing import Any +from typing import Any, List import numpy as np from transformers import PreTrainedTokenizerBase @@ -161,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): def __call__(self, features, return_tensors=None): if not isinstance(features[0], list): - features = [features] + features: List[List[dict]] = [features] out_features = [{} for _ in features] for i, features_ in enumerate(features): for feature in features_[0].keys():