collator cls for plugins

This commit is contained in:
Wing Lian
2025-05-22 10:13:44 -04:00
parent 28eb8632a1
commit 5ccfd225cb
5 changed files with 118 additions and 15 deletions

View File

@@ -429,18 +429,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] ]
] ]
collator_args = [self.tokenizer] collator_args = [self.tokenizer]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing and use_batch_sampler_collator: if self.cfg.plugins:
collator = KDBatchSamplerDataCollatorForSeq2Seq plugin_manager = PluginManager.get_instance()
else: collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval)
collator = DataCollatorForKD
if collator_cls:
pass
elif self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator: elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama # supported multipack models, or non-flash-attention llama

View File

@@ -167,7 +167,82 @@ class BasePlugin:
trainer: The trainer object for training. trainer: The trainer object for training.
Returns: Returns:
<<<<<<< HEAD
The created optimizer. The created optimizer.
=======
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Args:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
"""
def get_collator_cls(self, cfg, is_eval=False): # pylint: disable=unused-argument):
"""
Returns a custom class for the collator.
Args:
cfg (dict): The global axolotl configuration.
is_eval (bool): Whether this is an eval split.
Returns:
class: The class for the collator.
"""
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
"""
Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
>>>>>>> f8df1563d (collator cls for plugins)
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
@@ -442,6 +517,23 @@ class PluginManager:
return trainer_cls return trainer_cls
return None return None
def get_collator_cls(self, cfg, is_eval=False):
"""
Calls the get_collator_cls method of all registered plugins and returns the first non-None collator class.
Parameters:
cfg (dict): The configuration for the plugins.
is_eval (bool): Whether this is an eval split.
Returns:
object: The collator class, or None if none was found.
"""
for plugin in self.plugins.values():
collator_cls = plugin.get_collator_cls(cfg, is_eval=is_eval)
if collator_cls is not None:
return collator_cls
return None
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Calls the `post_trainer_create` method of all registered plugins. """Calls the `post_trainer_create` method of all registered plugins.

View File

@@ -35,6 +35,22 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer return AxolotlKDTrainer
return None return None
def get_collator_cls(self, cfg, is_eval=False):
if not cfg.kd_trainer:
return None
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
use_batch_sampler_collator = False
if is_eval is False and cfg.sample_packing:
use_batch_sampler_collator = True
if cfg.eval_sample_packing and is_eval:
use_batch_sampler_collator = True
if use_batch_sampler_collator:
return KDBatchSamplerDataCollatorForSeq2Seq
return DataCollatorForKD
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
from .kernels.models import apply_kernel from .kernels.models import apply_kernel

View File

@@ -305,8 +305,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
grad_inputs_list.append(grad_input_chunk) grad_inputs_list.append(grad_input_chunk)
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0) grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
print("grad_inputs_combined")
print(grad_inputs_combined.shape)
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc) ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
# For matching None returns in backward for non-tensor/non-grad_requiring inputs # For matching None returns in backward for non-tensor/non-grad_requiring inputs

View File

@@ -1,7 +1,7 @@
"""Data collators for axolotl to pad labels and position_ids for packed sequences""" """Data collators for axolotl to pad labels and position_ids for packed sequences"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any, List
import numpy as np import numpy as np
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -161,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None): def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list): if not isinstance(features[0], list):
features = [features] features: List[List[dict]] = [features]
out_features = [{} for _ in features] out_features = [{} for _ in features]
for i, features_ in enumerate(features): for i, features_ in enumerate(features):
for feature in features_[0].keys(): for feature in features_[0].keys():