refactor train.py

This commit is contained in:
Dan Saunders
2025-02-26 19:37:42 +00:00
parent 75cbd15301
commit c4104fc10c

View File

@@ -6,15 +6,15 @@ import signal
import sys import sys
import weakref import weakref
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Any, Tuple, Union
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from peft import PeftModel from peft import PeftConfig, PeftModel
from pkg_resources import get_distribution # type: ignore from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
@@ -27,11 +27,13 @@ from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
# Optional imports with graceful fallbacks
try: try:
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
except ImportError: except ImportError:
BetterTransformer = None BetterTransformer = None
# Project setup
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src") src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
@@ -40,9 +42,20 @@ configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
def train( def setup_model_and_tokenizer(
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta cfg: DictDefault,
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: ) -> Tuple[
PreTrainedTokenizer, ProcessorMixin | None, PreTrainedModel, PeftConfig | None
]:
"""
Load the tokenizer, processor (for multimodal models), and model based on configuration.
Args:
cfg: The configuration dictionary with training parameters
Returns:
Tuple containing tokenizer, processor (if multimodal, else None), model, and peft_config (if applicable, else None)
"""
# Load tokenizer # Load tokenizer
LOG.debug( LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
@@ -55,11 +68,58 @@ def train(
if cfg.is_multimodal: if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer) processor = load_processor(cfg, tokenizer)
# Get datasets # Load the model and peft_config
train_dataset = dataset_meta.train_dataset msg = "loading model"
eval_dataset = dataset_meta.eval_dataset if cfg.adapter:
total_num_steps = dataset_meta.total_num_steps msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
return tokenizer, processor, model, peft_config
def setup_reference_model(
cfg: DictDefault, tokenizer: PreTrainedTokenizer
) -> PreTrainedModel | None:
"""
Set up the reference model for RL training if needed.
Args:
cfg: The configuration dictionary
tokenizer: The tokenizer to use for the reference model
Returns:
Reference model if needed for RL training, None otherwise
"""
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
return model_ref
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
"""
Determine the checkpoint to resume from based on configuration.
Args:
cfg: The configuration dictionary
Returns:
Path to the checkpoint to resume from, or None if not resuming
"""
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [ possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
@@ -73,77 +133,22 @@ def train(
LOG.info( LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
) )
resume_from_checkpoint = cfg.resume_from_checkpoint return cfg.resume_from_checkpoint
# Load the model and tokenizer
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
model_ref = None def setup_signal_handler(
if cfg.rl and cfg.rl != "orpo": cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
if cfg.adapter and not cfg.rl_adapter_ref_model: ):
# use built-in trl autounwrap """
LOG.debug("Passing model_ref: None to RL trainer") Set up signal handler for graceful termination.
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
safe_serialization = cfg.save_safetensors is True Args:
cfg: The configuration dictionary
if cfg.unfrozen_parameters: model: The model to save on termination
freeze_layers_except(model, cfg.unfrozen_parameters) safe_serialization: Whether to use safe serialization when saving
"""
trainer = setup_trainer( # ray workers don't have access to this signal
cfg, if cfg.local_rank == 0 and not cfg.use_ray:
train_dataset,
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
if cfg.fix_untrained_tokens:
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if (
cfg.local_rank == 0 and not cfg.use_ray
): # ray workers don't have access to this signal
def terminate_handler(_, __, model_weakref): def terminate_handler(_, __, model_weakref):
if model_weakref() is not None: if model_weakref() is not None:
@@ -161,21 +166,22 @@ def train(
lambda signum, frame: terminate_handler(signum, frame, _model_weakref), lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
) )
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"): def execute_training(
raw_axolotl_cfg = Path(cfg.axolotl_config_path) cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
version = get_distribution("axolotl").version ):
if raw_axolotl_cfg.is_file(): """
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n" Execute the training process with appropriate backend configurations.
Args:
cfg: The configuration dictionary
trainer: The configured trainer object
resume_from_checkpoint: Path to checkpoint to resume from, if applicable
"""
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
if cfg.group_by_length: if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length") LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum: if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ... # TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
@@ -187,15 +193,30 @@ def train(
else: else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
safe_serialization: bool,
):
"""
Save the trained model according to configuration and training setup.
Args:
cfg: The configuration dictionary
trainer: The trainer object
model: The trained model to save
safe_serialization: Whether to use safe serialization
"""
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# post training # Post training module hooks
for name, module in model.named_modules(): for name, module in model.named_modules():
if hasattr(module, "_post_training"): if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access module._post_training(model, name) # pylint: disable=protected-access
# Handle FSDP state dict type
state_dict_type = "FULL_STATE_DICT" state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled: if trainer.is_fsdp_enabled:
if cfg.fsdp_final_state_dict_type: if cfg.fsdp_final_state_dict_type:
@@ -203,12 +224,13 @@ def train(
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.") LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
# Handle ReLoRA early return case
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload() model = model.merge_and_unload()
else: else:
# final model weights have already been saved by `ReLoRACallback.on_train_end` # final model weights have already been saved by `ReLoRACallback.on_train_end`
return model, tokenizer return
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
@@ -255,6 +277,15 @@ def train(
) )
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def create_model_card(cfg: DictDefault, trainer: Any):
"""
Create a model card for the trained model if needed.
Args:
cfg: The configuration dictionary
trainer: The trainer object with model card creation capabilities
"""
if not cfg.hub_model_id: if not cfg.hub_model_id:
try: try:
model_card_kwarg = { model_card_kwarg = {
@@ -291,22 +322,162 @@ def train(
# defensively push to the hub to ensure the model card is updated # defensively push to the hub to ensure the model card is updated
trainer.push_to_hub() trainer.push_to_hub()
def save_initial_configs(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
peft_config: PeftConfig | None,
):
"""
Save initial configurations before training.
Args:
cfg: The configuration dictionary
tokenizer: The tokenizer to save
model: The model to save configuration for
peft_config: The PEFT configuration to save if applicable
"""
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
output_dir = Path(cfg.output_dir)
if not output_dir.is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"):
model.config.save_pretrained(str(output_dir))
def setup_badge_for_model_card():
"""Set up the Axolotl badge for the model card."""
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def add_config_to_model_card(cfg: DictDefault):
"""
Add the Axolotl configuration to the model card if available.
Args:
cfg: The configuration dictionary with path to axolotl config file
"""
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = get_distribution("axolotl").version
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
def handle_untrained_tokens_fix(
cfg: DictDefault,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Any,
safe_serialization: bool,
):
"""
Apply fixes for untrained tokens if configured.
Args:
cfg: The configuration dictionary
model: The model to apply fixes to
tokenizer: The tokenizer for token identification
train_dataset: The training dataset to analyze
safe_serialization: Whether to use safe serialization when saving
"""
if not cfg.fix_untrained_tokens:
return
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
def train(
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
"""
Train a model on the given dataset.
Args:
cfg: The configuration dictionary with training parameters
dataset_meta: Metadata about the training dataset
Returns:
Tuple of (model, tokenizer) after training
"""
# Load tokenizer, processor and model
tokenizer, processor, model, peft_config = setup_model_and_tokenizer(cfg)
# Set up reference model for RL if needed
model_ref = setup_reference_model(cfg, tokenizer)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Get datasets from metadata
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Set up trainer
trainer = setup_trainer(
cfg,
train_dataset,
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
# Handle untrained tokens if configured
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Save initial configs
save_initial_configs(cfg, tokenizer, model, peft_config)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_badge_for_model_card()
add_config_to_model_card(cfg)
# Execute the training
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model
save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer)
return model, tokenizer return model, tokenizer
def pretrain_hooks(_cfg, _trainer):
"""
Run hooks right before kicking off the training
:param cfg:
:param trainer:
:return:
"""
def post_train_hooks(_cfg, _trainer):
"""
Run hooks right after training completes
:param cfg:
:param trainer:
:return:
"""