|
|
|
|
@@ -1,19 +1,20 @@
|
|
|
|
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
|
|
|
|
|
|
|
|
|
import importlib
|
|
|
|
|
import inspect
|
|
|
|
|
import os
|
|
|
|
|
import signal
|
|
|
|
|
import sys
|
|
|
|
|
import weakref
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any, Tuple, Union
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import transformers.modelcard
|
|
|
|
|
from accelerate.logging import get_logger
|
|
|
|
|
from accelerate.utils import save_fsdp_model
|
|
|
|
|
from datasets import Dataset
|
|
|
|
|
from peft import PeftConfig, PeftModel
|
|
|
|
|
from pkg_resources import get_distribution # type: ignore
|
|
|
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
|
|
|
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
|
|
|
from transformers.trainer import Trainer
|
|
|
|
|
@@ -22,6 +23,7 @@ from axolotl.common.datasets import TrainDatasetMeta
|
|
|
|
|
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
|
|
|
|
fix_untrained_tokens,
|
|
|
|
|
)
|
|
|
|
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
|
|
|
|
from axolotl.logging_config import configure_logging
|
|
|
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
from axolotl.utils.freeze import freeze_layers_except
|
|
|
|
|
@@ -39,17 +41,18 @@ LOG = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
def setup_model_and_tokenizer(
|
|
|
|
|
cfg: DictDefault,
|
|
|
|
|
) -> Tuple[
|
|
|
|
|
PreTrainedTokenizer, ProcessorMixin | None, PreTrainedModel, PeftConfig | None
|
|
|
|
|
) -> tuple[
|
|
|
|
|
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
|
|
|
|
|
]:
|
|
|
|
|
"""
|
|
|
|
|
Load the tokenizer, processor (for multimodal models), and model based on configuration.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary with training parameters
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple containing tokenizer, processor (if multimodal, else None), model, and peft_config (if applicable, else None)
|
|
|
|
|
Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
|
|
|
|
|
`None`), and processor (if multimodal, else `None`).
|
|
|
|
|
"""
|
|
|
|
|
# Load tokenizer
|
|
|
|
|
LOG.debug(
|
|
|
|
|
@@ -77,7 +80,7 @@ def setup_model_and_tokenizer(
|
|
|
|
|
if cfg.unfrozen_parameters:
|
|
|
|
|
freeze_layers_except(model, cfg.unfrozen_parameters)
|
|
|
|
|
|
|
|
|
|
return tokenizer, processor, model, peft_config
|
|
|
|
|
return model, tokenizer, peft_config, processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_reference_model(
|
|
|
|
|
@@ -87,11 +90,11 @@ def setup_reference_model(
|
|
|
|
|
Set up the reference model for RL training if needed.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary
|
|
|
|
|
tokenizer: The tokenizer to use for the reference model
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
tokenizer: The tokenizer to use for the reference model.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Reference model if needed for RL training, None otherwise
|
|
|
|
|
Reference model if needed for RL training, `None` otherwise.
|
|
|
|
|
"""
|
|
|
|
|
model_ref = None
|
|
|
|
|
if cfg.rl and cfg.rl != "orpo":
|
|
|
|
|
@@ -110,10 +113,10 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
|
|
|
|
Determine the checkpoint to resume from based on configuration.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Path to the checkpoint to resume from, or None if not resuming
|
|
|
|
|
Path to the checkpoint to resume from, or `None` if not resuming.
|
|
|
|
|
"""
|
|
|
|
|
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
|
|
|
|
possible_checkpoints = [
|
|
|
|
|
@@ -138,7 +141,7 @@ def setup_signal_handler(
|
|
|
|
|
Set up signal handler for graceful termination.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
model: The model to save on termination
|
|
|
|
|
safe_serialization: Whether to use safe serialization when saving
|
|
|
|
|
"""
|
|
|
|
|
@@ -169,9 +172,9 @@ def execute_training(
|
|
|
|
|
Execute the training process with appropriate backend configurations.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary
|
|
|
|
|
trainer: The configured trainer object
|
|
|
|
|
resume_from_checkpoint: Path to checkpoint to resume from, if applicable
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
trainer: The configured trainer object.
|
|
|
|
|
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
|
|
|
|
"""
|
|
|
|
|
LOG.info("Starting trainer...")
|
|
|
|
|
if cfg.group_by_length:
|
|
|
|
|
@@ -199,12 +202,12 @@ def save_trained_model(
|
|
|
|
|
Save the trained model according to configuration and training setup.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary
|
|
|
|
|
trainer: The trainer object
|
|
|
|
|
model: The trained model to save
|
|
|
|
|
safe_serialization: Whether to use safe serialization
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
trainer: The trainer object.
|
|
|
|
|
model: The trained model to save.
|
|
|
|
|
safe_serialization: Whether to use safe serialization.
|
|
|
|
|
"""
|
|
|
|
|
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
|
|
|
|
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
|
|
|
|
|
|
|
|
|
|
# Post training module hooks
|
|
|
|
|
for name, module in model.named_modules():
|
|
|
|
|
@@ -278,8 +281,8 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
|
|
|
|
Create a model card for the trained model if needed.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary
|
|
|
|
|
trainer: The trainer object with model card creation capabilities
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
trainer: The trainer object with model card creation capabilities.
|
|
|
|
|
"""
|
|
|
|
|
if not cfg.hub_model_id:
|
|
|
|
|
# Guard since create_model_card may fail if dataset_tags is empty list
|
|
|
|
|
@@ -289,29 +292,23 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
|
|
|
|
.encode("utf-8")
|
|
|
|
|
.decode("utf-8")
|
|
|
|
|
}
|
|
|
|
|
if cfg.datasets is not None:
|
|
|
|
|
|
|
|
|
|
# We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
|
|
|
|
|
rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
|
|
|
|
|
if cfg.datasets is not None and not rl:
|
|
|
|
|
dataset_tags = [
|
|
|
|
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
|
|
|
|
]
|
|
|
|
|
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
|
|
|
|
|
|
|
|
|
|
if dataset_tags:
|
|
|
|
|
param_name = (
|
|
|
|
|
"dataset_name"
|
|
|
|
|
if (
|
|
|
|
|
cfg.rl is not None
|
|
|
|
|
or cfg.reward_model
|
|
|
|
|
or cfg.process_reward_model
|
|
|
|
|
)
|
|
|
|
|
else "dataset_tags"
|
|
|
|
|
)
|
|
|
|
|
model_card_kwarg[param_name] = dataset_tags
|
|
|
|
|
model_card_kwarg["dataset_tags"] = dataset_tags
|
|
|
|
|
|
|
|
|
|
trainer.create_model_card(**model_card_kwarg)
|
|
|
|
|
except (AttributeError, UnicodeDecodeError):
|
|
|
|
|
pass
|
|
|
|
|
elif cfg.hub_model_id:
|
|
|
|
|
# defensively push to the hub to ensure the model card is updated
|
|
|
|
|
# Defensively push to the hub to ensure the model card is updated
|
|
|
|
|
trainer.push_to_hub()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -325,23 +322,26 @@ def save_initial_configs(
|
|
|
|
|
Save initial configurations before training.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary
|
|
|
|
|
tokenizer: The tokenizer to save
|
|
|
|
|
model: The model to save configuration for
|
|
|
|
|
peft_config: The PEFT configuration to save if applicable
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
tokenizer: The tokenizer to save.
|
|
|
|
|
model: The model to save configuration for.
|
|
|
|
|
peft_config: The PEFT configuration to save if applicable.
|
|
|
|
|
"""
|
|
|
|
|
# go ahead and presave, so we have the adapter config available to inspect
|
|
|
|
|
if peft_config:
|
|
|
|
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
|
|
|
|
peft_config.save_pretrained(cfg.output_dir)
|
|
|
|
|
|
|
|
|
|
# additionally presave the tokenizer and model configs
|
|
|
|
|
# Create output_dir if it doesn't already exist
|
|
|
|
|
output_dir = Path(cfg.output_dir)
|
|
|
|
|
if not output_dir.is_dir():
|
|
|
|
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# Pre-save adapter config so it's available to inspect
|
|
|
|
|
if peft_config:
|
|
|
|
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
|
|
|
|
|
peft_config.save_pretrained(cfg.output_dir)
|
|
|
|
|
|
|
|
|
|
# Pre-save the tokenizer and model configs
|
|
|
|
|
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
|
|
|
|
|
tokenizer.save_pretrained(str(output_dir))
|
|
|
|
|
if hasattr(model, "config"):
|
|
|
|
|
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
|
|
|
|
model.config.save_pretrained(str(output_dir))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -350,14 +350,14 @@ def setup_model_card(cfg: DictDefault):
|
|
|
|
|
Set up the Axolotl badge and add the Axolotl config to the model card if available.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary with path to axolotl config file
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
"""
|
|
|
|
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
|
|
|
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
|
|
|
|
|
|
|
|
|
if getattr(cfg, "axolotl_config_path"):
|
|
|
|
|
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
|
|
|
|
version = get_distribution("axolotl").version
|
|
|
|
|
version = importlib.metadata.version("axolotl")
|
|
|
|
|
if raw_axolotl_cfg.is_file():
|
|
|
|
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
|
|
|
|
|
|
|
|
|
@@ -366,26 +366,26 @@ def handle_untrained_tokens_fix(
|
|
|
|
|
cfg: DictDefault,
|
|
|
|
|
model: PreTrainedModel,
|
|
|
|
|
tokenizer: PreTrainedTokenizer,
|
|
|
|
|
train_dataset: Any,
|
|
|
|
|
train_dataset: Dataset,
|
|
|
|
|
safe_serialization: bool,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Apply fixes for untrained tokens if configured.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary
|
|
|
|
|
model: The model to apply fixes to
|
|
|
|
|
tokenizer: The tokenizer for token identification
|
|
|
|
|
train_dataset: The training dataset to analyze
|
|
|
|
|
safe_serialization: Whether to use safe serialization when saving
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
model: The model to apply fixes to.
|
|
|
|
|
tokenizer: The tokenizer for token identification.
|
|
|
|
|
train_dataset: The training dataset to use.
|
|
|
|
|
safe_serialization: Whether to use safe serialization when saving.
|
|
|
|
|
"""
|
|
|
|
|
if not cfg.fix_untrained_tokens:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
|
|
|
|
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
|
|
|
|
sig = inspect.signature(fix_untrained_tokens)
|
|
|
|
|
|
|
|
|
|
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
|
|
|
|
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
|
|
|
|
if "token_ids_to_fix" in sig.parameters and isinstance(
|
|
|
|
|
cfg.fix_untrained_tokens, list
|
|
|
|
|
):
|
|
|
|
|
@@ -404,48 +404,90 @@ def handle_untrained_tokens_fix(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(
|
|
|
|
|
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
|
|
|
|
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
|
|
|
|
def setup_model_and_trainer(
|
|
|
|
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
|
|
|
|
) -> tuple[
|
|
|
|
|
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
|
|
|
|
PeftModel | PreTrainedModel,
|
|
|
|
|
PreTrainedTokenizer,
|
|
|
|
|
PeftConfig | None,
|
|
|
|
|
]:
|
|
|
|
|
"""
|
|
|
|
|
Train a model on the given dataset.
|
|
|
|
|
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
|
|
|
|
trainer setup.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary with training parameters
|
|
|
|
|
dataset_meta: Metadata about the training dataset
|
|
|
|
|
cfg: The configuration dictionary with training parameters.
|
|
|
|
|
dataset_meta: Object with training, validation datasets and metadata.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (model, tokenizer) after training
|
|
|
|
|
Tuple of:
|
|
|
|
|
- Trainer (Causal or RLHF)
|
|
|
|
|
- Model
|
|
|
|
|
- Tokenizer
|
|
|
|
|
- PEFT config
|
|
|
|
|
"""
|
|
|
|
|
# Load tokenizer, processor and model
|
|
|
|
|
tokenizer, processor, model, peft_config = setup_model_and_tokenizer(cfg)
|
|
|
|
|
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
|
|
|
|
|
|
|
|
|
# Set up reference model for RL if needed
|
|
|
|
|
model_ref = setup_reference_model(cfg, tokenizer)
|
|
|
|
|
|
|
|
|
|
# Determine if we need to resume from a checkpoint
|
|
|
|
|
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
|
|
|
|
|
|
|
|
|
# Get datasets from metadata
|
|
|
|
|
train_dataset = dataset_meta.train_dataset
|
|
|
|
|
eval_dataset = dataset_meta.eval_dataset
|
|
|
|
|
total_num_steps = dataset_meta.total_num_steps
|
|
|
|
|
|
|
|
|
|
# Set up trainer
|
|
|
|
|
trainer = setup_trainer(
|
|
|
|
|
cfg=cfg,
|
|
|
|
|
train_dataset=train_dataset,
|
|
|
|
|
eval_dataset=eval_dataset,
|
|
|
|
|
model=model,
|
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
|
processor=processor,
|
|
|
|
|
total_num_steps=total_num_steps,
|
|
|
|
|
model_ref=model_ref,
|
|
|
|
|
peft_config=peft_config,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
trainer,
|
|
|
|
|
model,
|
|
|
|
|
tokenizer,
|
|
|
|
|
peft_config,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(
|
|
|
|
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
|
|
|
|
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
|
|
|
|
|
"""
|
|
|
|
|
Train a model on the given dataset.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cfg: The configuration dictionary with training parameters
|
|
|
|
|
dataset_meta: Object with training, validation datasets and metadata
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (model, tokenizer) after training
|
|
|
|
|
"""
|
|
|
|
|
# Setup model, tokenizer, (causal or RLHF) trainer etc.
|
|
|
|
|
(
|
|
|
|
|
trainer,
|
|
|
|
|
model,
|
|
|
|
|
tokenizer,
|
|
|
|
|
peft_config,
|
|
|
|
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
|
|
|
|
|
|
|
|
|
# Determine if we need to resume from a checkpoint
|
|
|
|
|
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
|
|
|
|
|
|
|
|
|
# Configuration for saving
|
|
|
|
|
safe_serialization = cfg.save_safetensors is True
|
|
|
|
|
|
|
|
|
|
# Set up trainer
|
|
|
|
|
trainer = setup_trainer(
|
|
|
|
|
cfg,
|
|
|
|
|
train_dataset,
|
|
|
|
|
eval_dataset,
|
|
|
|
|
(model, model_ref, peft_config),
|
|
|
|
|
tokenizer,
|
|
|
|
|
processor,
|
|
|
|
|
total_num_steps,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Handle untrained tokens if configured
|
|
|
|
|
train_dataset = dataset_meta.train_dataset
|
|
|
|
|
handle_untrained_tokens_fix(
|
|
|
|
|
cfg, model, tokenizer, train_dataset, safe_serialization
|
|
|
|
|
)
|
|
|
|
|
|