diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index db07eb43b..3e712f772 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -24,8 +24,8 @@ class TrainDatasetMeta: """Dataclass with fields for training and validation datasets and metadata.""" train_dataset: Dataset - eval_dataset: Optional[Dataset] = None - total_num_steps: Optional[int] = None + eval_dataset: Dataset | None = None + total_num_steps: int | None = None def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 12346b8a2..d4ddc9bf3 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -91,13 +91,11 @@ try: except ImportError: pass -LOG = logging.getLogger("axolotl.core.trainer_builder") +LOG = logging.getLogger(__name__) class TrainerBuilderBase(abc.ABC): - """ - Base class for trainer builder - """ + """Base class for trainer builder.""" _train_dataset = None _eval_dataset = None @@ -110,9 +108,9 @@ class TrainerBuilderBase(abc.ABC): self.tokenizer = tokenizer self.processor = processor - # in case the model supports tagging, add the axolotl tag. + # If the model supports tagging, add the axolotl tag. # This makes sure the tag is correctly pushed even if a user calls - # model.push_to_hub instad of trainer.push_to_hub. + # model.push_to_hub instead of trainer.push_to_hub. if hasattr(model, "add_model_tags"): model.add_model_tags(["axolotl"]) @@ -227,8 +225,8 @@ class TrainerBuilderBase(abc.ABC): class HFCausalTrainerBuilder(TrainerBuilderBase): """ - Build the HuggingFace training args/trainer for causal models - and reward modelling using TRL. + Build the HuggingFace training args/trainer for causal models and reward modeling + using TRL. """ def get_callbacks(self): @@ -872,9 +870,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): class HFRLTrainerBuilder(TrainerBuilderBase): - """ - Trainer factory class for TRL-based RLHF trainers (e.g. DPO) - """ + """Trainer factory class for TRL-based RLHF trainers (e.g. DPO)""" def get_callbacks(self): callbacks = super().get_callbacks() diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 515248fce..b2f4bf1e9 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,26 +1,29 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +import importlib import inspect import os import signal import sys import weakref from pathlib import Path -from typing import Tuple, Union +from typing import Any import torch import transformers.modelcard from accelerate.logging import get_logger from accelerate.utils import save_fsdp_model -from peft import PeftModel -from pkg_resources import get_distribution # type: ignore -from transformers import PreTrainedModel, PreTrainedTokenizer +from datasets import Dataset +from peft import PeftConfig, PeftModel +from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.trainer import Trainer from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) +from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except @@ -32,17 +35,25 @@ try: except ImportError: BetterTransformer = None -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -src_dir = os.path.join(project_root, "src") -sys.path.insert(0, src_dir) - configure_logging() LOG = get_logger(__name__) -def train( - *, cfg: DictDefault, dataset_meta: TrainDatasetMeta -) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: +def setup_model_and_tokenizer( + cfg: DictDefault, +) -> tuple[ + PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None +]: + """ + Load the tokenizer, processor (for multimodal models), and model based on configuration. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + + Returns: + Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else + `None`), and processor (if multimodal, else `None`). + """ # Load tokenizer LOG.debug( f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", @@ -55,11 +66,58 @@ def train( if cfg.is_multimodal: processor = load_processor(cfg, tokenizer) - # Get datasets - train_dataset = dataset_meta.train_dataset - eval_dataset = dataset_meta.eval_dataset - total_num_steps = dataset_meta.total_num_steps + # Load the model and peft_config + msg = "loading model" + if cfg.adapter: + msg += " and peft_config..." + LOG.debug(msg) + model, peft_config = load_model(cfg, tokenizer, processor=processor) + if model.generation_config is not None: + model.generation_config.do_sample = True + + # Apply freezing if specified + if cfg.unfrozen_parameters: + freeze_layers_except(model, cfg.unfrozen_parameters) + + return model, tokenizer, peft_config, processor + + +def setup_reference_model( + cfg: DictDefault, tokenizer: PreTrainedTokenizer +) -> PreTrainedModel | None: + """ + Set up the reference model for RL training if needed. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + tokenizer: The tokenizer to use for the reference model. + + Returns: + Reference model if needed for RL training, `None` otherwise. + """ + model_ref = None + if cfg.rl and cfg.rl != "orpo": + if cfg.adapter and not cfg.rl_adapter_ref_model: + # use built-in trl autounwrap + LOG.debug("Passing model_ref: None to RL trainer") + model_ref = None # explicit setting to None + else: + # load the model again for model_ref/baseline + model_ref, _ = load_model(cfg, tokenizer, reference_model=True) + return model_ref + + +def determine_resume_checkpoint(cfg: DictDefault) -> str | None: + """ + Determine the checkpoint to resume from based on configuration. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + + Returns: + Path to the checkpoint to resume from, or `None` if not resuming. + """ if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: possible_checkpoints = [ str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") @@ -73,77 +131,22 @@ def train( LOG.info( f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" ) - resume_from_checkpoint = cfg.resume_from_checkpoint + return cfg.resume_from_checkpoint - # Load the model and tokenizer - msg = "loading model" - if cfg.adapter: - msg += " and peft_config..." - LOG.debug(msg) - model, peft_config = load_model(cfg, tokenizer, processor=processor) - if model.generation_config is not None: - model.generation_config.do_sample = True - model_ref = None - if cfg.rl and cfg.rl != "orpo": - if cfg.adapter and not cfg.rl_adapter_ref_model: - # use built-in trl autounwrap - LOG.debug("Passing model_ref: None to RL trainer") - model_ref = None # explicit setting to None - else: - # load the model again for model_ref/baseline - model_ref, _ = load_model(cfg, tokenizer, reference_model=True) +def setup_signal_handler( + cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool +): + """ + Set up signal handler for graceful termination. - safe_serialization = cfg.save_safetensors is True - - if cfg.unfrozen_parameters: - freeze_layers_except(model, cfg.unfrozen_parameters) - - trainer = setup_trainer( - cfg, - train_dataset, - eval_dataset, - (model, model_ref, peft_config), - tokenizer, - processor, - total_num_steps, - ) - - if cfg.fix_untrained_tokens: - # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args - sig = inspect.signature(fix_untrained_tokens) - # if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list - if "token_ids_to_fix" in sig.parameters and isinstance( - cfg.fix_untrained_tokens, list - ): - fix_untrained_tokens( - model, - tokenizer, - train_dataset, - token_ids_to_fix=cfg.fix_untrained_tokens, - ) - else: - fix_untrained_tokens(model, tokenizer, train_dataset) - if cfg.local_rank == 0: - model.save_pretrained( - str(Path(cfg.output_dir)), safe_serialization=safe_serialization - ) - - # go ahead and presave, so we have the adapter config available to inspect - if peft_config: - LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") - peft_config.save_pretrained(cfg.output_dir) - # additionally presave the tokenizer and model configs - if not Path(cfg.output_dir).is_dir(): - os.makedirs(cfg.output_dir, exist_ok=True) - tokenizer.save_pretrained(str(Path(cfg.output_dir))) - if hasattr(model, "config"): - model.config.save_pretrained(str(Path(cfg.output_dir))) - - # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model - if ( - cfg.local_rank == 0 and not cfg.use_ray - ): # ray workers don't have access to this signal + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + model: The model to save on termination + safe_serialization: Whether to use safe serialization when saving + """ + # ray workers don't have access to this signal + if cfg.local_rank == 0 and not cfg.use_ray: def terminate_handler(_, __, model_weakref): if model_weakref() is not None: @@ -161,21 +164,22 @@ def train( lambda signum, frame: terminate_handler(signum, frame, _model_weakref), ) - badge_markdown = """[Built with Axolotl](https://github.com/axolotl-ai-cloud/axolotl)""" - transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" - if getattr(cfg, "axolotl_config_path"): - raw_axolotl_cfg = Path(cfg.axolotl_config_path) - version = get_distribution("axolotl").version - if raw_axolotl_cfg.is_file(): - transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n
See axolotl config\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n

\n" +def execute_training( + cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None +): + """ + Execute the training process with appropriate backend configurations. + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + trainer: The configured trainer object. + resume_from_checkpoint: Path to checkpoint to resume from, if applicable. + """ LOG.info("Starting trainer...") if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") - pretrain_hooks(cfg, trainer) - if cfg.flash_optimum: with torch.backends.cuda.sdp_kernel( # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... @@ -187,15 +191,30 @@ def train( else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) - post_train_hooks(cfg, trainer) - LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") +def save_trained_model( + cfg: DictDefault, + trainer: Any, + model: PreTrainedModel, + safe_serialization: bool, +): + """ + Save the trained model according to configuration and training setup. - # post training + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + trainer: The trainer object. + model: The trained model to save. + safe_serialization: Whether to use safe serialization. + """ + LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.") + + # Post training module hooks for name, module in model.named_modules(): if hasattr(module, "_post_training"): module._post_training(model, name) # pylint: disable=protected-access + # Handle FSDP state dict type state_dict_type = "FULL_STATE_DICT" if trainer.is_fsdp_enabled: if cfg.fsdp_final_state_dict_type: @@ -203,16 +222,18 @@ def train( trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.") + # Handle ReLoRA early return case if cfg.relora_steps: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): model = model.merge_and_unload() else: # final model weights have already been saved by `ReLoRACallback.on_train_end` - return model, tokenizer + return - # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading - # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: + # TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading + # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple + # processes attempt to write the same file if ( state_dict_type == "SHARDED_STATE_DICT" and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT" @@ -244,7 +265,6 @@ def train( os.remove(os.path.join(cfg.output_dir, "model.safetensors")) except FileNotFoundError: pass - elif cfg.local_rank == 0: if cfg.flash_optimum and BetterTransformer: model = BetterTransformer.reverse(model) @@ -255,58 +275,239 @@ def train( ) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + +def create_model_card(cfg: DictDefault, trainer: Trainer): + """ + Create a model card for the trained model if needed. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + trainer: The trainer object with model card creation capabilities. + """ if not cfg.hub_model_id: + # Guard since create_model_card may fail if dataset_tags is empty list try: model_card_kwarg = { "model_name": cfg.output_dir.lstrip("./") .encode("utf-8") .decode("utf-8") } - if cfg.datasets is not None: - if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model: - dataset_tags = [ - d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() - ] - dataset_tags = [ - d for d in dataset_tags if not d.startswith("https://") - ] - if dataset_tags: - # guard as create_model_card may fail if dataset_tags is empty list - model_card_kwarg["dataset_name"] = dataset_tags - else: - dataset_tags = [ - d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() - ] - dataset_tags = [ - d for d in dataset_tags if not d.startswith("https://") - ] - if dataset_tags: - # guard as create_model_card may fail if dataset_tags is empty list - model_card_kwarg["dataset_tags"] = dataset_tags + + # We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed. + rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model + if cfg.datasets is not None and not rl: + dataset_tags = [ + d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() + ] + dataset_tags = [d for d in dataset_tags if not d.startswith("https://")] + + if dataset_tags: + model_card_kwarg["dataset_tags"] = dataset_tags trainer.create_model_card(**model_card_kwarg) except (AttributeError, UnicodeDecodeError): pass elif cfg.hub_model_id: - # defensively push to the hub to ensure the model card is updated + # Defensively push to the hub to ensure the model card is updated trainer.push_to_hub() + +def save_initial_configs( + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + model: PreTrainedModel, + peft_config: PeftConfig | None, +): + """ + Save initial configurations before training. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + tokenizer: The tokenizer to save. + model: The model to save configuration for. + peft_config: The PEFT configuration to save if applicable. + """ + # Create output_dir if it doesn't already exist + output_dir = Path(cfg.output_dir) + if not output_dir.is_dir(): + os.makedirs(cfg.output_dir, exist_ok=True) + + # Pre-save adapter config so it's available to inspect + if peft_config: + LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...") + peft_config.save_pretrained(cfg.output_dir) + + # Pre-save the tokenizer and model configs + LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...") + tokenizer.save_pretrained(str(output_dir)) + if hasattr(model, "config"): + LOG.info(f"Pre-saving model config to {cfg.output_dir}...") + model.config.save_pretrained(str(output_dir)) + + +def setup_model_card(cfg: DictDefault): + """ + Set up the Axolotl badge and add the Axolotl config to the model card if available. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + """ + badge_markdown = """[Built with Axolotl](https://github.com/axolotl-ai-cloud/axolotl)""" + transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + + if getattr(cfg, "axolotl_config_path"): + raw_axolotl_cfg = Path(cfg.axolotl_config_path) + version = importlib.metadata.version("axolotl") + if raw_axolotl_cfg.is_file(): + transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n
See axolotl config\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n

\n" + + +def handle_untrained_tokens_fix( + cfg: DictDefault, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + train_dataset: Dataset, + safe_serialization: bool, +): + """ + Apply fixes for untrained tokens if configured. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + model: The model to apply fixes to. + tokenizer: The tokenizer for token identification. + train_dataset: The training dataset to use. + safe_serialization: Whether to use safe serialization when saving. + """ + if not cfg.fix_untrained_tokens: + return + + # Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args + sig = inspect.signature(fix_untrained_tokens) + + # If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list + if "token_ids_to_fix" in sig.parameters and isinstance( + cfg.fix_untrained_tokens, list + ): + fix_untrained_tokens( + model, + tokenizer, + train_dataset, + token_ids_to_fix=cfg.fix_untrained_tokens, + ) + else: + fix_untrained_tokens(model, tokenizer, train_dataset) + + if cfg.local_rank == 0: + model.save_pretrained( + str(Path(cfg.output_dir)), safe_serialization=safe_serialization + ) + + +def setup_model_and_trainer( + cfg: DictDefault, dataset_meta: TrainDatasetMeta +) -> tuple[ + HFRLTrainerBuilder | HFCausalTrainerBuilder, + PeftModel | PreTrainedModel, + PreTrainedTokenizer, + PeftConfig | None, +]: + """ + Load model, tokenizer, trainer, etc. Helper function to encapsulate the full + trainer setup. + + Args: + cfg: The configuration dictionary with training parameters. + dataset_meta: Object with training, validation datasets and metadata. + + Returns: + Tuple of: + - Trainer (Causal or RLHF) + - Model + - Tokenizer + - PEFT config + """ + # Load tokenizer, processor and model + model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg) + + # Set up reference model for RL if needed + model_ref = setup_reference_model(cfg, tokenizer) + + # Get datasets from metadata + train_dataset = dataset_meta.train_dataset + eval_dataset = dataset_meta.eval_dataset + total_num_steps = dataset_meta.total_num_steps + + # Set up trainer + trainer = setup_trainer( + cfg=cfg, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + model=model, + tokenizer=tokenizer, + processor=processor, + total_num_steps=total_num_steps, + model_ref=model_ref, + peft_config=peft_config, + ) + + return ( + trainer, + model, + tokenizer, + peft_config, + ) + + +def train( + cfg: DictDefault, dataset_meta: TrainDatasetMeta +) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]: + """ + Train a model on the given dataset. + + Args: + cfg: The configuration dictionary with training parameters + dataset_meta: Object with training, validation datasets and metadata + + Returns: + Tuple of (model, tokenizer) after training + """ + # Setup model, tokenizer, (causal or RLHF) trainer etc. + ( + trainer, + model, + tokenizer, + peft_config, + ) = setup_model_and_trainer(cfg, dataset_meta) + + # Determine if we need to resume from a checkpoint + resume_from_checkpoint = determine_resume_checkpoint(cfg) + + # Configuration for saving + safe_serialization = cfg.save_safetensors is True + + # Handle untrained tokens if configured + train_dataset = dataset_meta.train_dataset + handle_untrained_tokens_fix( + cfg, model, tokenizer, train_dataset, safe_serialization + ) + + # Save initial configs + save_initial_configs(cfg, tokenizer, model, peft_config) + + # Set up signal handler for graceful termination + setup_signal_handler(cfg, model, safe_serialization) + + # Set up badges and config info for model card + setup_model_card(cfg) + + # Execute the training + execute_training(cfg, trainer, resume_from_checkpoint) + + # Save the trained model + save_trained_model(cfg, trainer, model, safe_serialization) + + # Create model card + create_model_card(cfg, trainer) + return model, tokenizer - - -def pretrain_hooks(_cfg, _trainer): - """ - Run hooks right before kicking off the training - :param cfg: - :param trainer: - :return: - """ - - -def post_train_hooks(_cfg, _trainer): - """ - Run hooks right after training completes - :param cfg: - :param trainer: - :return: - """ diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8553339b9..8cee3d124 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -574,14 +574,40 @@ def prepare_opinionated_env(cfg): def setup_trainer( - cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps + cfg, + train_dataset, + eval_dataset, + model, + tokenizer, + processor, + total_num_steps, + model_ref=None, + peft_config=None, ): + """ + Helper method for instantiating and building a (causal or RLHF) trainer. + + Args: + cfg: Axolotl config object containing training parameters. + train_dataset: Dataset to use for training. + eval_dataset: Dataset to use for evaluation. + model: The model to train. + tokenizer: Tokenizer for processing text input. + processor: Processor for data preparation. + total_num_steps: The total number of training steps. + model_ref: Optional reference model for RLHF training. Default is None. + peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None. + + Returns: + A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based + on the provided parameters. + """ if cfg.rl: - trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) - trainer_builder.model_ref = model[1] - trainer_builder.peft_config = model[2] + trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor) + trainer_builder.model_ref = model_ref + trainer_builder.peft_config = peft_config else: - trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor) + trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor) trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset