From 5c0510a876dcb3ee2835748d8851e3f4eceba1b4 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 3 Mar 2025 18:44:16 +0000 Subject: [PATCH] review comments --- src/axolotl/common/datasets.py | 4 +- src/axolotl/core/trainer_builder.py | 14 +- src/axolotl/train.py | 196 +++++++++++++++++----------- src/axolotl/utils/trainer.py | 36 ++++- 4 files changed, 157 insertions(+), 93 deletions(-) diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index db07eb43b..3e712f772 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -24,8 +24,8 @@ class TrainDatasetMeta: """Dataclass with fields for training and validation datasets and metadata.""" train_dataset: Dataset - eval_dataset: Optional[Dataset] = None - total_num_steps: Optional[int] = None + eval_dataset: Dataset | None = None + total_num_steps: int | None = None def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 07753876f..d4ddc9bf3 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -91,13 +91,11 @@ try: except ImportError: pass -LOG = logging.getLogger("axolotl.core.trainer_builder") +LOG = logging.getLogger(__name__) class TrainerBuilderBase(abc.ABC): - """ - Base class for trainer builder - """ + """Base class for trainer builder.""" _train_dataset = None _eval_dataset = None @@ -110,9 +108,9 @@ class TrainerBuilderBase(abc.ABC): self.tokenizer = tokenizer self.processor = processor - # in case the model supports tagging, add the axolotl tag. + # If the model supports tagging, add the axolotl tag. # This makes sure the tag is correctly pushed even if a user calls - # model.push_to_hub instad of trainer.push_to_hub. + # model.push_to_hub instead of trainer.push_to_hub. if hasattr(model, "add_model_tags"): model.add_model_tags(["axolotl"]) @@ -872,9 +870,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): class HFRLTrainerBuilder(TrainerBuilderBase): - """ - Trainer factory class for TRL-based RLHF trainers (e.g. DPO) - """ + """Trainer factory class for TRL-based RLHF trainers (e.g. DPO)""" def get_callbacks(self): callbacks = super().get_callbacks() diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d61e200c6..b2f4bf1e9 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,19 +1,20 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +import importlib import inspect import os import signal import sys import weakref from pathlib import Path -from typing import Any, Tuple, Union +from typing import Any import torch import transformers.modelcard from accelerate.logging import get_logger from accelerate.utils import save_fsdp_model +from datasets import Dataset from peft import PeftConfig, PeftModel -from pkg_resources import get_distribution # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer import Trainer @@ -22,6 +23,7 @@ from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) +from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except @@ -39,17 +41,18 @@ LOG = get_logger(__name__) def setup_model_and_tokenizer( cfg: DictDefault, -) -> Tuple[ - PreTrainedTokenizer, ProcessorMixin | None, PreTrainedModel, PeftConfig | None +) -> tuple[ + PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None ]: """ Load the tokenizer, processor (for multimodal models), and model based on configuration. Args: - cfg: The configuration dictionary with training parameters + cfg: Dictionary mapping `axolotl` config keys to values. Returns: - Tuple containing tokenizer, processor (if multimodal, else None), model, and peft_config (if applicable, else None) + Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else + `None`), and processor (if multimodal, else `None`). """ # Load tokenizer LOG.debug( @@ -77,7 +80,7 @@ def setup_model_and_tokenizer( if cfg.unfrozen_parameters: freeze_layers_except(model, cfg.unfrozen_parameters) - return tokenizer, processor, model, peft_config + return model, tokenizer, peft_config, processor def setup_reference_model( @@ -87,11 +90,11 @@ def setup_reference_model( Set up the reference model for RL training if needed. Args: - cfg: The configuration dictionary - tokenizer: The tokenizer to use for the reference model + cfg: Dictionary mapping `axolotl` config keys to values. + tokenizer: The tokenizer to use for the reference model. Returns: - Reference model if needed for RL training, None otherwise + Reference model if needed for RL training, `None` otherwise. """ model_ref = None if cfg.rl and cfg.rl != "orpo": @@ -110,10 +113,10 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None: Determine the checkpoint to resume from based on configuration. Args: - cfg: The configuration dictionary + cfg: Dictionary mapping `axolotl` config keys to values. Returns: - Path to the checkpoint to resume from, or None if not resuming + Path to the checkpoint to resume from, or `None` if not resuming. """ if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: possible_checkpoints = [ @@ -138,7 +141,7 @@ def setup_signal_handler( Set up signal handler for graceful termination. Args: - cfg: The configuration dictionary + cfg: Dictionary mapping `axolotl` config keys to values. model: The model to save on termination safe_serialization: Whether to use safe serialization when saving """ @@ -169,9 +172,9 @@ def execute_training( Execute the training process with appropriate backend configurations. Args: - cfg: The configuration dictionary - trainer: The configured trainer object - resume_from_checkpoint: Path to checkpoint to resume from, if applicable + cfg: Dictionary mapping `axolotl` config keys to values. + trainer: The configured trainer object. + resume_from_checkpoint: Path to checkpoint to resume from, if applicable. """ LOG.info("Starting trainer...") if cfg.group_by_length: @@ -199,12 +202,12 @@ def save_trained_model( Save the trained model according to configuration and training setup. Args: - cfg: The configuration dictionary - trainer: The trainer object - model: The trained model to save - safe_serialization: Whether to use safe serialization + cfg: Dictionary mapping `axolotl` config keys to values. + trainer: The trainer object. + model: The trained model to save. + safe_serialization: Whether to use safe serialization. """ - LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.") # Post training module hooks for name, module in model.named_modules(): @@ -278,8 +281,8 @@ def create_model_card(cfg: DictDefault, trainer: Trainer): Create a model card for the trained model if needed. Args: - cfg: The configuration dictionary - trainer: The trainer object with model card creation capabilities + cfg: Dictionary mapping `axolotl` config keys to values. + trainer: The trainer object with model card creation capabilities. """ if not cfg.hub_model_id: # Guard since create_model_card may fail if dataset_tags is empty list @@ -289,29 +292,23 @@ def create_model_card(cfg: DictDefault, trainer: Trainer): .encode("utf-8") .decode("utf-8") } - if cfg.datasets is not None: + + # We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed. + rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model + if cfg.datasets is not None and not rl: dataset_tags = [ d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() ] dataset_tags = [d for d in dataset_tags if not d.startswith("https://")] if dataset_tags: - param_name = ( - "dataset_name" - if ( - cfg.rl is not None - or cfg.reward_model - or cfg.process_reward_model - ) - else "dataset_tags" - ) - model_card_kwarg[param_name] = dataset_tags + model_card_kwarg["dataset_tags"] = dataset_tags trainer.create_model_card(**model_card_kwarg) except (AttributeError, UnicodeDecodeError): pass elif cfg.hub_model_id: - # defensively push to the hub to ensure the model card is updated + # Defensively push to the hub to ensure the model card is updated trainer.push_to_hub() @@ -325,23 +322,26 @@ def save_initial_configs( Save initial configurations before training. Args: - cfg: The configuration dictionary - tokenizer: The tokenizer to save - model: The model to save configuration for - peft_config: The PEFT configuration to save if applicable + cfg: Dictionary mapping `axolotl` config keys to values. + tokenizer: The tokenizer to save. + model: The model to save configuration for. + peft_config: The PEFT configuration to save if applicable. """ - # go ahead and presave, so we have the adapter config available to inspect - if peft_config: - LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") - peft_config.save_pretrained(cfg.output_dir) - - # additionally presave the tokenizer and model configs + # Create output_dir if it doesn't already exist output_dir = Path(cfg.output_dir) if not output_dir.is_dir(): os.makedirs(cfg.output_dir, exist_ok=True) + # Pre-save adapter config so it's available to inspect + if peft_config: + LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...") + peft_config.save_pretrained(cfg.output_dir) + + # Pre-save the tokenizer and model configs + LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...") tokenizer.save_pretrained(str(output_dir)) if hasattr(model, "config"): + LOG.info(f"Pre-saving model config to {cfg.output_dir}...") model.config.save_pretrained(str(output_dir)) @@ -350,14 +350,14 @@ def setup_model_card(cfg: DictDefault): Set up the Axolotl badge and add the Axolotl config to the model card if available. Args: - cfg: The configuration dictionary with path to axolotl config file + cfg: Dictionary mapping `axolotl` config keys to values. """ badge_markdown = """[Built with Axolotl](https://github.com/axolotl-ai-cloud/axolotl)""" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" if getattr(cfg, "axolotl_config_path"): raw_axolotl_cfg = Path(cfg.axolotl_config_path) - version = get_distribution("axolotl").version + version = importlib.metadata.version("axolotl") if raw_axolotl_cfg.is_file(): transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n
See axolotl config\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n

\n" @@ -366,26 +366,26 @@ def handle_untrained_tokens_fix( cfg: DictDefault, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, - train_dataset: Any, + train_dataset: Dataset, safe_serialization: bool, ): """ Apply fixes for untrained tokens if configured. Args: - cfg: The configuration dictionary - model: The model to apply fixes to - tokenizer: The tokenizer for token identification - train_dataset: The training dataset to analyze - safe_serialization: Whether to use safe serialization when saving + cfg: Dictionary mapping `axolotl` config keys to values. + model: The model to apply fixes to. + tokenizer: The tokenizer for token identification. + train_dataset: The training dataset to use. + safe_serialization: Whether to use safe serialization when saving. """ if not cfg.fix_untrained_tokens: return - # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args + # Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args sig = inspect.signature(fix_untrained_tokens) - # if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list + # If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list if "token_ids_to_fix" in sig.parameters and isinstance( cfg.fix_untrained_tokens, list ): @@ -404,48 +404,90 @@ def handle_untrained_tokens_fix( ) -def train( - *, cfg: DictDefault, dataset_meta: TrainDatasetMeta -) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: +def setup_model_and_trainer( + cfg: DictDefault, dataset_meta: TrainDatasetMeta +) -> tuple[ + HFRLTrainerBuilder | HFCausalTrainerBuilder, + PeftModel | PreTrainedModel, + PreTrainedTokenizer, + PeftConfig | None, +]: """ - Train a model on the given dataset. + Load model, tokenizer, trainer, etc. Helper function to encapsulate the full + trainer setup. Args: - cfg: The configuration dictionary with training parameters - dataset_meta: Metadata about the training dataset + cfg: The configuration dictionary with training parameters. + dataset_meta: Object with training, validation datasets and metadata. Returns: - Tuple of (model, tokenizer) after training + Tuple of: + - Trainer (Causal or RLHF) + - Model + - Tokenizer + - PEFT config """ # Load tokenizer, processor and model - tokenizer, processor, model, peft_config = setup_model_and_tokenizer(cfg) + model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg) # Set up reference model for RL if needed model_ref = setup_reference_model(cfg, tokenizer) - # Determine if we need to resume from a checkpoint - resume_from_checkpoint = determine_resume_checkpoint(cfg) - # Get datasets from metadata train_dataset = dataset_meta.train_dataset eval_dataset = dataset_meta.eval_dataset total_num_steps = dataset_meta.total_num_steps + # Set up trainer + trainer = setup_trainer( + cfg=cfg, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + model=model, + tokenizer=tokenizer, + processor=processor, + total_num_steps=total_num_steps, + model_ref=model_ref, + peft_config=peft_config, + ) + + return ( + trainer, + model, + tokenizer, + peft_config, + ) + + +def train( + cfg: DictDefault, dataset_meta: TrainDatasetMeta +) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]: + """ + Train a model on the given dataset. + + Args: + cfg: The configuration dictionary with training parameters + dataset_meta: Object with training, validation datasets and metadata + + Returns: + Tuple of (model, tokenizer) after training + """ + # Setup model, tokenizer, (causal or RLHF) trainer etc. + ( + trainer, + model, + tokenizer, + peft_config, + ) = setup_model_and_trainer(cfg, dataset_meta) + + # Determine if we need to resume from a checkpoint + resume_from_checkpoint = determine_resume_checkpoint(cfg) + # Configuration for saving safe_serialization = cfg.save_safetensors is True - # Set up trainer - trainer = setup_trainer( - cfg, - train_dataset, - eval_dataset, - (model, model_ref, peft_config), - tokenizer, - processor, - total_num_steps, - ) - # Handle untrained tokens if configured + train_dataset = dataset_meta.train_dataset handle_untrained_tokens_fix( cfg, model, tokenizer, train_dataset, safe_serialization ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8553339b9..8cee3d124 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -574,14 +574,40 @@ def prepare_opinionated_env(cfg): def setup_trainer( - cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps + cfg, + train_dataset, + eval_dataset, + model, + tokenizer, + processor, + total_num_steps, + model_ref=None, + peft_config=None, ): + """ + Helper method for instantiating and building a (causal or RLHF) trainer. + + Args: + cfg: Axolotl config object containing training parameters. + train_dataset: Dataset to use for training. + eval_dataset: Dataset to use for evaluation. + model: The model to train. + tokenizer: Tokenizer for processing text input. + processor: Processor for data preparation. + total_num_steps: The total number of training steps. + model_ref: Optional reference model for RLHF training. Default is None. + peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None. + + Returns: + A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based + on the provided parameters. + """ if cfg.rl: - trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) - trainer_builder.model_ref = model[1] - trainer_builder.peft_config = model[2] + trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor) + trainer_builder.model_ref = model_ref + trainer_builder.peft_config = peft_config else: - trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor) + trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor) trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset