collator cls for plugins
This commit is contained in:
@@ -429,18 +429,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
]
|
||||
]
|
||||
collator_args = [self.tokenizer]
|
||||
if self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
elif self.cfg.kd_trainer:
|
||||
from axolotl.integrations.kd.collator import (
|
||||
DataCollatorForKD,
|
||||
KDBatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
|
||||
if self.cfg.sample_packing and use_batch_sampler_collator:
|
||||
collator = KDBatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = DataCollatorForKD
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval)
|
||||
|
||||
if collator_cls:
|
||||
pass
|
||||
elif self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
elif use_batch_sampler_collator:
|
||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||
# supported multipack models, or non-flash-attention llama
|
||||
|
||||
@@ -167,7 +167,82 @@ class BasePlugin:
|
||||
trainer: The trainer object for training.
|
||||
|
||||
Returns:
|
||||
<<<<<<< HEAD
|
||||
The created optimizer.
|
||||
=======
|
||||
None
|
||||
"""
|
||||
|
||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before LoRA weights are loaded.
|
||||
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after LoRA weights are loaded.
|
||||
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
|
||||
"""
|
||||
Returns a custom class for the trainer.
|
||||
|
||||
Args:
|
||||
cfg (dict): The global axolotl configuration.
|
||||
|
||||
Returns:
|
||||
class: The class for the trainer.
|
||||
"""
|
||||
|
||||
def get_collator_cls(self, cfg, is_eval=False): # pylint: disable=unused-argument):
|
||||
"""
|
||||
Returns a custom class for the collator.
|
||||
|
||||
Args:
|
||||
cfg (dict): The global axolotl configuration.
|
||||
is_eval (bool): Whether this is an eval split.
|
||||
|
||||
Returns:
|
||||
class: The class for the collator.
|
||||
"""
|
||||
|
||||
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after the trainer is created.
|
||||
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns an optimizer for training.
|
||||
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
object: The created optimizer.
|
||||
>>>>>>> f8df1563d (collator cls for plugins)
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@@ -442,6 +517,23 @@ class PluginManager:
|
||||
return trainer_cls
|
||||
return None
|
||||
|
||||
def get_collator_cls(self, cfg, is_eval=False):
|
||||
"""
|
||||
Calls the get_collator_cls method of all registered plugins and returns the first non-None collator class.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
is_eval (bool): Whether this is an eval split.
|
||||
|
||||
Returns:
|
||||
object: The collator class, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
collator_cls = plugin.get_collator_cls(cfg, is_eval=is_eval)
|
||||
if collator_cls is not None:
|
||||
return collator_cls
|
||||
return None
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||
"""Calls the `post_trainer_create` method of all registered plugins.
|
||||
|
||||
|
||||
@@ -35,6 +35,22 @@ class KDPlugin(BasePlugin):
|
||||
return AxolotlKDTrainer
|
||||
return None
|
||||
|
||||
def get_collator_cls(self, cfg, is_eval=False):
|
||||
if not cfg.kd_trainer:
|
||||
return None
|
||||
|
||||
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
|
||||
|
||||
use_batch_sampler_collator = False
|
||||
if is_eval is False and cfg.sample_packing:
|
||||
use_batch_sampler_collator = True
|
||||
if cfg.eval_sample_packing and is_eval:
|
||||
use_batch_sampler_collator = True
|
||||
|
||||
if use_batch_sampler_collator:
|
||||
return KDBatchSamplerDataCollatorForSeq2Seq
|
||||
return DataCollatorForKD
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from .kernels.models import apply_kernel
|
||||
|
||||
|
||||
@@ -305,8 +305,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
grad_inputs_list.append(grad_input_chunk)
|
||||
|
||||
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
|
||||
print("grad_inputs_combined")
|
||||
print(grad_inputs_combined.shape)
|
||||
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
||||
|
||||
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -161,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if not isinstance(features[0], list):
|
||||
features = [features]
|
||||
features: List[List[dict]] = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
|
||||
Reference in New Issue
Block a user