diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index 515248fce..eb2ced231 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -6,15 +6,15 @@ import signal
import sys
import weakref
from pathlib import Path
-from typing import Tuple, Union
+from typing import Any, Tuple, Union
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
-from peft import PeftModel
+from peft import PeftConfig, PeftModel
from pkg_resources import get_distribution # type: ignore
-from transformers import PreTrainedModel, PreTrainedTokenizer
+from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.datasets import TrainDatasetMeta
@@ -27,11 +27,13 @@ from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
+# Optional imports with graceful fallbacks
try:
from optimum.bettertransformer import BetterTransformer
except ImportError:
BetterTransformer = None
+# Project setup
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
@@ -40,9 +42,20 @@ configure_logging()
LOG = get_logger(__name__)
-def train(
- *, cfg: DictDefault, dataset_meta: TrainDatasetMeta
-) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
+def setup_model_and_tokenizer(
+ cfg: DictDefault,
+) -> Tuple[
+ PreTrainedTokenizer, ProcessorMixin | None, PreTrainedModel, PeftConfig | None
+]:
+ """
+ Load the tokenizer, processor (for multimodal models), and model based on configuration.
+
+ Args:
+ cfg: The configuration dictionary with training parameters
+
+ Returns:
+ Tuple containing tokenizer, processor (if multimodal, else None), model, and peft_config (if applicable, else None)
+ """
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
@@ -55,11 +68,58 @@ def train(
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
- # Get datasets
- train_dataset = dataset_meta.train_dataset
- eval_dataset = dataset_meta.eval_dataset
- total_num_steps = dataset_meta.total_num_steps
+ # Load the model and peft_config
+ msg = "loading model"
+ if cfg.adapter:
+ msg += " and peft_config..."
+ LOG.debug(msg)
+ model, peft_config = load_model(cfg, tokenizer, processor=processor)
+ if model.generation_config is not None:
+ model.generation_config.do_sample = True
+
+ # Apply freezing if specified
+ if cfg.unfrozen_parameters:
+ freeze_layers_except(model, cfg.unfrozen_parameters)
+
+ return tokenizer, processor, model, peft_config
+
+
+def setup_reference_model(
+ cfg: DictDefault, tokenizer: PreTrainedTokenizer
+) -> PreTrainedModel | None:
+ """
+ Set up the reference model for RL training if needed.
+
+ Args:
+ cfg: The configuration dictionary
+ tokenizer: The tokenizer to use for the reference model
+
+ Returns:
+ Reference model if needed for RL training, None otherwise
+ """
+ model_ref = None
+ if cfg.rl and cfg.rl != "orpo":
+ if cfg.adapter and not cfg.rl_adapter_ref_model:
+ # use built-in trl autounwrap
+ LOG.debug("Passing model_ref: None to RL trainer")
+ model_ref = None # explicit setting to None
+ else:
+ # load the model again for model_ref/baseline
+ model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
+ return model_ref
+
+
+def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
+ """
+ Determine the checkpoint to resume from based on configuration.
+
+ Args:
+ cfg: The configuration dictionary
+
+ Returns:
+ Path to the checkpoint to resume from, or None if not resuming
+ """
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
@@ -73,77 +133,22 @@ def train(
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
- resume_from_checkpoint = cfg.resume_from_checkpoint
+ return cfg.resume_from_checkpoint
- # Load the model and tokenizer
- msg = "loading model"
- if cfg.adapter:
- msg += " and peft_config..."
- LOG.debug(msg)
- model, peft_config = load_model(cfg, tokenizer, processor=processor)
- if model.generation_config is not None:
- model.generation_config.do_sample = True
- model_ref = None
- if cfg.rl and cfg.rl != "orpo":
- if cfg.adapter and not cfg.rl_adapter_ref_model:
- # use built-in trl autounwrap
- LOG.debug("Passing model_ref: None to RL trainer")
- model_ref = None # explicit setting to None
- else:
- # load the model again for model_ref/baseline
- model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
+def setup_signal_handler(
+ cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
+):
+ """
+ Set up signal handler for graceful termination.
- safe_serialization = cfg.save_safetensors is True
-
- if cfg.unfrozen_parameters:
- freeze_layers_except(model, cfg.unfrozen_parameters)
-
- trainer = setup_trainer(
- cfg,
- train_dataset,
- eval_dataset,
- (model, model_ref, peft_config),
- tokenizer,
- processor,
- total_num_steps,
- )
-
- if cfg.fix_untrained_tokens:
- # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
- sig = inspect.signature(fix_untrained_tokens)
- # if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
- if "token_ids_to_fix" in sig.parameters and isinstance(
- cfg.fix_untrained_tokens, list
- ):
- fix_untrained_tokens(
- model,
- tokenizer,
- train_dataset,
- token_ids_to_fix=cfg.fix_untrained_tokens,
- )
- else:
- fix_untrained_tokens(model, tokenizer, train_dataset)
- if cfg.local_rank == 0:
- model.save_pretrained(
- str(Path(cfg.output_dir)), safe_serialization=safe_serialization
- )
-
- # go ahead and presave, so we have the adapter config available to inspect
- if peft_config:
- LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
- peft_config.save_pretrained(cfg.output_dir)
- # additionally presave the tokenizer and model configs
- if not Path(cfg.output_dir).is_dir():
- os.makedirs(cfg.output_dir, exist_ok=True)
- tokenizer.save_pretrained(str(Path(cfg.output_dir)))
- if hasattr(model, "config"):
- model.config.save_pretrained(str(Path(cfg.output_dir)))
-
- # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
- if (
- cfg.local_rank == 0 and not cfg.use_ray
- ): # ray workers don't have access to this signal
+ Args:
+ cfg: The configuration dictionary
+ model: The model to save on termination
+ safe_serialization: Whether to use safe serialization when saving
+ """
+ # ray workers don't have access to this signal
+ if cfg.local_rank == 0 and not cfg.use_ray:
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
@@ -161,21 +166,22 @@ def train(
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
)
- badge_markdown = """[
](https://github.com/axolotl-ai-cloud/axolotl)"""
- transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
- if getattr(cfg, "axolotl_config_path"):
- raw_axolotl_cfg = Path(cfg.axolotl_config_path)
- version = get_distribution("axolotl").version
- if raw_axolotl_cfg.is_file():
- transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\nSee axolotl config
\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n
\n"
+def execute_training(
+ cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
+):
+ """
+ Execute the training process with appropriate backend configurations.
+ Args:
+ cfg: The configuration dictionary
+ trainer: The configured trainer object
+ resume_from_checkpoint: Path to checkpoint to resume from, if applicable
+ """
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
- pretrain_hooks(cfg, trainer)
-
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
@@ -187,15 +193,30 @@ def train(
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
- post_train_hooks(cfg, trainer)
+def save_trained_model(
+ cfg: DictDefault,
+ trainer: Any,
+ model: PreTrainedModel,
+ safe_serialization: bool,
+):
+ """
+ Save the trained model according to configuration and training setup.
+
+ Args:
+ cfg: The configuration dictionary
+ trainer: The trainer object
+ model: The trained model to save
+ safe_serialization: Whether to use safe serialization
+ """
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
- # post training
+ # Post training module hooks
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
+ # Handle FSDP state dict type
state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled:
if cfg.fsdp_final_state_dict_type:
@@ -203,12 +224,13 @@ def train(
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
+ # Handle ReLoRA early return case
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
- return model, tokenizer
+ return
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
@@ -255,6 +277,15 @@ def train(
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
+
+def create_model_card(cfg: DictDefault, trainer: Any):
+ """
+ Create a model card for the trained model if needed.
+
+ Args:
+ cfg: The configuration dictionary
+ trainer: The trainer object with model card creation capabilities
+ """
if not cfg.hub_model_id:
try:
model_card_kwarg = {
@@ -291,22 +322,162 @@ def train(
# defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()
+
+def save_initial_configs(
+ cfg: DictDefault,
+ tokenizer: PreTrainedTokenizer,
+ model: PreTrainedModel,
+ peft_config: PeftConfig | None,
+):
+ """
+ Save initial configurations before training.
+
+ Args:
+ cfg: The configuration dictionary
+ tokenizer: The tokenizer to save
+ model: The model to save configuration for
+ peft_config: The PEFT configuration to save if applicable
+ """
+ # go ahead and presave, so we have the adapter config available to inspect
+ if peft_config:
+ LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
+ peft_config.save_pretrained(cfg.output_dir)
+
+ # additionally presave the tokenizer and model configs
+ output_dir = Path(cfg.output_dir)
+ if not output_dir.is_dir():
+ os.makedirs(cfg.output_dir, exist_ok=True)
+
+ tokenizer.save_pretrained(str(output_dir))
+ if hasattr(model, "config"):
+ model.config.save_pretrained(str(output_dir))
+
+
+def setup_badge_for_model_card():
+ """Set up the Axolotl badge for the model card."""
+ badge_markdown = """[
](https://github.com/axolotl-ai-cloud/axolotl)"""
+ transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
+
+
+def add_config_to_model_card(cfg: DictDefault):
+ """
+ Add the Axolotl configuration to the model card if available.
+
+ Args:
+ cfg: The configuration dictionary with path to axolotl config file
+ """
+ if getattr(cfg, "axolotl_config_path"):
+ raw_axolotl_cfg = Path(cfg.axolotl_config_path)
+ version = get_distribution("axolotl").version
+ if raw_axolotl_cfg.is_file():
+ transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\nSee axolotl config
\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n
\n"
+
+
+def handle_untrained_tokens_fix(
+ cfg: DictDefault,
+ model: PreTrainedModel,
+ tokenizer: PreTrainedTokenizer,
+ train_dataset: Any,
+ safe_serialization: bool,
+):
+ """
+ Apply fixes for untrained tokens if configured.
+
+ Args:
+ cfg: The configuration dictionary
+ model: The model to apply fixes to
+ tokenizer: The tokenizer for token identification
+ train_dataset: The training dataset to analyze
+ safe_serialization: Whether to use safe serialization when saving
+ """
+ if not cfg.fix_untrained_tokens:
+ return
+
+ # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
+ sig = inspect.signature(fix_untrained_tokens)
+
+ # if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
+ if "token_ids_to_fix" in sig.parameters and isinstance(
+ cfg.fix_untrained_tokens, list
+ ):
+ fix_untrained_tokens(
+ model,
+ tokenizer,
+ train_dataset,
+ token_ids_to_fix=cfg.fix_untrained_tokens,
+ )
+ else:
+ fix_untrained_tokens(model, tokenizer, train_dataset)
+
+ if cfg.local_rank == 0:
+ model.save_pretrained(
+ str(Path(cfg.output_dir)), safe_serialization=safe_serialization
+ )
+
+
+def train(
+ *, cfg: DictDefault, dataset_meta: TrainDatasetMeta
+) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
+ """
+ Train a model on the given dataset.
+
+ Args:
+ cfg: The configuration dictionary with training parameters
+ dataset_meta: Metadata about the training dataset
+
+ Returns:
+ Tuple of (model, tokenizer) after training
+ """
+ # Load tokenizer, processor and model
+ tokenizer, processor, model, peft_config = setup_model_and_tokenizer(cfg)
+
+ # Set up reference model for RL if needed
+ model_ref = setup_reference_model(cfg, tokenizer)
+
+ # Determine if we need to resume from a checkpoint
+ resume_from_checkpoint = determine_resume_checkpoint(cfg)
+
+ # Get datasets from metadata
+ train_dataset = dataset_meta.train_dataset
+ eval_dataset = dataset_meta.eval_dataset
+ total_num_steps = dataset_meta.total_num_steps
+
+ # Configuration for saving
+ safe_serialization = cfg.save_safetensors is True
+
+ # Set up trainer
+ trainer = setup_trainer(
+ cfg,
+ train_dataset,
+ eval_dataset,
+ (model, model_ref, peft_config),
+ tokenizer,
+ processor,
+ total_num_steps,
+ )
+
+ # Handle untrained tokens if configured
+ handle_untrained_tokens_fix(
+ cfg, model, tokenizer, train_dataset, safe_serialization
+ )
+
+ # Save initial configs
+ save_initial_configs(cfg, tokenizer, model, peft_config)
+
+ # Set up signal handler for graceful termination
+ setup_signal_handler(cfg, model, safe_serialization)
+
+ # Set up badges and config info for model card
+ setup_badge_for_model_card()
+ add_config_to_model_card(cfg)
+
+ # Execute the training
+ execute_training(cfg, trainer, resume_from_checkpoint)
+
+ # Save the trained model
+ save_trained_model(cfg, trainer, model, safe_serialization)
+
+ # Create model card
+ create_model_card(cfg, trainer)
+
return model, tokenizer
-
-
-def pretrain_hooks(_cfg, _trainer):
- """
- Run hooks right before kicking off the training
- :param cfg:
- :param trainer:
- :return:
- """
-
-
-def post_train_hooks(_cfg, _trainer):
- """
- Run hooks right after training completes
- :param cfg:
- :param trainer:
- :return:
- """