Compare commits

...

5 Commits

Author SHA1 Message Date
Dan Saunders
5c0510a876 review comments 2025-03-03 18:44:16 +00:00
Dan Saunders
e1bc18763a combine like functions 2025-02-28 17:47:39 +00:00
Dan Saunders
ed5178cd3d update 2025-02-26 21:03:44 +00:00
Dan Saunders
a3224c7c3c updates 2025-02-26 20:31:54 +00:00
Dan Saunders
c4104fc10c refactor train.py 2025-02-26 19:37:42 +00:00
4 changed files with 380 additions and 157 deletions

View File

@@ -24,8 +24,8 @@ class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata.""" """Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset train_dataset: Dataset
eval_dataset: Optional[Dataset] = None eval_dataset: Dataset | None = None
total_num_steps: Optional[int] = None total_num_steps: int | None = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:

View File

@@ -91,13 +91,11 @@ try:
except ImportError: except ImportError:
pass pass
LOG = logging.getLogger("axolotl.core.trainer_builder") LOG = logging.getLogger(__name__)
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """Base class for trainer builder."""
Base class for trainer builder
"""
_train_dataset = None _train_dataset = None
_eval_dataset = None _eval_dataset = None
@@ -110,9 +108,9 @@ class TrainerBuilderBase(abc.ABC):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor self.processor = processor
# in case the model supports tagging, add the axolotl tag. # If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls # This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instad of trainer.push_to_hub. # model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"): if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"]) model.add_model_tags(["axolotl"])
@@ -227,8 +225,8 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase): class HFCausalTrainerBuilder(TrainerBuilderBase):
""" """
Build the HuggingFace training args/trainer for causal models Build the HuggingFace training args/trainer for causal models and reward modeling
and reward modelling using TRL. using TRL.
""" """
def get_callbacks(self): def get_callbacks(self):
@@ -872,9 +870,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase): class HFRLTrainerBuilder(TrainerBuilderBase):
""" """Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = super().get_callbacks()

View File

@@ -1,26 +1,29 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import inspect import inspect
import os import os
import signal import signal
import sys import sys
import weakref import weakref
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Any
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from peft import PeftModel from datasets import Dataset
from pkg_resources import get_distribution # type: ignore from peft import PeftConfig, PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except
@@ -32,17 +35,25 @@ try:
except ImportError: except ImportError:
BetterTransformer = None BetterTransformer = None
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging() configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
def train( def setup_model_and_tokenizer(
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta cfg: DictDefault,
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: ) -> tuple[
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]:
"""
Load the tokenizer, processor (for multimodal models), and model based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
`None`), and processor (if multimodal, else `None`).
"""
# Load tokenizer # Load tokenizer
LOG.debug( LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
@@ -55,11 +66,58 @@ def train(
if cfg.is_multimodal: if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer) processor = load_processor(cfg, tokenizer)
# Get datasets # Load the model and peft_config
train_dataset = dataset_meta.train_dataset msg = "loading model"
eval_dataset = dataset_meta.eval_dataset if cfg.adapter:
total_num_steps = dataset_meta.total_num_steps msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
return model, tokenizer, peft_config, processor
def setup_reference_model(
cfg: DictDefault, tokenizer: PreTrainedTokenizer
) -> PreTrainedModel | None:
"""
Set up the reference model for RL training if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to use for the reference model.
Returns:
Reference model if needed for RL training, `None` otherwise.
"""
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
return model_ref
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
"""
Determine the checkpoint to resume from based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Path to the checkpoint to resume from, or `None` if not resuming.
"""
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [ possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
@@ -73,77 +131,22 @@ def train(
LOG.info( LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
) )
resume_from_checkpoint = cfg.resume_from_checkpoint return cfg.resume_from_checkpoint
# Load the model and tokenizer
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
model_ref = None def setup_signal_handler(
if cfg.rl and cfg.rl != "orpo": cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
if cfg.adapter and not cfg.rl_adapter_ref_model: ):
# use built-in trl autounwrap """
LOG.debug("Passing model_ref: None to RL trainer") Set up signal handler for graceful termination.
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
safe_serialization = cfg.save_safetensors is True Args:
cfg: Dictionary mapping `axolotl` config keys to values.
if cfg.unfrozen_parameters: model: The model to save on termination
freeze_layers_except(model, cfg.unfrozen_parameters) safe_serialization: Whether to use safe serialization when saving
"""
trainer = setup_trainer( # ray workers don't have access to this signal
cfg, if cfg.local_rank == 0 and not cfg.use_ray:
train_dataset,
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
if cfg.fix_untrained_tokens:
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if (
cfg.local_rank == 0 and not cfg.use_ray
): # ray workers don't have access to this signal
def terminate_handler(_, __, model_weakref): def terminate_handler(_, __, model_weakref):
if model_weakref() is not None: if model_weakref() is not None:
@@ -161,21 +164,22 @@ def train(
lambda signum, frame: terminate_handler(signum, frame, _model_weakref), lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
) )
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"): def execute_training(
raw_axolotl_cfg = Path(cfg.axolotl_config_path) cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
version = get_distribution("axolotl").version ):
if raw_axolotl_cfg.is_file(): """
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n" Execute the training process with appropriate backend configurations.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
if cfg.group_by_length: if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length") LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum: if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ... # TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
@@ -187,15 +191,30 @@ def train(
else: else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
safe_serialization: bool,
):
"""
Save the trained model according to configuration and training setup.
# post training Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object.
model: The trained model to save.
safe_serialization: Whether to use safe serialization.
"""
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
# Post training module hooks
for name, module in model.named_modules(): for name, module in model.named_modules():
if hasattr(module, "_post_training"): if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access module._post_training(model, name) # pylint: disable=protected-access
# Handle FSDP state dict type
state_dict_type = "FULL_STATE_DICT" state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled: if trainer.is_fsdp_enabled:
if cfg.fsdp_final_state_dict_type: if cfg.fsdp_final_state_dict_type:
@@ -203,16 +222,18 @@ def train(
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.") LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
# Handle ReLoRA early return case
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload() model = model.merge_and_unload()
else: else:
# final model weights have already been saved by `ReLoRACallback.on_train_end` # final model weights have already been saved by `ReLoRACallback.on_train_end`
return model, tokenizer return
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp: if cfg.fsdp:
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple
# processes attempt to write the same file
if ( if (
state_dict_type == "SHARDED_STATE_DICT" state_dict_type == "SHARDED_STATE_DICT"
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT" and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
@@ -244,7 +265,6 @@ def train(
os.remove(os.path.join(cfg.output_dir, "model.safetensors")) os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError: except FileNotFoundError:
pass pass
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
@@ -255,58 +275,239 @@ def train(
) )
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def create_model_card(cfg: DictDefault, trainer: Trainer):
"""
Create a model card for the trained model if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object with model card creation capabilities.
"""
if not cfg.hub_model_id: if not cfg.hub_model_id:
# Guard since create_model_card may fail if dataset_tags is empty list
try: try:
model_card_kwarg = { model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./") "model_name": cfg.output_dir.lstrip("./")
.encode("utf-8") .encode("utf-8")
.decode("utf-8") .decode("utf-8")
} }
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model: # We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
dataset_tags = [ rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() if cfg.datasets is not None and not rl:
] dataset_tags = [
dataset_tags = [ d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
d for d in dataset_tags if not d.startswith("https://") ]
] dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list if dataset_tags:
model_card_kwarg["dataset_name"] = dataset_tags model_card_kwarg["dataset_tags"] = dataset_tags
else:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [
d for d in dataset_tags if not d.startswith("https://")
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg) trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError): except (AttributeError, UnicodeDecodeError):
pass pass
elif cfg.hub_model_id: elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated # Defensively push to the hub to ensure the model card is updated
trainer.push_to_hub() trainer.push_to_hub()
def save_initial_configs(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
peft_config: PeftConfig | None,
):
"""
Save initial configurations before training.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to save.
model: The model to save configuration for.
peft_config: The PEFT configuration to save if applicable.
"""
# Create output_dir if it doesn't already exist
output_dir = Path(cfg.output_dir)
if not output_dir.is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
# Pre-save adapter config so it's available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
peft_config.save_pretrained(cfg.output_dir)
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
def setup_model_card(cfg: DictDefault):
"""
Set up the Axolotl badge and add the Axolotl config to the model card if available.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = importlib.metadata.version("axolotl")
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
def handle_untrained_tokens_fix(
cfg: DictDefault,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
safe_serialization: bool,
):
"""
Apply fixes for untrained tokens if configured.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to apply fixes to.
tokenizer: The tokenizer for token identification.
train_dataset: The training dataset to use.
safe_serialization: Whether to use safe serialization when saving.
"""
if not cfg.fix_untrained_tokens:
return
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
]:
"""
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
trainer setup.
Args:
cfg: The configuration dictionary with training parameters.
dataset_meta: Object with training, validation datasets and metadata.
Returns:
Tuple of:
- Trainer (Causal or RLHF)
- Model
- Tokenizer
- PEFT config
"""
# Load tokenizer, processor and model
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
# Set up reference model for RL if needed
model_ref = setup_reference_model(cfg, tokenizer)
# Get datasets from metadata
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Set up trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,
model_ref=model_ref,
peft_config=peft_config,
)
return (
trainer,
model,
tokenizer,
peft_config,
)
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
"""
Train a model on the given dataset.
Args:
cfg: The configuration dictionary with training parameters
dataset_meta: Object with training, validation datasets and metadata
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer etc.
(
trainer,
model,
tokenizer,
peft_config,
) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Save initial configs
save_initial_configs(cfg, tokenizer, model, peft_config)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_model_card(cfg)
# Execute the training
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model
save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer)
return model, tokenizer return model, tokenizer
def pretrain_hooks(_cfg, _trainer):
"""
Run hooks right before kicking off the training
:param cfg:
:param trainer:
:return:
"""
def post_train_hooks(_cfg, _trainer):
"""
Run hooks right after training completes
:param cfg:
:param trainer:
:return:
"""

View File

@@ -574,14 +574,40 @@ def prepare_opinionated_env(cfg):
def setup_trainer( def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps cfg,
train_dataset,
eval_dataset,
model,
tokenizer,
processor,
total_num_steps,
model_ref=None,
peft_config=None,
): ):
"""
Helper method for instantiating and building a (causal or RLHF) trainer.
Args:
cfg: Axolotl config object containing training parameters.
train_dataset: Dataset to use for training.
eval_dataset: Dataset to use for evaluation.
model: The model to train.
tokenizer: Tokenizer for processing text input.
processor: Processor for data preparation.
total_num_steps: The total number of training steps.
model_ref: Optional reference model for RLHF training. Default is None.
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
Returns:
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
if cfg.rl: if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model[1] trainer_builder.model_ref = model_ref
trainer_builder.peft_config = model[2] trainer_builder.peft_config = peft_config
else: else:
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.train_dataset = train_dataset trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset trainer_builder.eval_dataset = eval_dataset