collator cls for plugins

This commit is contained in:
Wing Lian
2025-05-22 10:13:44 -04:00
parent 28eb8632a1
commit 5ccfd225cb
5 changed files with 118 additions and 15 deletions

View File

@@ -429,18 +429,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
collator_args = [self.tokenizer]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing and use_batch_sampler_collator:
collator = KDBatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForKD
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval)
if collator_cls:
pass
elif self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama

View File

@@ -167,7 +167,82 @@ class BasePlugin:
trainer: The trainer object for training.
Returns:
<<<<<<< HEAD
The created optimizer.
=======
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Args:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
"""
def get_collator_cls(self, cfg, is_eval=False): # pylint: disable=unused-argument):
"""
Returns a custom class for the collator.
Args:
cfg (dict): The global axolotl configuration.
is_eval (bool): Whether this is an eval split.
Returns:
class: The class for the collator.
"""
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
"""
Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
>>>>>>> f8df1563d (collator cls for plugins)
"""
# pylint: disable=unused-argument
@@ -442,6 +517,23 @@ class PluginManager:
return trainer_cls
return None
def get_collator_cls(self, cfg, is_eval=False):
"""
Calls the get_collator_cls method of all registered plugins and returns the first non-None collator class.
Parameters:
cfg (dict): The configuration for the plugins.
is_eval (bool): Whether this is an eval split.
Returns:
object: The collator class, or None if none was found.
"""
for plugin in self.plugins.values():
collator_cls = plugin.get_collator_cls(cfg, is_eval=is_eval)
if collator_cls is not None:
return collator_cls
return None
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Calls the `post_trainer_create` method of all registered plugins.

View File

@@ -35,6 +35,22 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer
return None
def get_collator_cls(self, cfg, is_eval=False):
if not cfg.kd_trainer:
return None
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
use_batch_sampler_collator = False
if is_eval is False and cfg.sample_packing:
use_batch_sampler_collator = True
if cfg.eval_sample_packing and is_eval:
use_batch_sampler_collator = True
if use_batch_sampler_collator:
return KDBatchSamplerDataCollatorForSeq2Seq
return DataCollatorForKD
def pre_model_load(self, cfg):
from .kernels.models import apply_kernel

View File

@@ -305,8 +305,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
grad_inputs_list.append(grad_input_chunk)
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
print("grad_inputs_combined")
print(grad_inputs_combined.shape)
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
# For matching None returns in backward for non-tensor/non-grad_requiring inputs

View File

@@ -1,7 +1,7 @@
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
from dataclasses import dataclass
from typing import Any
from typing import Any, List
import numpy as np
from transformers import PreTrainedTokenizerBase
@@ -161,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features = [features]
features: List[List[dict]] = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():