diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py
index db07eb43b..3e712f772 100644
--- a/src/axolotl/common/datasets.py
+++ b/src/axolotl/common/datasets.py
@@ -24,8 +24,8 @@ class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
- eval_dataset: Optional[Dataset] = None
- total_num_steps: Optional[int] = None
+ eval_dataset: Dataset | None = None
+ total_num_steps: int | None = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py
index 07753876f..d4ddc9bf3 100755
--- a/src/axolotl/core/trainer_builder.py
+++ b/src/axolotl/core/trainer_builder.py
@@ -91,13 +91,11 @@ try:
except ImportError:
pass
-LOG = logging.getLogger("axolotl.core.trainer_builder")
+LOG = logging.getLogger(__name__)
class TrainerBuilderBase(abc.ABC):
- """
- Base class for trainer builder
- """
+ """Base class for trainer builder."""
_train_dataset = None
_eval_dataset = None
@@ -110,9 +108,9 @@ class TrainerBuilderBase(abc.ABC):
self.tokenizer = tokenizer
self.processor = processor
- # in case the model supports tagging, add the axolotl tag.
+ # If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
- # model.push_to_hub instad of trainer.push_to_hub.
+ # model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
@@ -872,9 +870,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase):
- """
- Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
- """
+ """Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
def get_callbacks(self):
callbacks = super().get_callbacks()
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index d61e200c6..b2f4bf1e9 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -1,19 +1,20 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
+import importlib
import inspect
import os
import signal
import sys
import weakref
from pathlib import Path
-from typing import Any, Tuple, Union
+from typing import Any
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
+from datasets import Dataset
from peft import PeftConfig, PeftModel
-from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
@@ -22,6 +23,7 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
+from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
@@ -39,17 +41,18 @@ LOG = get_logger(__name__)
def setup_model_and_tokenizer(
cfg: DictDefault,
-) -> Tuple[
- PreTrainedTokenizer, ProcessorMixin | None, PreTrainedModel, PeftConfig | None
+) -> tuple[
+ PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]:
"""
Load the tokenizer, processor (for multimodal models), and model based on configuration.
Args:
- cfg: The configuration dictionary with training parameters
+ cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
- Tuple containing tokenizer, processor (if multimodal, else None), model, and peft_config (if applicable, else None)
+ Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
+ `None`), and processor (if multimodal, else `None`).
"""
# Load tokenizer
LOG.debug(
@@ -77,7 +80,7 @@ def setup_model_and_tokenizer(
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
- return tokenizer, processor, model, peft_config
+ return model, tokenizer, peft_config, processor
def setup_reference_model(
@@ -87,11 +90,11 @@ def setup_reference_model(
Set up the reference model for RL training if needed.
Args:
- cfg: The configuration dictionary
- tokenizer: The tokenizer to use for the reference model
+ cfg: Dictionary mapping `axolotl` config keys to values.
+ tokenizer: The tokenizer to use for the reference model.
Returns:
- Reference model if needed for RL training, None otherwise
+ Reference model if needed for RL training, `None` otherwise.
"""
model_ref = None
if cfg.rl and cfg.rl != "orpo":
@@ -110,10 +113,10 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
Determine the checkpoint to resume from based on configuration.
Args:
- cfg: The configuration dictionary
+ cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
- Path to the checkpoint to resume from, or None if not resuming
+ Path to the checkpoint to resume from, or `None` if not resuming.
"""
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
@@ -138,7 +141,7 @@ def setup_signal_handler(
Set up signal handler for graceful termination.
Args:
- cfg: The configuration dictionary
+ cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to save on termination
safe_serialization: Whether to use safe serialization when saving
"""
@@ -169,9 +172,9 @@ def execute_training(
Execute the training process with appropriate backend configurations.
Args:
- cfg: The configuration dictionary
- trainer: The configured trainer object
- resume_from_checkpoint: Path to checkpoint to resume from, if applicable
+ cfg: Dictionary mapping `axolotl` config keys to values.
+ trainer: The configured trainer object.
+ resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.group_by_length:
@@ -199,12 +202,12 @@ def save_trained_model(
Save the trained model according to configuration and training setup.
Args:
- cfg: The configuration dictionary
- trainer: The trainer object
- model: The trained model to save
- safe_serialization: Whether to use safe serialization
+ cfg: Dictionary mapping `axolotl` config keys to values.
+ trainer: The trainer object.
+ model: The trained model to save.
+ safe_serialization: Whether to use safe serialization.
"""
- LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
+ LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
# Post training module hooks
for name, module in model.named_modules():
@@ -278,8 +281,8 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
Create a model card for the trained model if needed.
Args:
- cfg: The configuration dictionary
- trainer: The trainer object with model card creation capabilities
+ cfg: Dictionary mapping `axolotl` config keys to values.
+ trainer: The trainer object with model card creation capabilities.
"""
if not cfg.hub_model_id:
# Guard since create_model_card may fail if dataset_tags is empty list
@@ -289,29 +292,23 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
.encode("utf-8")
.decode("utf-8")
}
- if cfg.datasets is not None:
+
+ # We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
+ rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
+ if cfg.datasets is not None and not rl:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
if dataset_tags:
- param_name = (
- "dataset_name"
- if (
- cfg.rl is not None
- or cfg.reward_model
- or cfg.process_reward_model
- )
- else "dataset_tags"
- )
- model_card_kwarg[param_name] = dataset_tags
+ model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
- # defensively push to the hub to ensure the model card is updated
+ # Defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()
@@ -325,23 +322,26 @@ def save_initial_configs(
Save initial configurations before training.
Args:
- cfg: The configuration dictionary
- tokenizer: The tokenizer to save
- model: The model to save configuration for
- peft_config: The PEFT configuration to save if applicable
+ cfg: Dictionary mapping `axolotl` config keys to values.
+ tokenizer: The tokenizer to save.
+ model: The model to save configuration for.
+ peft_config: The PEFT configuration to save if applicable.
"""
- # go ahead and presave, so we have the adapter config available to inspect
- if peft_config:
- LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
- peft_config.save_pretrained(cfg.output_dir)
-
- # additionally presave the tokenizer and model configs
+ # Create output_dir if it doesn't already exist
output_dir = Path(cfg.output_dir)
if not output_dir.is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
+ # Pre-save adapter config so it's available to inspect
+ if peft_config:
+ LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
+ peft_config.save_pretrained(cfg.output_dir)
+
+ # Pre-save the tokenizer and model configs
+ LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"):
+ LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
@@ -350,14 +350,14 @@ def setup_model_card(cfg: DictDefault):
Set up the Axolotl badge and add the Axolotl config to the model card if available.
Args:
- cfg: The configuration dictionary with path to axolotl config file
+ cfg: Dictionary mapping `axolotl` config keys to values.
"""
badge_markdown = """[
](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
- version = get_distribution("axolotl").version
+ version = importlib.metadata.version("axolotl")
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\nSee axolotl config
\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n
\n"
@@ -366,26 +366,26 @@ def handle_untrained_tokens_fix(
cfg: DictDefault,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
- train_dataset: Any,
+ train_dataset: Dataset,
safe_serialization: bool,
):
"""
Apply fixes for untrained tokens if configured.
Args:
- cfg: The configuration dictionary
- model: The model to apply fixes to
- tokenizer: The tokenizer for token identification
- train_dataset: The training dataset to analyze
- safe_serialization: Whether to use safe serialization when saving
+ cfg: Dictionary mapping `axolotl` config keys to values.
+ model: The model to apply fixes to.
+ tokenizer: The tokenizer for token identification.
+ train_dataset: The training dataset to use.
+ safe_serialization: Whether to use safe serialization when saving.
"""
if not cfg.fix_untrained_tokens:
return
- # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
+ # Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
- # if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
+ # If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
@@ -404,48 +404,90 @@ def handle_untrained_tokens_fix(
)
-def train(
- *, cfg: DictDefault, dataset_meta: TrainDatasetMeta
-) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
+def setup_model_and_trainer(
+ cfg: DictDefault, dataset_meta: TrainDatasetMeta
+) -> tuple[
+ HFRLTrainerBuilder | HFCausalTrainerBuilder,
+ PeftModel | PreTrainedModel,
+ PreTrainedTokenizer,
+ PeftConfig | None,
+]:
"""
- Train a model on the given dataset.
+ Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
+ trainer setup.
Args:
- cfg: The configuration dictionary with training parameters
- dataset_meta: Metadata about the training dataset
+ cfg: The configuration dictionary with training parameters.
+ dataset_meta: Object with training, validation datasets and metadata.
Returns:
- Tuple of (model, tokenizer) after training
+ Tuple of:
+ - Trainer (Causal or RLHF)
+ - Model
+ - Tokenizer
+ - PEFT config
"""
# Load tokenizer, processor and model
- tokenizer, processor, model, peft_config = setup_model_and_tokenizer(cfg)
+ model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
# Set up reference model for RL if needed
model_ref = setup_reference_model(cfg, tokenizer)
- # Determine if we need to resume from a checkpoint
- resume_from_checkpoint = determine_resume_checkpoint(cfg)
-
# Get datasets from metadata
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
+ # Set up trainer
+ trainer = setup_trainer(
+ cfg=cfg,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ model=model,
+ tokenizer=tokenizer,
+ processor=processor,
+ total_num_steps=total_num_steps,
+ model_ref=model_ref,
+ peft_config=peft_config,
+ )
+
+ return (
+ trainer,
+ model,
+ tokenizer,
+ peft_config,
+ )
+
+
+def train(
+ cfg: DictDefault, dataset_meta: TrainDatasetMeta
+) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
+ """
+ Train a model on the given dataset.
+
+ Args:
+ cfg: The configuration dictionary with training parameters
+ dataset_meta: Object with training, validation datasets and metadata
+
+ Returns:
+ Tuple of (model, tokenizer) after training
+ """
+ # Setup model, tokenizer, (causal or RLHF) trainer etc.
+ (
+ trainer,
+ model,
+ tokenizer,
+ peft_config,
+ ) = setup_model_and_trainer(cfg, dataset_meta)
+
+ # Determine if we need to resume from a checkpoint
+ resume_from_checkpoint = determine_resume_checkpoint(cfg)
+
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
- # Set up trainer
- trainer = setup_trainer(
- cfg,
- train_dataset,
- eval_dataset,
- (model, model_ref, peft_config),
- tokenizer,
- processor,
- total_num_steps,
- )
-
# Handle untrained tokens if configured
+ train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 8553339b9..8cee3d124 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -574,14 +574,40 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
- cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
+ cfg,
+ train_dataset,
+ eval_dataset,
+ model,
+ tokenizer,
+ processor,
+ total_num_steps,
+ model_ref=None,
+ peft_config=None,
):
+ """
+ Helper method for instantiating and building a (causal or RLHF) trainer.
+
+ Args:
+ cfg: Axolotl config object containing training parameters.
+ train_dataset: Dataset to use for training.
+ eval_dataset: Dataset to use for evaluation.
+ model: The model to train.
+ tokenizer: Tokenizer for processing text input.
+ processor: Processor for data preparation.
+ total_num_steps: The total number of training steps.
+ model_ref: Optional reference model for RLHF training. Default is None.
+ peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
+
+ Returns:
+ A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
+ on the provided parameters.
+ """
if cfg.rl:
- trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
- trainer_builder.model_ref = model[1]
- trainer_builder.peft_config = model[2]
+ trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
+ trainer_builder.model_ref = model_ref
+ trainer_builder.peft_config = peft_config
else:
- trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
+ trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset