collator cls for plugins
This commit is contained in:
@@ -429,18 +429,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
collator_args = [self.tokenizer]
|
collator_args = [self.tokenizer]
|
||||||
if self.cfg.reward_model:
|
|
||||||
collator = RewardDataCollatorWithPadding
|
|
||||||
elif self.cfg.kd_trainer:
|
|
||||||
from axolotl.integrations.kd.collator import (
|
|
||||||
DataCollatorForKD,
|
|
||||||
KDBatchSamplerDataCollatorForSeq2Seq,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.sample_packing and use_batch_sampler_collator:
|
if self.cfg.plugins:
|
||||||
collator = KDBatchSamplerDataCollatorForSeq2Seq
|
plugin_manager = PluginManager.get_instance()
|
||||||
else:
|
collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval)
|
||||||
collator = DataCollatorForKD
|
|
||||||
|
if collator_cls:
|
||||||
|
pass
|
||||||
|
elif self.cfg.reward_model:
|
||||||
|
collator = RewardDataCollatorWithPadding
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||||
# supported multipack models, or non-flash-attention llama
|
# supported multipack models, or non-flash-attention llama
|
||||||
|
|||||||
@@ -167,7 +167,82 @@ class BasePlugin:
|
|||||||
trainer: The trainer object for training.
|
trainer: The trainer object for training.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
<<<<<<< HEAD
|
||||||
The created optimizer.
|
The created optimizer.
|
||||||
|
=======
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
Performs actions before LoRA weights are loaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): The configuration for the plugin.
|
||||||
|
model (object): The loaded model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
Performs actions after LoRA weights are loaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): The configuration for the plugin.
|
||||||
|
model (object): The loaded model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
|
||||||
|
"""
|
||||||
|
Returns a custom class for the trainer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): The global axolotl configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
class: The class for the trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_collator_cls(self, cfg, is_eval=False): # pylint: disable=unused-argument):
|
||||||
|
"""
|
||||||
|
Returns a custom class for the collator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): The global axolotl configuration.
|
||||||
|
is_eval (bool): Whether this is an eval split.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
class: The class for the collator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
Performs actions after the trainer is created.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): The configuration for the plugin.
|
||||||
|
trainer (object): The trainer object for training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
Creates and returns an optimizer for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): The configuration for the plugin.
|
||||||
|
trainer (object): The trainer object for training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The created optimizer.
|
||||||
|
>>>>>>> f8df1563d (collator cls for plugins)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
@@ -442,6 +517,23 @@ class PluginManager:
|
|||||||
return trainer_cls
|
return trainer_cls
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_collator_cls(self, cfg, is_eval=False):
|
||||||
|
"""
|
||||||
|
Calls the get_collator_cls method of all registered plugins and returns the first non-None collator class.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The configuration for the plugins.
|
||||||
|
is_eval (bool): Whether this is an eval split.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The collator class, or None if none was found.
|
||||||
|
"""
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
collator_cls = plugin.get_collator_cls(cfg, is_eval=is_eval)
|
||||||
|
if collator_cls is not None:
|
||||||
|
return collator_cls
|
||||||
|
return None
|
||||||
|
|
||||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||||
"""Calls the `post_trainer_create` method of all registered plugins.
|
"""Calls the `post_trainer_create` method of all registered plugins.
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,22 @@ class KDPlugin(BasePlugin):
|
|||||||
return AxolotlKDTrainer
|
return AxolotlKDTrainer
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_collator_cls(self, cfg, is_eval=False):
|
||||||
|
if not cfg.kd_trainer:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
use_batch_sampler_collator = False
|
||||||
|
if is_eval is False and cfg.sample_packing:
|
||||||
|
use_batch_sampler_collator = True
|
||||||
|
if cfg.eval_sample_packing and is_eval:
|
||||||
|
use_batch_sampler_collator = True
|
||||||
|
|
||||||
|
if use_batch_sampler_collator:
|
||||||
|
return KDBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
return DataCollatorForKD
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
from .kernels.models import apply_kernel
|
from .kernels.models import apply_kernel
|
||||||
|
|
||||||
|
|||||||
@@ -305,8 +305,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
grad_inputs_list.append(grad_input_chunk)
|
grad_inputs_list.append(grad_input_chunk)
|
||||||
|
|
||||||
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
|
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
|
||||||
print("grad_inputs_combined")
|
|
||||||
print(grad_inputs_combined.shape)
|
|
||||||
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
||||||
|
|
||||||
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
|
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@@ -161,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
if not isinstance(features[0], list):
|
if not isinstance(features[0], list):
|
||||||
features = [features]
|
features: List[List[dict]] = [features]
|
||||||
out_features = [{} for _ in features]
|
out_features = [{} for _ in features]
|
||||||
for i, features_ in enumerate(features):
|
for i, features_ in enumerate(features):
|
||||||
for feature in features_[0].keys():
|
for feature in features_[0].keys():
|
||||||
|
|||||||
Reference in New Issue
Block a user