review comments

This commit is contained in:
Dan Saunders
2025-03-03 18:44:16 +00:00
parent e1bc18763a
commit 5c0510a876
4 changed files with 157 additions and 93 deletions

View File

@@ -24,8 +24,8 @@ class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
eval_dataset: Dataset | None = None
total_num_steps: int | None = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:

View File

@@ -91,13 +91,11 @@ try:
except ImportError:
pass
LOG = logging.getLogger("axolotl.core.trainer_builder")
LOG = logging.getLogger(__name__)
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
"""
"""Base class for trainer builder."""
_train_dataset = None
_eval_dataset = None
@@ -110,9 +108,9 @@ class TrainerBuilderBase(abc.ABC):
self.tokenizer = tokenizer
self.processor = processor
# in case the model supports tagging, add the axolotl tag.
# If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instad of trainer.push_to_hub.
# model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
@@ -872,9 +870,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
def get_callbacks(self):
callbacks = super().get_callbacks()

View File

@@ -1,19 +1,20 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import inspect
import os
import signal
import sys
import weakref
from pathlib import Path
from typing import Any, Tuple, Union
from typing import Any
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftConfig, PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
@@ -22,6 +23,7 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
@@ -39,17 +41,18 @@ LOG = get_logger(__name__)
def setup_model_and_tokenizer(
cfg: DictDefault,
) -> Tuple[
PreTrainedTokenizer, ProcessorMixin | None, PreTrainedModel, PeftConfig | None
) -> tuple[
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]:
"""
Load the tokenizer, processor (for multimodal models), and model based on configuration.
Args:
cfg: The configuration dictionary with training parameters
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Tuple containing tokenizer, processor (if multimodal, else None), model, and peft_config (if applicable, else None)
Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
`None`), and processor (if multimodal, else `None`).
"""
# Load tokenizer
LOG.debug(
@@ -77,7 +80,7 @@ def setup_model_and_tokenizer(
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
return tokenizer, processor, model, peft_config
return model, tokenizer, peft_config, processor
def setup_reference_model(
@@ -87,11 +90,11 @@ def setup_reference_model(
Set up the reference model for RL training if needed.
Args:
cfg: The configuration dictionary
tokenizer: The tokenizer to use for the reference model
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to use for the reference model.
Returns:
Reference model if needed for RL training, None otherwise
Reference model if needed for RL training, `None` otherwise.
"""
model_ref = None
if cfg.rl and cfg.rl != "orpo":
@@ -110,10 +113,10 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
Determine the checkpoint to resume from based on configuration.
Args:
cfg: The configuration dictionary
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Path to the checkpoint to resume from, or None if not resuming
Path to the checkpoint to resume from, or `None` if not resuming.
"""
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
@@ -138,7 +141,7 @@ def setup_signal_handler(
Set up signal handler for graceful termination.
Args:
cfg: The configuration dictionary
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to save on termination
safe_serialization: Whether to use safe serialization when saving
"""
@@ -169,9 +172,9 @@ def execute_training(
Execute the training process with appropriate backend configurations.
Args:
cfg: The configuration dictionary
trainer: The configured trainer object
resume_from_checkpoint: Path to checkpoint to resume from, if applicable
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.group_by_length:
@@ -199,12 +202,12 @@ def save_trained_model(
Save the trained model according to configuration and training setup.
Args:
cfg: The configuration dictionary
trainer: The trainer object
model: The trained model to save
safe_serialization: Whether to use safe serialization
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object.
model: The trained model to save.
safe_serialization: Whether to use safe serialization.
"""
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
# Post training module hooks
for name, module in model.named_modules():
@@ -278,8 +281,8 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
Create a model card for the trained model if needed.
Args:
cfg: The configuration dictionary
trainer: The trainer object with model card creation capabilities
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object with model card creation capabilities.
"""
if not cfg.hub_model_id:
# Guard since create_model_card may fail if dataset_tags is empty list
@@ -289,29 +292,23 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
.encode("utf-8")
.decode("utf-8")
}
if cfg.datasets is not None:
# We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
if cfg.datasets is not None and not rl:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
if dataset_tags:
param_name = (
"dataset_name"
if (
cfg.rl is not None
or cfg.reward_model
or cfg.process_reward_model
)
else "dataset_tags"
)
model_card_kwarg[param_name] = dataset_tags
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated
# Defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()
@@ -325,23 +322,26 @@ def save_initial_configs(
Save initial configurations before training.
Args:
cfg: The configuration dictionary
tokenizer: The tokenizer to save
model: The model to save configuration for
peft_config: The PEFT configuration to save if applicable
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to save.
model: The model to save configuration for.
peft_config: The PEFT configuration to save if applicable.
"""
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
# Create output_dir if it doesn't already exist
output_dir = Path(cfg.output_dir)
if not output_dir.is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
# Pre-save adapter config so it's available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
peft_config.save_pretrained(cfg.output_dir)
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
@@ -350,14 +350,14 @@ def setup_model_card(cfg: DictDefault):
Set up the Axolotl badge and add the Axolotl config to the model card if available.
Args:
cfg: The configuration dictionary with path to axolotl config file
cfg: Dictionary mapping `axolotl` config keys to values.
"""
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = get_distribution("axolotl").version
version = importlib.metadata.version("axolotl")
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
@@ -366,26 +366,26 @@ def handle_untrained_tokens_fix(
cfg: DictDefault,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Any,
train_dataset: Dataset,
safe_serialization: bool,
):
"""
Apply fixes for untrained tokens if configured.
Args:
cfg: The configuration dictionary
model: The model to apply fixes to
tokenizer: The tokenizer for token identification
train_dataset: The training dataset to analyze
safe_serialization: Whether to use safe serialization when saving
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to apply fixes to.
tokenizer: The tokenizer for token identification.
train_dataset: The training dataset to use.
safe_serialization: Whether to use safe serialization when saving.
"""
if not cfg.fix_untrained_tokens:
return
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
@@ -404,48 +404,90 @@ def handle_untrained_tokens_fix(
)
def train(
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
]:
"""
Train a model on the given dataset.
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
trainer setup.
Args:
cfg: The configuration dictionary with training parameters
dataset_meta: Metadata about the training dataset
cfg: The configuration dictionary with training parameters.
dataset_meta: Object with training, validation datasets and metadata.
Returns:
Tuple of (model, tokenizer) after training
Tuple of:
- Trainer (Causal or RLHF)
- Model
- Tokenizer
- PEFT config
"""
# Load tokenizer, processor and model
tokenizer, processor, model, peft_config = setup_model_and_tokenizer(cfg)
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
# Set up reference model for RL if needed
model_ref = setup_reference_model(cfg, tokenizer)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Get datasets from metadata
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Set up trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,
model_ref=model_ref,
peft_config=peft_config,
)
return (
trainer,
model,
tokenizer,
peft_config,
)
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
"""
Train a model on the given dataset.
Args:
cfg: The configuration dictionary with training parameters
dataset_meta: Object with training, validation datasets and metadata
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer etc.
(
trainer,
model,
tokenizer,
peft_config,
) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Set up trainer
trainer = setup_trainer(
cfg,
train_dataset,
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)

View File

@@ -574,14 +574,40 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
cfg,
train_dataset,
eval_dataset,
model,
tokenizer,
processor,
total_num_steps,
model_ref=None,
peft_config=None,
):
"""
Helper method for instantiating and building a (causal or RLHF) trainer.
Args:
cfg: Axolotl config object containing training parameters.
train_dataset: Dataset to use for training.
eval_dataset: Dataset to use for evaluation.
model: The model to train.
tokenizer: Tokenizer for processing text input.
processor: Processor for data preparation.
total_num_steps: The total number of training steps.
model_ref: Optional reference model for RLHF training. Default is None.
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
Returns:
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref
trainer_builder.peft_config = peft_config
else:
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset