review comments

This commit is contained in:
Dan Saunders
2025-03-03 18:44:16 +00:00
parent e1bc18763a
commit 5c0510a876
4 changed files with 157 additions and 93 deletions

View File

@@ -24,8 +24,8 @@ class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata.""" """Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset train_dataset: Dataset
eval_dataset: Optional[Dataset] = None eval_dataset: Dataset | None = None
total_num_steps: Optional[int] = None total_num_steps: int | None = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:

View File

@@ -91,13 +91,11 @@ try:
except ImportError: except ImportError:
pass pass
LOG = logging.getLogger("axolotl.core.trainer_builder") LOG = logging.getLogger(__name__)
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """Base class for trainer builder."""
Base class for trainer builder
"""
_train_dataset = None _train_dataset = None
_eval_dataset = None _eval_dataset = None
@@ -110,9 +108,9 @@ class TrainerBuilderBase(abc.ABC):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor self.processor = processor
# in case the model supports tagging, add the axolotl tag. # If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls # This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instad of trainer.push_to_hub. # model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"): if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"]) model.add_model_tags(["axolotl"])
@@ -872,9 +870,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase): class HFRLTrainerBuilder(TrainerBuilderBase):
""" """Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = super().get_callbacks()

View File

@@ -1,19 +1,20 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import inspect import inspect
import os import os
import signal import signal
import sys import sys
import weakref import weakref
from pathlib import Path from pathlib import Path
from typing import Any, Tuple, Union from typing import Any
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftConfig, PeftModel from peft import PeftConfig, PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
@@ -22,6 +23,7 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except
@@ -39,17 +41,18 @@ LOG = get_logger(__name__)
def setup_model_and_tokenizer( def setup_model_and_tokenizer(
cfg: DictDefault, cfg: DictDefault,
) -> Tuple[ ) -> tuple[
PreTrainedTokenizer, ProcessorMixin | None, PreTrainedModel, PeftConfig | None PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]: ]:
""" """
Load the tokenizer, processor (for multimodal models), and model based on configuration. Load the tokenizer, processor (for multimodal models), and model based on configuration.
Args: Args:
cfg: The configuration dictionary with training parameters cfg: Dictionary mapping `axolotl` config keys to values.
Returns: Returns:
Tuple containing tokenizer, processor (if multimodal, else None), model, and peft_config (if applicable, else None) Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
`None`), and processor (if multimodal, else `None`).
""" """
# Load tokenizer # Load tokenizer
LOG.debug( LOG.debug(
@@ -77,7 +80,7 @@ def setup_model_and_tokenizer(
if cfg.unfrozen_parameters: if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters) freeze_layers_except(model, cfg.unfrozen_parameters)
return tokenizer, processor, model, peft_config return model, tokenizer, peft_config, processor
def setup_reference_model( def setup_reference_model(
@@ -87,11 +90,11 @@ def setup_reference_model(
Set up the reference model for RL training if needed. Set up the reference model for RL training if needed.
Args: Args:
cfg: The configuration dictionary cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to use for the reference model tokenizer: The tokenizer to use for the reference model.
Returns: Returns:
Reference model if needed for RL training, None otherwise Reference model if needed for RL training, `None` otherwise.
""" """
model_ref = None model_ref = None
if cfg.rl and cfg.rl != "orpo": if cfg.rl and cfg.rl != "orpo":
@@ -110,10 +113,10 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
Determine the checkpoint to resume from based on configuration. Determine the checkpoint to resume from based on configuration.
Args: Args:
cfg: The configuration dictionary cfg: Dictionary mapping `axolotl` config keys to values.
Returns: Returns:
Path to the checkpoint to resume from, or None if not resuming Path to the checkpoint to resume from, or `None` if not resuming.
""" """
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [ possible_checkpoints = [
@@ -138,7 +141,7 @@ def setup_signal_handler(
Set up signal handler for graceful termination. Set up signal handler for graceful termination.
Args: Args:
cfg: The configuration dictionary cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to save on termination model: The model to save on termination
safe_serialization: Whether to use safe serialization when saving safe_serialization: Whether to use safe serialization when saving
""" """
@@ -169,9 +172,9 @@ def execute_training(
Execute the training process with appropriate backend configurations. Execute the training process with appropriate backend configurations.
Args: Args:
cfg: The configuration dictionary cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The configured trainer object trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
""" """
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
if cfg.group_by_length: if cfg.group_by_length:
@@ -199,12 +202,12 @@ def save_trained_model(
Save the trained model according to configuration and training setup. Save the trained model according to configuration and training setup.
Args: Args:
cfg: The configuration dictionary cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object trainer: The trainer object.
model: The trained model to save model: The trained model to save.
safe_serialization: Whether to use safe serialization safe_serialization: Whether to use safe serialization.
""" """
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
# Post training module hooks # Post training module hooks
for name, module in model.named_modules(): for name, module in model.named_modules():
@@ -278,8 +281,8 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
Create a model card for the trained model if needed. Create a model card for the trained model if needed.
Args: Args:
cfg: The configuration dictionary cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object with model card creation capabilities trainer: The trainer object with model card creation capabilities.
""" """
if not cfg.hub_model_id: if not cfg.hub_model_id:
# Guard since create_model_card may fail if dataset_tags is empty list # Guard since create_model_card may fail if dataset_tags is empty list
@@ -289,29 +292,23 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
.encode("utf-8") .encode("utf-8")
.decode("utf-8") .decode("utf-8")
} }
if cfg.datasets is not None:
# We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
if cfg.datasets is not None and not rl:
dataset_tags = [ dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
] ]
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")] dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
if dataset_tags: if dataset_tags:
param_name = ( model_card_kwarg["dataset_tags"] = dataset_tags
"dataset_name"
if (
cfg.rl is not None
or cfg.reward_model
or cfg.process_reward_model
)
else "dataset_tags"
)
model_card_kwarg[param_name] = dataset_tags
trainer.create_model_card(**model_card_kwarg) trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError): except (AttributeError, UnicodeDecodeError):
pass pass
elif cfg.hub_model_id: elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated # Defensively push to the hub to ensure the model card is updated
trainer.push_to_hub() trainer.push_to_hub()
@@ -325,23 +322,26 @@ def save_initial_configs(
Save initial configurations before training. Save initial configurations before training.
Args: Args:
cfg: The configuration dictionary cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to save tokenizer: The tokenizer to save.
model: The model to save configuration for model: The model to save configuration for.
peft_config: The PEFT configuration to save if applicable peft_config: The PEFT configuration to save if applicable.
""" """
# go ahead and presave, so we have the adapter config available to inspect # Create output_dir if it doesn't already exist
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
output_dir = Path(cfg.output_dir) output_dir = Path(cfg.output_dir)
if not output_dir.is_dir(): if not output_dir.is_dir():
os.makedirs(cfg.output_dir, exist_ok=True) os.makedirs(cfg.output_dir, exist_ok=True)
# Pre-save adapter config so it's available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
peft_config.save_pretrained(cfg.output_dir)
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir)) tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"): if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir)) model.config.save_pretrained(str(output_dir))
@@ -350,14 +350,14 @@ def setup_model_card(cfg: DictDefault):
Set up the Axolotl badge and add the Axolotl config to the model card if available. Set up the Axolotl badge and add the Axolotl config to the model card if available.
Args: Args:
cfg: The configuration dictionary with path to axolotl config file cfg: Dictionary mapping `axolotl` config keys to values.
""" """
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)""" badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"): if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path) raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = get_distribution("axolotl").version version = importlib.metadata.version("axolotl")
if raw_axolotl_cfg.is_file(): if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
@@ -366,26 +366,26 @@ def handle_untrained_tokens_fix(
cfg: DictDefault, cfg: DictDefault,
model: PreTrainedModel, model: PreTrainedModel,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
train_dataset: Any, train_dataset: Dataset,
safe_serialization: bool, safe_serialization: bool,
): ):
""" """
Apply fixes for untrained tokens if configured. Apply fixes for untrained tokens if configured.
Args: Args:
cfg: The configuration dictionary cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to apply fixes to model: The model to apply fixes to.
tokenizer: The tokenizer for token identification tokenizer: The tokenizer for token identification.
train_dataset: The training dataset to analyze train_dataset: The training dataset to use.
safe_serialization: Whether to use safe serialization when saving safe_serialization: Whether to use safe serialization when saving.
""" """
if not cfg.fix_untrained_tokens: if not cfg.fix_untrained_tokens:
return return
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args # Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens) sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list # If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance( if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list cfg.fix_untrained_tokens, list
): ):
@@ -404,48 +404,90 @@ def handle_untrained_tokens_fix(
) )
def train( def setup_model_and_trainer(
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: ) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
]:
""" """
Train a model on the given dataset. Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
trainer setup.
Args: Args:
cfg: The configuration dictionary with training parameters cfg: The configuration dictionary with training parameters.
dataset_meta: Metadata about the training dataset dataset_meta: Object with training, validation datasets and metadata.
Returns: Returns:
Tuple of (model, tokenizer) after training Tuple of:
- Trainer (Causal or RLHF)
- Model
- Tokenizer
- PEFT config
""" """
# Load tokenizer, processor and model # Load tokenizer, processor and model
tokenizer, processor, model, peft_config = setup_model_and_tokenizer(cfg) model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
# Set up reference model for RL if needed # Set up reference model for RL if needed
model_ref = setup_reference_model(cfg, tokenizer) model_ref = setup_reference_model(cfg, tokenizer)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Get datasets from metadata # Get datasets from metadata
train_dataset = dataset_meta.train_dataset train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps total_num_steps = dataset_meta.total_num_steps
# Set up trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,
model_ref=model_ref,
peft_config=peft_config,
)
return (
trainer,
model,
tokenizer,
peft_config,
)
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
"""
Train a model on the given dataset.
Args:
cfg: The configuration dictionary with training parameters
dataset_meta: Object with training, validation datasets and metadata
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer etc.
(
trainer,
model,
tokenizer,
peft_config,
) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving # Configuration for saving
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
# Set up trainer
trainer = setup_trainer(
cfg,
train_dataset,
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
# Handle untrained tokens if configured # Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix( handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization cfg, model, tokenizer, train_dataset, safe_serialization
) )

View File

@@ -574,14 +574,40 @@ def prepare_opinionated_env(cfg):
def setup_trainer( def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps cfg,
train_dataset,
eval_dataset,
model,
tokenizer,
processor,
total_num_steps,
model_ref=None,
peft_config=None,
): ):
"""
Helper method for instantiating and building a (causal or RLHF) trainer.
Args:
cfg: Axolotl config object containing training parameters.
train_dataset: Dataset to use for training.
eval_dataset: Dataset to use for evaluation.
model: The model to train.
tokenizer: Tokenizer for processing text input.
processor: Processor for data preparation.
total_num_steps: The total number of training steps.
model_ref: Optional reference model for RLHF training. Default is None.
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
Returns:
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
if cfg.rl: if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model[1] trainer_builder.model_ref = model_ref
trainer_builder.peft_config = model[2] trainer_builder.peft_config = peft_config
else: else:
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.train_dataset = train_dataset trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset trainer_builder.eval_dataset = eval_dataset