diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 515248fce..eb2ced231 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -6,15 +6,15 @@ import signal import sys import weakref from pathlib import Path -from typing import Tuple, Union +from typing import Any, Tuple, Union import torch import transformers.modelcard from accelerate.logging import get_logger from accelerate.utils import save_fsdp_model -from peft import PeftModel +from peft import PeftConfig, PeftModel from pkg_resources import get_distribution # type: ignore -from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.datasets import TrainDatasetMeta @@ -27,11 +27,13 @@ from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.trainer import setup_trainer +# Optional imports with graceful fallbacks try: from optimum.bettertransformer import BetterTransformer except ImportError: BetterTransformer = None +# Project setup project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) @@ -40,9 +42,20 @@ configure_logging() LOG = get_logger(__name__) -def train( - *, cfg: DictDefault, dataset_meta: TrainDatasetMeta -) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: +def setup_model_and_tokenizer( + cfg: DictDefault, +) -> Tuple[ + PreTrainedTokenizer, ProcessorMixin | None, PreTrainedModel, PeftConfig | None +]: + """ + Load the tokenizer, processor (for multimodal models), and model based on configuration. + + Args: + cfg: The configuration dictionary with training parameters + + Returns: + Tuple containing tokenizer, processor (if multimodal, else None), model, and peft_config (if applicable, else None) + """ # Load tokenizer LOG.debug( f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", @@ -55,11 +68,58 @@ def train( if cfg.is_multimodal: processor = load_processor(cfg, tokenizer) - # Get datasets - train_dataset = dataset_meta.train_dataset - eval_dataset = dataset_meta.eval_dataset - total_num_steps = dataset_meta.total_num_steps + # Load the model and peft_config + msg = "loading model" + if cfg.adapter: + msg += " and peft_config..." + LOG.debug(msg) + model, peft_config = load_model(cfg, tokenizer, processor=processor) + if model.generation_config is not None: + model.generation_config.do_sample = True + + # Apply freezing if specified + if cfg.unfrozen_parameters: + freeze_layers_except(model, cfg.unfrozen_parameters) + + return tokenizer, processor, model, peft_config + + +def setup_reference_model( + cfg: DictDefault, tokenizer: PreTrainedTokenizer +) -> PreTrainedModel | None: + """ + Set up the reference model for RL training if needed. + + Args: + cfg: The configuration dictionary + tokenizer: The tokenizer to use for the reference model + + Returns: + Reference model if needed for RL training, None otherwise + """ + model_ref = None + if cfg.rl and cfg.rl != "orpo": + if cfg.adapter and not cfg.rl_adapter_ref_model: + # use built-in trl autounwrap + LOG.debug("Passing model_ref: None to RL trainer") + model_ref = None # explicit setting to None + else: + # load the model again for model_ref/baseline + model_ref, _ = load_model(cfg, tokenizer, reference_model=True) + return model_ref + + +def determine_resume_checkpoint(cfg: DictDefault) -> str | None: + """ + Determine the checkpoint to resume from based on configuration. + + Args: + cfg: The configuration dictionary + + Returns: + Path to the checkpoint to resume from, or None if not resuming + """ if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: possible_checkpoints = [ str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") @@ -73,77 +133,22 @@ def train( LOG.info( f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" ) - resume_from_checkpoint = cfg.resume_from_checkpoint + return cfg.resume_from_checkpoint - # Load the model and tokenizer - msg = "loading model" - if cfg.adapter: - msg += " and peft_config..." - LOG.debug(msg) - model, peft_config = load_model(cfg, tokenizer, processor=processor) - if model.generation_config is not None: - model.generation_config.do_sample = True - model_ref = None - if cfg.rl and cfg.rl != "orpo": - if cfg.adapter and not cfg.rl_adapter_ref_model: - # use built-in trl autounwrap - LOG.debug("Passing model_ref: None to RL trainer") - model_ref = None # explicit setting to None - else: - # load the model again for model_ref/baseline - model_ref, _ = load_model(cfg, tokenizer, reference_model=True) +def setup_signal_handler( + cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool +): + """ + Set up signal handler for graceful termination. - safe_serialization = cfg.save_safetensors is True - - if cfg.unfrozen_parameters: - freeze_layers_except(model, cfg.unfrozen_parameters) - - trainer = setup_trainer( - cfg, - train_dataset, - eval_dataset, - (model, model_ref, peft_config), - tokenizer, - processor, - total_num_steps, - ) - - if cfg.fix_untrained_tokens: - # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args - sig = inspect.signature(fix_untrained_tokens) - # if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list - if "token_ids_to_fix" in sig.parameters and isinstance( - cfg.fix_untrained_tokens, list - ): - fix_untrained_tokens( - model, - tokenizer, - train_dataset, - token_ids_to_fix=cfg.fix_untrained_tokens, - ) - else: - fix_untrained_tokens(model, tokenizer, train_dataset) - if cfg.local_rank == 0: - model.save_pretrained( - str(Path(cfg.output_dir)), safe_serialization=safe_serialization - ) - - # go ahead and presave, so we have the adapter config available to inspect - if peft_config: - LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") - peft_config.save_pretrained(cfg.output_dir) - # additionally presave the tokenizer and model configs - if not Path(cfg.output_dir).is_dir(): - os.makedirs(cfg.output_dir, exist_ok=True) - tokenizer.save_pretrained(str(Path(cfg.output_dir))) - if hasattr(model, "config"): - model.config.save_pretrained(str(Path(cfg.output_dir))) - - # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model - if ( - cfg.local_rank == 0 and not cfg.use_ray - ): # ray workers don't have access to this signal + Args: + cfg: The configuration dictionary + model: The model to save on termination + safe_serialization: Whether to use safe serialization when saving + """ + # ray workers don't have access to this signal + if cfg.local_rank == 0 and not cfg.use_ray: def terminate_handler(_, __, model_weakref): if model_weakref() is not None: @@ -161,21 +166,22 @@ def train( lambda signum, frame: terminate_handler(signum, frame, _model_weakref), ) - badge_markdown = """[Built with Axolotl](https://github.com/axolotl-ai-cloud/axolotl)""" - transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" - if getattr(cfg, "axolotl_config_path"): - raw_axolotl_cfg = Path(cfg.axolotl_config_path) - version = get_distribution("axolotl").version - if raw_axolotl_cfg.is_file(): - transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n
See axolotl config\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n

\n" +def execute_training( + cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None +): + """ + Execute the training process with appropriate backend configurations. + Args: + cfg: The configuration dictionary + trainer: The configured trainer object + resume_from_checkpoint: Path to checkpoint to resume from, if applicable + """ LOG.info("Starting trainer...") if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") - pretrain_hooks(cfg, trainer) - if cfg.flash_optimum: with torch.backends.cuda.sdp_kernel( # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... @@ -187,15 +193,30 @@ def train( else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) - post_train_hooks(cfg, trainer) +def save_trained_model( + cfg: DictDefault, + trainer: Any, + model: PreTrainedModel, + safe_serialization: bool, +): + """ + Save the trained model according to configuration and training setup. + + Args: + cfg: The configuration dictionary + trainer: The trainer object + model: The trained model to save + safe_serialization: Whether to use safe serialization + """ LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") - # post training + # Post training module hooks for name, module in model.named_modules(): if hasattr(module, "_post_training"): module._post_training(model, name) # pylint: disable=protected-access + # Handle FSDP state dict type state_dict_type = "FULL_STATE_DICT" if trainer.is_fsdp_enabled: if cfg.fsdp_final_state_dict_type: @@ -203,12 +224,13 @@ def train( trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.") + # Handle ReLoRA early return case if cfg.relora_steps: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): model = model.merge_and_unload() else: # final model weights have already been saved by `ReLoRACallback.on_train_end` - return model, tokenizer + return # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file @@ -255,6 +277,15 @@ def train( ) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + +def create_model_card(cfg: DictDefault, trainer: Any): + """ + Create a model card for the trained model if needed. + + Args: + cfg: The configuration dictionary + trainer: The trainer object with model card creation capabilities + """ if not cfg.hub_model_id: try: model_card_kwarg = { @@ -291,22 +322,162 @@ def train( # defensively push to the hub to ensure the model card is updated trainer.push_to_hub() + +def save_initial_configs( + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + model: PreTrainedModel, + peft_config: PeftConfig | None, +): + """ + Save initial configurations before training. + + Args: + cfg: The configuration dictionary + tokenizer: The tokenizer to save + model: The model to save configuration for + peft_config: The PEFT configuration to save if applicable + """ + # go ahead and presave, so we have the adapter config available to inspect + if peft_config: + LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") + peft_config.save_pretrained(cfg.output_dir) + + # additionally presave the tokenizer and model configs + output_dir = Path(cfg.output_dir) + if not output_dir.is_dir(): + os.makedirs(cfg.output_dir, exist_ok=True) + + tokenizer.save_pretrained(str(output_dir)) + if hasattr(model, "config"): + model.config.save_pretrained(str(output_dir)) + + +def setup_badge_for_model_card(): + """Set up the Axolotl badge for the model card.""" + badge_markdown = """[Built with Axolotl](https://github.com/axolotl-ai-cloud/axolotl)""" + transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + + +def add_config_to_model_card(cfg: DictDefault): + """ + Add the Axolotl configuration to the model card if available. + + Args: + cfg: The configuration dictionary with path to axolotl config file + """ + if getattr(cfg, "axolotl_config_path"): + raw_axolotl_cfg = Path(cfg.axolotl_config_path) + version = get_distribution("axolotl").version + if raw_axolotl_cfg.is_file(): + transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n
See axolotl config\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n

\n" + + +def handle_untrained_tokens_fix( + cfg: DictDefault, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + train_dataset: Any, + safe_serialization: bool, +): + """ + Apply fixes for untrained tokens if configured. + + Args: + cfg: The configuration dictionary + model: The model to apply fixes to + tokenizer: The tokenizer for token identification + train_dataset: The training dataset to analyze + safe_serialization: Whether to use safe serialization when saving + """ + if not cfg.fix_untrained_tokens: + return + + # check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args + sig = inspect.signature(fix_untrained_tokens) + + # if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list + if "token_ids_to_fix" in sig.parameters and isinstance( + cfg.fix_untrained_tokens, list + ): + fix_untrained_tokens( + model, + tokenizer, + train_dataset, + token_ids_to_fix=cfg.fix_untrained_tokens, + ) + else: + fix_untrained_tokens(model, tokenizer, train_dataset) + + if cfg.local_rank == 0: + model.save_pretrained( + str(Path(cfg.output_dir)), safe_serialization=safe_serialization + ) + + +def train( + *, cfg: DictDefault, dataset_meta: TrainDatasetMeta +) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: + """ + Train a model on the given dataset. + + Args: + cfg: The configuration dictionary with training parameters + dataset_meta: Metadata about the training dataset + + Returns: + Tuple of (model, tokenizer) after training + """ + # Load tokenizer, processor and model + tokenizer, processor, model, peft_config = setup_model_and_tokenizer(cfg) + + # Set up reference model for RL if needed + model_ref = setup_reference_model(cfg, tokenizer) + + # Determine if we need to resume from a checkpoint + resume_from_checkpoint = determine_resume_checkpoint(cfg) + + # Get datasets from metadata + train_dataset = dataset_meta.train_dataset + eval_dataset = dataset_meta.eval_dataset + total_num_steps = dataset_meta.total_num_steps + + # Configuration for saving + safe_serialization = cfg.save_safetensors is True + + # Set up trainer + trainer = setup_trainer( + cfg, + train_dataset, + eval_dataset, + (model, model_ref, peft_config), + tokenizer, + processor, + total_num_steps, + ) + + # Handle untrained tokens if configured + handle_untrained_tokens_fix( + cfg, model, tokenizer, train_dataset, safe_serialization + ) + + # Save initial configs + save_initial_configs(cfg, tokenizer, model, peft_config) + + # Set up signal handler for graceful termination + setup_signal_handler(cfg, model, safe_serialization) + + # Set up badges and config info for model card + setup_badge_for_model_card() + add_config_to_model_card(cfg) + + # Execute the training + execute_training(cfg, trainer, resume_from_checkpoint) + + # Save the trained model + save_trained_model(cfg, trainer, model, safe_serialization) + + # Create model card + create_model_card(cfg, trainer) + return model, tokenizer - - -def pretrain_hooks(_cfg, _trainer): - """ - Run hooks right before kicking off the training - :param cfg: - :param trainer: - :return: - """ - - -def post_train_hooks(_cfg, _trainer): - """ - Run hooks right after training completes - :param cfg: - :param trainer: - :return: - """