Compare commits

...

3 Commits

Author SHA1 Message Date
Dan Saunders
103edc7211 refactor build() into smaller fns 2025-05-12 20:36:52 +00:00
Wing Lian
c7b6790614 Various fixes for CI, save_only_model for RL, prevent packing multiprocessing deadlocks (#2661)
* lean mistral ft tests, remove e2e torch 2.4.1 test

* make sure to pass save_only_model for RL

* more tests to make ci leaner, add cleanup to modal ci

* fix module for import in e2e tests

* use mp spawn to prevent deadlocks with packing

* make sure cleanup shell script is executable when cloned out
2025-05-12 10:51:18 -04:00
Dan Saunders
47e0e71bc8 don't sort multipack sampler (#2657)
* don't sort multipack sampler

* increased packing efficiency increases loss

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-09 20:28:58 -04:00
24 changed files with 1545 additions and 1361 deletions

View File

@@ -335,12 +335,6 @@ jobs:
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -377,3 +371,43 @@ jobs:
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
docker-e2e-cleanup:
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [docker-e2e-tests]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.cleanup

0
cicd/__init__.py Normal file
View File

View File

@@ -18,7 +18,7 @@ pytest -v --durations=10 \
--cov-append
# Run patched tests excluding lora kernels with coverage append
pytest -v --durations=10 \
pytest --full-trace -vvv --durations=10 \
--ignore=tests/e2e/patched/lora_kernels \
/workspace/axolotl/tests/e2e/patched \
--cov=axolotl \

19
cicd/cleanup.py Normal file
View File

@@ -0,0 +1,19 @@
"""Modal app to run axolotl GPU cleanup"""
from .single_gpu import VOLUME_CONFIG, app, cicd_image, run_cmd
@app.function(
image=cicd_image,
timeout=60 * 60,
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cleanup():
run_cmd("./cicd/cleanup.sh", "/workspace/axolotl")
@app.local_entrypoint()
def main():
cleanup.remote()

6
cicd/cleanup.sh Executable file
View File

@@ -0,0 +1,6 @@
#!/bin/bash
set -e
# cleanup old cache files for datasets processing and intermediate mappings
find /workspace/data/huggingface-cache/hub/datasets -name "cache-*" -type f -mtime +1 -exec rm {} \;
find /workspace/data/huggingface-cache/hub/datasets -name "*.lock" -type f -mtime +1 -exec rm {} \;

View File

@@ -1,69 +1,6 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
@app.function(

66
cicd/single_gpu.py Normal file
View File

@@ -0,0 +1,66 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Init for axolotl.core.trainers.builders"""
# pylint: disable=unused-import
# flake8: noqa
from .causal import HFCausalTrainerBuilder
from .rl import HFRLTrainerBuilder

View File

@@ -0,0 +1,331 @@
"""Base class trainer / training args builder implementation"""
import abc
from typing import Any
from torch import Type
from transformers import TrainerCallback
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import GCCallback, SaveAxolotlConfigtoWandBCallback
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
PLUGIN_MANAGER = PluginManager.get_instance()
class TrainerBuilderBase(abc.ABC):
"""Base class for trainer builder."""
_train_dataset = None
_eval_dataset = None
_model_ref = None
_peft_config = None
def __init__(self, cfg, model, tokenizer, processor=None):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor
# If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
patch_trainer_get_lr()
@property
def model_ref(self):
return self._model_ref
@model_ref.setter
def model_ref(self, model):
self._model_ref = model
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@property
def peft_config(self):
return self._peft_config
@peft_config.setter
def peft_config(self, peft_config):
self._peft_config = peft_config
@abc.abstractmethod
def build(self, total_num_steps):
pass
def get_common_training_args_kwargs(
self, total_num_steps: int | None = None
) -> dict[str, Any]:
"""Get common training arguments kwargs used across different trainer types."""
training_args_kwargs = {}
# Common parameters
for arg in [
"adam_beta1",
"adam_beta2",
"adam_epsilon",
"max_grad_norm",
"dataloader_num_workers",
"dataloader_pin_memory",
"dataloader_prefetch_factor",
"dataloader_drop_last",
"remove_unused_columns",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
# Add Hub integration arguments if needed
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
# BF16/FP16 settings
if hasattr(self.cfg, "bf16") and self.cfg.bf16:
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
training_args_kwargs["bf16"] = self.cfg.bf16
elif hasattr(self.cfg, "bfloat16") and self.cfg.bfloat16:
training_args_kwargs["bf16"] = True
if hasattr(self.cfg, "fp16"):
training_args_kwargs["fp16"] = (
getattr(self.cfg, "fp16", False)
and not getattr(self.cfg, "bf16", False)
) or False
# Set save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
# Handle safetensors
if self.cfg.save_safetensors is not None:
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
# Handle gradient checkpointing
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
# Common optimizer and LR scheduler settings
training_args_kwargs["optim"] = self.cfg.optimizer
if hasattr(self.cfg, "lr_scheduler") and self.cfg.lr_scheduler:
training_args_kwargs["lr_scheduler_type"] = self.cfg.lr_scheduler
else:
training_args_kwargs["lr_scheduler_type"] = "cosine"
if hasattr(self.cfg, "lr_scheduler_kwargs") and self.cfg.lr_scheduler_kwargs:
training_args_kwargs["lr_scheduler_kwargs"] = self.cfg.lr_scheduler_kwargs
else:
training_args_kwargs["lr_scheduler_kwargs"] = {}
# LoRA+ specific settings
if hasattr(self.cfg, "loraplus_lr_ratio"):
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
if hasattr(self.cfg, "loraplus_lr_embedding"):
training_args_kwargs["loraplus_lr_embedding"] = (
self.cfg.loraplus_lr_embedding
)
# Reporting tools
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_args_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
if report_to:
training_args_kwargs["report_to"] = report_to
# Basic training settings
if hasattr(self.cfg, "sequence_len"):
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["save_only_model"] = getattr(
self.cfg, "save_only_model", False
)
training_args_kwargs["save_total_limit"] = getattr(
self.cfg, "save_total_limit", 5
)
# Compute warmup steps
if hasattr(self.cfg, "warmup_steps") and self.cfg.warmup_steps is not None:
training_args_kwargs["warmup_steps"] = self.cfg.warmup_steps
elif (
total_num_steps
and hasattr(self.cfg, "warmup_ratio")
and self.cfg.warmup_ratio is not None
):
training_args_kwargs["warmup_steps"] = max(
int(self.cfg.warmup_ratio * total_num_steps), 0
)
elif total_num_steps:
training_args_kwargs["warmup_steps"] = min(int(0.03 * total_num_steps), 100)
return training_args_kwargs
def create_training_args(
self,
args_cls: Type[TrainingArguments],
total_num_steps: int | None = None,
**additional_kwargs,
) -> TrainingArguments:
"""Create training arguments with common logic."""
# Get common trainings args and update with trainer-specific args
training_args_kwargs = self.get_common_training_args_kwargs(total_num_steps)
training_args_kwargs.update(additional_kwargs)
# Create training args with pre- and post-creation hooks
training_args_kwargs = self.hook_pre_create_training_args(training_args_kwargs)
training_args = args_cls(**training_args_kwargs)
training_args = self.hook_post_create_training_args(training_args)
# Unset run_name so wandb sets up experiment names properly
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = None
return training_args
def create_trainer(
self, trainer_cls, training_args, trainer_args=None, trainer_kwargs=None
):
"""Create trainer with common logic."""
if trainer_args is None:
trainer_args = []
if trainer_kwargs is None:
trainer_kwargs = {}
# Create trainer with pre- and post- creation hooks
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
*trainer_args,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
# Add post-creation callbacks
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
return trainer
def get_callbacks(self) -> list[TrainerCallback]:
callbacks = []
callbacks.extend(
PLUGIN_MANAGER.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
"""Callbacks added after the trainer is created, usually because these need
access to the trainer.
"""
callbacks = []
if self.cfg.plugins:
callbacks.extend(
[
cb
for cb in PLUGIN_MANAGER.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
return trainer

View File

@@ -0,0 +1,619 @@
"""Causal trainer / training args builder implementation"""
import importlib
import inspect
import logging
import math
import os
import sys
from pathlib import Path
from typing import Type
import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
)
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.core.trainers.builders.base import TrainerBuilderBase
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
from axolotl.core.trainers.relora import ReLoRATrainer
from axolotl.core.trainers.trl import AxolotlPRMTrainer, AxolotlRewardTrainer
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators.batching import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mamba import MambaDataCollator
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""Build the HuggingFace training args / trainer for causal models and reward
modeling using TRL.
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(GPUStatsCallback(self.cfg))
callbacks.append(EvalFirstStepCallback())
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback())
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
def _get_trainer_cls(self):
if self.cfg.plugins:
trainer_cls = PLUGIN_MANAGER.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer
def build(self, total_num_steps):
"""Build and return a causal trainer instance using the refactored base class."""
# Get trainer class
trainer_cls = self._get_trainer_cls()
# Prepare training arguments
training_args = self._prepare_training_args(total_num_steps)
# Prepare data collators
data_collator_kwargs = self._prepare_data_collator_kwargs()
# Prepare trainer kwargs
trainer_kwargs = self._prepare_trainer_kwargs(
trainer_cls=trainer_cls,
data_collator_kwargs=data_collator_kwargs,
training_args=training_args,
)
# Create the trainer
trainer = self.create_trainer(
trainer_cls=trainer_cls,
training_args=training_args,
trainer_kwargs={
"model": self.model,
"data_collator": self.build_collator(
training_args, **data_collator_kwargs
),
**trainer_kwargs,
},
)
# Handle DeepSpeed config for sample packing if needed
if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size
return trainer
def _prepare_training_args(self, total_num_steps):
"""Prepare and return training arguments."""
# Base training arguments
training_args_kwargs = self._get_base_training_args()
# Add feature configurations
self._add_feature_configs(training_args_kwargs)
# Handle optimizer configuration
self._configure_optimizer(training_args_kwargs)
# Create training args using the base class method
training_args_cls = self._get_training_args_cls()
return self.create_training_args(
args_cls=training_args_cls,
total_num_steps=total_num_steps,
**training_args_kwargs,
)
def _get_base_training_args(self):
"""Return the base training arguments."""
return {
"max_steps": self.cfg.max_steps if self.cfg.max_steps else -1,
"max_seq_length": self.cfg.sequence_len,
"per_device_train_batch_size": self.cfg.micro_batch_size,
"gradient_accumulation_steps": self.cfg.gradient_accumulation_steps,
"eval_accumulation_steps": self.cfg.gradient_accumulation_steps,
"num_train_epochs": self.cfg.num_epochs,
"learning_rate": self.cfg.learning_rate,
"output_dir": self.cfg.output_dir,
"weight_decay": (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
),
"model_type": self.cfg.model_config_type,
"pretraining": bool(self.cfg.pretraining_dataset),
"sequence_parallel_degree": self.cfg.sequence_parallel_degree,
"ring_attn_func": self.cfg.ring_attn_func,
"embedding_lr": self.cfg.embedding_lr,
"embedding_lr_scale": self.cfg.embedding_lr_scale,
"loraplus_lr_ratio": self.cfg.loraplus_lr_ratio,
"loraplus_lr_embedding": self.cfg.loraplus_lr_embedding,
"lr_groups": self.cfg.lr_groups,
}
def _add_feature_configs(self, training_args_kwargs):
"""Add various feature configurations."""
# Sample packing configurations
self._add_sample_packing_configs(training_args_kwargs)
# Batch size configurations
if self.cfg.eval_batch_size:
training_args_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
if self.cfg.auto_find_batch_size is not None:
training_args_kwargs["auto_find_batch_size"] = self.cfg.auto_find_batch_size
# Advanced training techniques (ReLoRA & Lisa)
self._add_advanced_training_configs(training_args_kwargs)
# Model-specific configurations
self._add_model_specific_configs(training_args_kwargs)
def _add_sample_packing_configs(self, training_args_kwargs):
"""Add sample packing configurations if applicable."""
if hasattr(self.cfg, "sample_packing") and self.cfg.sample_packing:
training_args_kwargs.update(
{
"sample_packing": bool(self.cfg.sample_packing),
"multipack_real_batches": not self.cfg.flash_attention
or self.cfg.multipack_real_batches,
"eval_sample_packing": bool(self.cfg.eval_sample_packing),
}
)
if self.cfg.sample_packing_bin_size is not None:
training_args_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size
)
if self.cfg.sample_packing_group_size is not None:
training_args_kwargs["sample_packing_group_size"] = (
self.cfg.sample_packing_group_size
)
if self.cfg.sample_packing_eff_est:
training_args_kwargs["sample_packing_efficiency"] = (
self.cfg.sample_packing_eff_est
)
def _add_advanced_training_configs(self, training_args_kwargs):
"""Add advanced training techniques configurations (ReLoRA & Lisa)."""
# ReLoRA configurations
if self.cfg.relora_steps:
training_args_kwargs.update(
{
"relora_steps": self.cfg.relora_steps,
"relora_warmup_steps": self.cfg.relora_warmup_steps,
}
)
if self.cfg.relora_anneal_steps:
training_args_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora_prune_ratio:
training_args_kwargs["relora_prune_ratio"] = self.cfg.relora_prune_ratio
# Lisa configurations
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_args_kwargs.update(
{
"lisa_n_layers": self.cfg.lisa_n_layers,
"lisa_step_interval": self.cfg.lisa_step_interval,
"lisa_layers_attribute": self.cfg.lisa_layers_attribute,
}
)
def _add_model_specific_configs(self, training_args_kwargs):
"""Add model-specific configurations."""
# Chat template
if self.cfg.chat_template:
training_args_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
# NEFTune
if self.cfg.neftune_noise_alpha is not None:
training_args_kwargs["neftune_noise_alpha"] = self.cfg.neftune_noise_alpha
# Knowledge distillation configurations
if self.cfg.kd_ce_alpha is not None:
training_args_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
training_args_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_args_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_args_kwargs["kd_zscore_base_temp"] = self.cfg.kd_zscore_base_temp
if self.cfg.kd_top_k_before_softmax is not None:
training_args_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
# Image configurations
if self.cfg.image_size:
training_args_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm:
training_args_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
# Accelerator configuration
if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
def _configure_optimizer(self, training_args_kwargs):
"""Configure optimizer settings."""
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
# Use custom optimizer implementation
self._configure_custom_optimizer(training_args_kwargs)
else:
# Use transformers' optimizer
training_args_kwargs["optim"] = self.cfg.optimizer
self._add_optimizer_args(training_args_kwargs)
# Handle optimizer targeting specific modules
if self.cfg.optim_target_modules:
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
# Special case for anyprecision optimizer
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
def _configure_custom_optimizer(self, training_args_kwargs):
"""Configure custom optimizer settings."""
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_args_kwargs.get("learning_rate"),
"weight_decay": training_args_kwargs.get("weight_decay"),
}
# Add Adam-specific kwargs if available
adam_kwargs = self._get_adam_kwargs(training_args_kwargs)
# Get optimizer class and update kwargs based on optimizer type
optimizer_cls = self._get_optimizer_class(
training_args_kwargs, optimizer_kwargs, adam_kwargs
)
# Add any additional optimizer args from config
self._update_optimizer_kwargs_from_config(optimizer_kwargs)
training_args_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
)
def _get_adam_kwargs(self, training_args_kwargs):
"""Get Adam-specific kwargs if available."""
adam_kwargs = {}
if training_args_kwargs.get("adam_beta1") and training_args_kwargs.get(
"adam_beta2"
):
adam_kwargs["betas"] = (
training_args_kwargs.get("adam_beta1"),
training_args_kwargs.get("adam_beta2"),
)
if training_args_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
return adam_kwargs
def _get_optimizer_class(self, training_args_kwargs, optimizer_kwargs, adam_kwargs):
"""Get optimizer class based on configuration."""
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import MuonOptimizerFactory # pylint: disable=no-name-in-module
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_args_kwargs.get("adam_beta1", 0.9)
beta2 = training_args_kwargs.get("adam_beta2", 0.999)
beta3 = training_args_kwargs.get("adam_beta2", 0.9999)
eps1 = training_args_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_args_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
else:
# Default case or unsupported optimizer
optimizer_cls = None
return optimizer_cls
def _update_optimizer_kwargs_from_config(self, optimizer_kwargs):
"""Update optimizer kwargs from config."""
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
def _add_optimizer_args(self, training_args_kwargs):
"""Add optimizer arguments if available."""
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_args_kwargs["optim_args"] = optim_args
def _get_training_args_cls(self):
"""Get the appropriate training arguments class."""
if self.cfg.reward_model:
return AxolotlRewardConfig
if self.cfg.process_reward_model:
return AxolotlPRMConfig
return AxolotlTrainingArguments
def _prepare_data_collator_kwargs(self):
"""Prepare data collator kwargs."""
data_collator_kwargs = {"padding": True} # True/"longest" is the default
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
self.cfg.sequence_len / 64
)
else:
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len
return data_collator_kwargs
def _prepare_trainer_kwargs(self, trainer_cls, data_collator_kwargs, training_args):
"""Prepare trainer kwargs."""
trainer_kwargs = {}
# Handle special data collators for evaluation
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["eval_data_collator"] = eval_data_collator
# Add bench data collator if needed
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
# Add tokenizer or processing class
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
# Add dataset tags if available
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
return trainer_kwargs
def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
use_batch_sampler_collator = False
if is_eval is False and training_args.sample_packing:
use_batch_sampler_collator = True
if is_eval and training_args.eval_sample_packing:
use_batch_sampler_collator = True
collator: Type[
V2BatchSamplerDataCollatorForSeq2Seq
| BatchSamplerDataCollatorForSeq2Seq
| DataCollatorForSeq2Seq
| DataCollatorWithFlattening
| RewardDataCollatorWithPadding
]
collator_args = [self.tokenizer]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.flex_attention:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
):
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing:
collator = KDBatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
return collator(
*collator_args,
**kwargs,
)

View File

@@ -0,0 +1,367 @@
"""RL trainer / training args builder implementation"""
import inspect
from pathlib import Path
from axolotl.core.trainers.builders.base import TrainerBuilderBase
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
)
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
from axolotl.utils.models import ensure_dtype
class HFRLTrainerBuilder(TrainerBuilderBase):
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build_training_arguments(self, total_num_steps):
training_args_kwargs = {}
for arg in [
"adam_beta1",
"adam_beta2",
"adam_epsilon",
"dataloader_num_workers",
"dataloader_pin_memory",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors is not None:
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.eval_dataset:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else:
training_args_kwargs["eval_strategy"] = "no"
if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
else:
training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs["dataloader_pin_memory"] = (
self.cfg.dataloader_pin_memory
)
if self.cfg.dataloader_num_workers is not None:
training_args_kwargs["dataloader_num_workers"] = (
self.cfg.dataloader_num_workers
)
if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
# set save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
max_steps = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=max_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
warmup_steps=self.cfg.warmup_steps,
logging_first_step=True,
logging_steps=1,
optim=self.cfg.optimizer,
save_total_limit=self.cfg.save_total_limit or 5,
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args
def build(self, total_num_steps):
"""Build and return an RL trainer instance"""
# Prepare RL-specific training args kwargs
training_args_kwargs = {
"per_device_train_batch_size": self.cfg.micro_batch_size,
"max_steps": self.cfg.max_steps or total_num_steps or -1,
"gradient_accumulation_steps": self.cfg.gradient_accumulation_steps,
"learning_rate": self.cfg.learning_rate,
"warmup_steps": self.cfg.warmup_steps,
"logging_first_step": True,
"logging_steps": 1,
"output_dir": self.cfg.output_dir,
"num_train_epochs": self.cfg.num_epochs,
}
# Handle dataset processes
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# Handle beta/alpha parameters for different RL algorithms
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter
training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
# Determine training args class and add RL-specific parameters
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else: # Default to DPO
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
# Remove any blocklisted arguments
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
# Create training args using the base class method
training_args = self.create_training_args(
args_cls=training_args_cls,
total_num_steps=total_num_steps,
**training_args_kwargs,
)
# Prepare trainer kwargs
trainer_kwargs = {}
if self.cfg.rl == "ipo" and self.cfg.dpo_label_smoothing:
trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
# Determine trainer class and arguments
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_args = [self.model]
trainer_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_args = [self.model]
elif self.cfg.rl in ["kto"]:
trainer_cls = AxolotlKTOTrainer
trainer_args = [self.model]
elif self.cfg.rl in ["simpo"]:
trainer_cls = AxolotlCPOTrainer
trainer_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
# Add tokenizer or processing class
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():
trainer_kwargs["tokenizer"] = self.tokenizer
else:
trainer_kwargs["processing_class"] = self.tokenizer
# Add dataset tags if available
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
# Create the trainer
trainer = self.create_trainer(
trainer_cls=trainer_cls,
training_args=training_args,
trainer_args=trainer_args,
trainer_kwargs=trainer_kwargs,
)
# Handle FSDP specific settings
if self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if (
self.cfg.rl in ["dpo", "ipo"]
and hasattr(trainer, "ref_model")
and trainer.ref_model
):
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
return trainer

View File

@@ -26,7 +26,7 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)

View File

@@ -46,11 +46,11 @@ from axolotl.utils.distributed import (
from axolotl.utils.schemas.config import AxolotlInputConfig
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
from axolotl.core.training_args import AxolotlTrainingArguments
IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks")
LOG = logging.getLogger(__name__)
class EvalFirstStepCallback(

View File

@@ -6,7 +6,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import logging
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from multiprocessing import cpu_count, get_context
from typing import Iterable, Union
import numba
@@ -78,15 +78,11 @@ def pack_group(
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
for seq_id, size in enumerate(sequence_lengths):
global_idx = seq_id + group_offset
# Try to place sequence in existing bins
add_new_bin = True
@@ -130,6 +126,7 @@ def pack_parallel(
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
mp_start_method: str | None = "spawn",
):
"""
Pack sequences into bins using parallel processing
@@ -141,7 +138,9 @@ def pack_parallel(
bin_size: Maximum number of bins to use
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver').
'spawn' is often safer with Numba/PyTorch.
Set to None to use system default.
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
@@ -158,9 +157,33 @@ def pack_parallel(
# Process groups in parallel
all_bins = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_bins in executor.map(_process_group, tasks):
mp_ctx = None
if mp_start_method:
try:
mp_ctx = get_context(mp_start_method)
except ValueError:
LOG.warning(
f"Failed to get multiprocessing context '{mp_start_method}'. "
f"Falling back to default. Available: {get_context().get_all_start_methods()}"
)
mp_ctx = (
None # Fallback to default context if specified one is not available
)
if num_processes == 1:
LOG.debug("Using single process for pack_parallel, running sequentially.")
for task_args in tasks:
group_bins = _process_group(task_args)
all_bins.extend(group_bins)
else:
# Use ProcessPoolExecutor only if num_processes > 1
# Pass mp_context if available
with ProcessPoolExecutor(
max_workers=num_processes, mp_context=mp_ctx
) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
return all_bins

View File

@@ -16,7 +16,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
@@ -633,8 +633,7 @@ def setup_trainer(
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
Returns:
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
A trainer instance configured based on the provided parameters.
"""
if (
cfg.torch_compile

View File

@@ -1,10 +1,8 @@
"""
unit tests for axolotl.core.trainer_builder
"""
"""Unit tests for axolotl.core.trainers.builders"""
import pytest
from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.core.trainers.builders import HFRLTrainerBuilder
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -53,9 +51,7 @@ def fixture_model(cfg, tokenizer):
class TestHFRLTrainerBuilder:
"""
TestCase class for DPO trainer builder
"""
"""Test case class for RL trainer builder"""
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)

View File

@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
)

View File

@@ -57,9 +57,9 @@ class Test4dMultipackLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
}
)
@@ -105,9 +105,9 @@ class Test4dMultipackLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
}
)

View File

@@ -57,9 +57,9 @@ class TestMistral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)
@@ -99,9 +99,9 @@ class TestMistral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)

View File

@@ -54,9 +54,9 @@ class TestMixtral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)
@@ -93,9 +93,9 @@ class TestMixtral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)

View File

@@ -56,9 +56,9 @@ class TestPhiMultipack(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"eval_steps": 10,
"save_steps": 10,
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
}
)
@@ -108,9 +108,9 @@ class TestPhiMultipack(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"eval_steps": 10,
"save_steps": 10,
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
}
)

View File

@@ -1,21 +1,21 @@
"""
test module to import various submodules that have historically broken due to dependency issues
"""Test module to import various submodules that have historically broken due to
dependency issues.
"""
import unittest
class TestImports(unittest.TestCase):
"""
Test class to import various submodules that have historically broken due to dependency issues
"""Test class to import various submodules that have historically broken due to
dependency issues.
"""
def test_import_causal_trainer(self):
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
from axolotl.core.trainers.builders import ( # pylint: disable=unused-import # noqa: F401
HFCausalTrainerBuilder,
)
def test_import_rl_trainer(self):
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
from axolotl.core.trainers.builders import ( # pylint: disable=unused-import # noqa: F401
HFRLTrainerBuilder,
)

View File

@@ -106,3 +106,4 @@ class TestBatchedSamplerPacking:
original_idxs = set(range(len(train_dataset)))
assert original_idxs == set(batch_idxs)
assert len(batch_idxs) == len(set(batch_idxs))