refactor train.py
This commit is contained in:
@@ -6,15 +6,15 @@ import signal
|
||||
import sys
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import Any, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from peft import PeftModel
|
||||
from peft import PeftConfig, PeftModel
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
@@ -27,11 +27,13 @@ from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
# Optional imports with graceful fallbacks
|
||||
try:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
# Project setup
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
@@ -40,9 +42,20 @@ configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def train(
|
||||
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||
def setup_model_and_tokenizer(
|
||||
cfg: DictDefault,
|
||||
) -> Tuple[
|
||||
PreTrainedTokenizer, ProcessorMixin | None, PreTrainedModel, PeftConfig | None
|
||||
]:
|
||||
"""
|
||||
Load the tokenizer, processor (for multimodal models), and model based on configuration.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary with training parameters
|
||||
|
||||
Returns:
|
||||
Tuple containing tokenizer, processor (if multimodal, else None), model, and peft_config (if applicable, else None)
|
||||
"""
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
@@ -55,11 +68,58 @@ def train(
|
||||
if cfg.is_multimodal:
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
# Load the model and peft_config
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
|
||||
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
# Apply freezing if specified
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
return tokenizer, processor, model, peft_config
|
||||
|
||||
|
||||
def setup_reference_model(
|
||||
cfg: DictDefault, tokenizer: PreTrainedTokenizer
|
||||
) -> PreTrainedModel | None:
|
||||
"""
|
||||
Set up the reference model for RL training if needed.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary
|
||||
tokenizer: The tokenizer to use for the reference model
|
||||
|
||||
Returns:
|
||||
Reference model if needed for RL training, None otherwise
|
||||
"""
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
# use built-in trl autounwrap
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||
return model_ref
|
||||
|
||||
|
||||
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
||||
"""
|
||||
Determine the checkpoint to resume from based on configuration.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary
|
||||
|
||||
Returns:
|
||||
Path to the checkpoint to resume from, or None if not resuming
|
||||
"""
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
@@ -73,77 +133,22 @@ def train(
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
return cfg.resume_from_checkpoint
|
||||
|
||||
# Load the model and tokenizer
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
# use built-in trl autounwrap
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||
def setup_signal_handler(
|
||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||
):
|
||||
"""
|
||||
Set up signal handler for graceful termination.
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
(model, model_ref, peft_config),
|
||||
tokenizer,
|
||||
processor,
|
||||
total_num_steps,
|
||||
)
|
||||
|
||||
if cfg.fix_untrained_tokens:
|
||||
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||
sig = inspect.signature(fix_untrained_tokens)
|
||||
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||
cfg.fix_untrained_tokens, list
|
||||
):
|
||||
fix_untrained_tokens(
|
||||
model,
|
||||
tokenizer,
|
||||
train_dataset,
|
||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||
)
|
||||
else:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
peft_config.save_pretrained(cfg.output_dir)
|
||||
# additionally presave the tokenizer and model configs
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if (
|
||||
cfg.local_rank == 0 and not cfg.use_ray
|
||||
): # ray workers don't have access to this signal
|
||||
Args:
|
||||
cfg: The configuration dictionary
|
||||
model: The model to save on termination
|
||||
safe_serialization: Whether to use safe serialization when saving
|
||||
"""
|
||||
# ray workers don't have access to this signal
|
||||
if cfg.local_rank == 0 and not cfg.use_ray:
|
||||
|
||||
def terminate_handler(_, __, model_weakref):
|
||||
if model_weakref() is not None:
|
||||
@@ -161,21 +166,22 @@ def train(
|
||||
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||
)
|
||||
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
if getattr(cfg, "axolotl_config_path"):
|
||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||
version = get_distribution("axolotl").version
|
||||
if raw_axolotl_cfg.is_file():
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||
def execute_training(
|
||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||
):
|
||||
"""
|
||||
Execute the training process with appropriate backend configurations.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary
|
||||
trainer: The configured trainer object
|
||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable
|
||||
"""
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
@@ -187,15 +193,30 @@ def train(
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
trainer: Any,
|
||||
model: PreTrainedModel,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
Save the trained model according to configuration and training setup.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary
|
||||
trainer: The trainer object
|
||||
model: The trained model to save
|
||||
safe_serialization: Whether to use safe serialization
|
||||
"""
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# post training
|
||||
# Post training module hooks
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "_post_training"):
|
||||
module._post_training(model, name) # pylint: disable=protected-access
|
||||
|
||||
# Handle FSDP state dict type
|
||||
state_dict_type = "FULL_STATE_DICT"
|
||||
if trainer.is_fsdp_enabled:
|
||||
if cfg.fsdp_final_state_dict_type:
|
||||
@@ -203,12 +224,13 @@ def train(
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
|
||||
|
||||
# Handle ReLoRA early return case
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
model = model.merge_and_unload()
|
||||
else:
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return model, tokenizer
|
||||
return
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
@@ -255,6 +277,15 @@ def train(
|
||||
)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def create_model_card(cfg: DictDefault, trainer: Any):
|
||||
"""
|
||||
Create a model card for the trained model if needed.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary
|
||||
trainer: The trainer object with model card creation capabilities
|
||||
"""
|
||||
if not cfg.hub_model_id:
|
||||
try:
|
||||
model_card_kwarg = {
|
||||
@@ -291,22 +322,162 @@ def train(
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
trainer.push_to_hub()
|
||||
|
||||
|
||||
def save_initial_configs(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PreTrainedModel,
|
||||
peft_config: PeftConfig | None,
|
||||
):
|
||||
"""
|
||||
Save initial configurations before training.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary
|
||||
tokenizer: The tokenizer to save
|
||||
model: The model to save configuration for
|
||||
peft_config: The PEFT configuration to save if applicable
|
||||
"""
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
peft_config.save_pretrained(cfg.output_dir)
|
||||
|
||||
# additionally presave the tokenizer and model configs
|
||||
output_dir = Path(cfg.output_dir)
|
||||
if not output_dir.is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
|
||||
tokenizer.save_pretrained(str(output_dir))
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(str(output_dir))
|
||||
|
||||
|
||||
def setup_badge_for_model_card():
|
||||
"""Set up the Axolotl badge for the model card."""
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
|
||||
def add_config_to_model_card(cfg: DictDefault):
|
||||
"""
|
||||
Add the Axolotl configuration to the model card if available.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary with path to axolotl config file
|
||||
"""
|
||||
if getattr(cfg, "axolotl_config_path"):
|
||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||
version = get_distribution("axolotl").version
|
||||
if raw_axolotl_cfg.is_file():
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||
|
||||
|
||||
def handle_untrained_tokens_fix(
|
||||
cfg: DictDefault,
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
train_dataset: Any,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
Apply fixes for untrained tokens if configured.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary
|
||||
model: The model to apply fixes to
|
||||
tokenizer: The tokenizer for token identification
|
||||
train_dataset: The training dataset to analyze
|
||||
safe_serialization: Whether to use safe serialization when saving
|
||||
"""
|
||||
if not cfg.fix_untrained_tokens:
|
||||
return
|
||||
|
||||
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||
sig = inspect.signature(fix_untrained_tokens)
|
||||
|
||||
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||
cfg.fix_untrained_tokens, list
|
||||
):
|
||||
fix_untrained_tokens(
|
||||
model,
|
||||
tokenizer,
|
||||
train_dataset,
|
||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||
)
|
||||
else:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
|
||||
def train(
|
||||
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||
"""
|
||||
Train a model on the given dataset.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary with training parameters
|
||||
dataset_meta: Metadata about the training dataset
|
||||
|
||||
Returns:
|
||||
Tuple of (model, tokenizer) after training
|
||||
"""
|
||||
# Load tokenizer, processor and model
|
||||
tokenizer, processor, model, peft_config = setup_model_and_tokenizer(cfg)
|
||||
|
||||
# Set up reference model for RL if needed
|
||||
model_ref = setup_reference_model(cfg, tokenizer)
|
||||
|
||||
# Determine if we need to resume from a checkpoint
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
|
||||
# Get datasets from metadata
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# Configuration for saving
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
# Set up trainer
|
||||
trainer = setup_trainer(
|
||||
cfg,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
(model, model_ref, peft_config),
|
||||
tokenizer,
|
||||
processor,
|
||||
total_num_steps,
|
||||
)
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
handle_untrained_tokens_fix(
|
||||
cfg, model, tokenizer, train_dataset, safe_serialization
|
||||
)
|
||||
|
||||
# Save initial configs
|
||||
save_initial_configs(cfg, tokenizer, model, peft_config)
|
||||
|
||||
# Set up signal handler for graceful termination
|
||||
setup_signal_handler(cfg, model, safe_serialization)
|
||||
|
||||
# Set up badges and config info for model card
|
||||
setup_badge_for_model_card()
|
||||
add_config_to_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# Save the trained model
|
||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||
|
||||
# Create model card
|
||||
create_model_card(cfg, trainer)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pretrain_hooks(_cfg, _trainer):
|
||||
"""
|
||||
Run hooks right before kicking off the training
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
|
||||
|
||||
def post_train_hooks(_cfg, _trainer):
|
||||
"""
|
||||
Run hooks right after training completes
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user