Compare commits

..

5 Commits

Author SHA1 Message Date
Dan Saunders
5c0510a876 review comments 2025-03-03 18:44:16 +00:00
Dan Saunders
e1bc18763a combine like functions 2025-02-28 17:47:39 +00:00
Dan Saunders
ed5178cd3d update 2025-02-26 21:03:44 +00:00
Dan Saunders
a3224c7c3c updates 2025-02-26 20:31:54 +00:00
Dan Saunders
c4104fc10c refactor train.py 2025-02-26 19:37:42 +00:00
22 changed files with 191 additions and 669 deletions

View File

@@ -19,6 +19,9 @@
<br/> <br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p> </p>
Axolotl is a tool designed to streamline post-training for various AI models. Axolotl is a tool designed to streamline post-training for various AI models.

View File

@@ -40,7 +40,6 @@ website:
- section: "Deployments" - section: "Deployments"
contents: contents:
- docs/docker.qmd
- docs/multi-gpu.qmd - docs/multi-gpu.qmd
- docs/multi-node.qmd - docs/multi-node.qmd
- docs/ray-integration.qmd - docs/ray-integration.qmd

View File

@@ -163,12 +163,6 @@ datasets:
system: ["system"] system: ["system"]
tool: ["tool"] tool: ["tool"]
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
# This does not drop the default system message from chat_template if it exists. If you wish to,
# we recommend using a custom jinja template with the default system message removed or
# adding a system turn with empty content.
drop_system_message:
# IMPORTANT: The following fields determine which parts of the conversation to train on. # IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train # Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd` # See examples at `docs/dataset-formats/conversation.qmd`
@@ -228,8 +222,8 @@ process_reward_model:
chat_template: tokenizer_default chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null. # custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null chat_template_jinja: null
# Changes the default system message. Currently only supports chatml. # Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so # Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
@@ -451,7 +445,7 @@ gradient_checkpointing: false
early_stopping_patience: 3 early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer # Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs: lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf) cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -534,8 +528,6 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
sdp_attention: sdp_attention:
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf # Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention: s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Resume from a specific checkpoint dir # Resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off. # If resume_from_checkpoint isn't set and you simply want it to start where it left off.

View File

@@ -129,7 +129,6 @@ You can mix and match within each approach or across approaches to train a model
We suggest this approach when you want to bring your own tokenized dataset. We suggest this approach when you want to bring your own tokenized dataset.
Axolotl expects the dataset to have three keys: Axolotl expects the dataset to have three keys:
- `input_ids`: from tokenizing formatted prompt - `input_ids`: from tokenizing formatted prompt
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]` - `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`. - `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.

View File

@@ -1,140 +0,0 @@
---
title: "Docker"
format:
html:
toc: true
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
#### Tags format
```bash
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
```
Tags examples:
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
#### Image
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
#### Tags format {#sec-main-tags}
```bash
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest
# nightly build
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
# tagged release
{version}
```
:::{.callout-tip}
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
:::
Tags examples:
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
## Cloud
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
:::{.callout-tip}
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
:::
#### Image
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
#### Tags format
This uses the same tags as the [`main` image](#sec-main-tags).
#### Environment variables
- `JUPYTER_DISABLE`: Disable Jupyter lab.
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
- `PUBLIC_KEY`: Add a public key for the SSH service.
- `SSH_KEY`: Add a private key for the SSH service.
#### Volume mounts
:::{.callout-tip}
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
:::
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
## Cloud-no-tmux
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
#### Image
```
axolotlai/axolotl-cloud-term
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
:::{.callout-note}
The naming may be a bit confusing as it has `-term` appended to the end.
:::
#### Tags format
This uses the same tags as the [`cloud` image](#sec-cloud-tags).

View File

@@ -19,9 +19,7 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'** **Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed** > A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
**Q: The codes is stuck on saving preprocessed datasets.** **Q: The codes is stuck on saving preprocessed datasets.**

View File

@@ -65,8 +65,6 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
``` ```
::: :::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud} ## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu} ### Cloud GPU Providers {#sec-cloud-gpu}

View File

@@ -63,4 +63,3 @@ torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3 axolotl-contribs-lgpl==0.0.3
axolotl-contribs-mit==0.0.3

View File

@@ -41,12 +41,11 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta) model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
del model del model
del tokenizer del tokenizer
del trainer
plugin_manager.post_train_unload(cfg) plugin_manager.post_train_unload(cfg)

View File

@@ -35,7 +35,6 @@ from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
TrainerCallback, TrainerCallback,
) )
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import ( from axolotl.core.trainers.base import (
@@ -85,7 +84,6 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
) )
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
from axolotl.utils.models import ensure_dtype from axolotl.utils.models import ensure_dtype
try: try:
@@ -551,8 +549,30 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else: else:
training_arguments_kwargs["run_name"] = None training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[ training_arguments_kwargs[
"alternate_lr_scheduler_type" "alternate_lr_scheduler_type"
@@ -636,114 +656,46 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model: if self.cfg.reward_model:
training_arguments_kwargs["max_length"] = self.cfg.sequence_len training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# Handle custom optimizer # pylint: disable=duplicate-code
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers] if self.cfg.optimizer in [
if self.cfg.optimizer in custom_supported_optimizers: "optimi_adamw",
# Common optimizer kwargs "ao_adamw_4bit",
optimizer_kwargs = { "ao_adamw_8bit",
"lr": training_arguments_kwargs.get("learning_rate"), "ao_adamw_fp8",
"weight_decay": training_arguments_kwargs.get("weight_decay"), "adopt_adamw",
} ]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
# Adam-specific kwargs if self.cfg.optimizer == "lion_pytorch":
adam_kwargs = {} from lion_pytorch import Lion
if training_arguments_kwargs.get(
"adam_beta1"
) and training_arguments_kwargs.get("adam_beta2"):
adam_kwargs["betas"] = (
training_arguments_kwargs.get("adam_beta1"),
training_arguments_kwargs.get("adam_beta2"),
)
if training_arguments_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon": lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module if "weight_decay" in training_arguments_kwargs:
MuonOptimizerFactory, lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
) )
optimizer_cls = MuonOptimizerFactory trainer_kwargs["optimizers"] = (
optimizer_kwargs.update(adam_kwargs) Lion(params=self.model.parameters(), **lion_kwargs),
elif self.cfg.optimizer == "optimi_adamw": None,
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
) )
elif self.cfg.optimizer == "ao_adamw_8bit": # Set default so transformers doesn't throw
from torchao.prototype.low_bit_optim import AdamW8bit training_arguments_kwargs["optim"] = "adamw_hf"
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
)
else:
# Use transformers' optimizer
training_arguments_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optimizer == "adamw_anyprecision": if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists(): if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path) sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx") importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config: if self.cfg.accelerator_config:
training_arguments_kwargs[ training_arguments_kwargs[
"accelerator_config" "accelerator_config"

View File

@@ -14,7 +14,6 @@ from typing import Dict, Literal, Optional
import torch import torch
from datasets import Dataset from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer from transformers import Trainer
@@ -23,11 +22,9 @@ from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup, get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant, get_cosine_schedule_with_warmup_decay_constant,
@@ -118,17 +115,6 @@ class SchedulerMixin(Trainer):
**extra_lr_kwargs, **extra_lr_kwargs,
**self.args.lr_scheduler_kwargs, **self.args.lr_scheduler_kwargs,
) )
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic: elif use_cosine_quadratic:
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
@@ -168,18 +154,47 @@ class SchedulerMixin(Trainer):
return self.lr_scheduler return self.lr_scheduler
class OptimizerMixin(Trainer): class AxolotlTrainer(SchedulerMixin, Trainer):
""" """
Mixin class for shared handling of building custom optimizers Extend the base Trainer for axolotl helpers
""" """
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def create_optimizer_grouped_parameters( def __init__(
self, opt_model, optimizer_kwargs self,
) -> list[dict]: *_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model) decay_parameters = self.get_decay_parameter_names(opt_model)
params: dict = { params = {
"to_weight_decay": {}, # LayerNorm and bias "to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens, "embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {}, "no_weight_decay": {},
@@ -266,30 +281,23 @@ class OptimizerMixin(Trainer):
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None and self.args.embedding_lr is None
and self.args.lr_groups is None and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
): ):
return super().create_optimizer() return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
if ( optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
not self.optimizer self.args,
and self.optimizer_cls_and_kwargs is not None opt_model,
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
) )
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs opt_model, optimizer_kwargs
) )
@@ -306,47 +314,50 @@ class OptimizerMixin(Trainer):
loraplus_lr_embedding=loraplus_lr_embedding, loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs, **optimizer_kwargs,
) )
else: elif (
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` self.args.embedding_lr_scale is not None
# e.g. for GaLore optimizer. or self.args.embedding_lr is not None
if "params" in optimizer_kwargs: or self.args.lr_groups is not None
optimizer_grouped_parameters = optimizer_kwargs.pop("params") ):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
) )
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = optimizer_cls( self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_grouped_parameters, **optimizer_kwargs AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
) )
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
) )
LOG.info(f"skipped {module}: {skipped/2**20}M params") elif self.args.alternate_optimizer == "ao_adamw_4bit":
manager.register_module_override( from torchao.prototype.low_bit_optim import AdamW4bit
module, "weight", {"optim_bits": 32}
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
) )
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -355,45 +366,6 @@ class OptimizerMixin(Trainer):
return self.optimizer return self.optimizer
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches: if self.args.multipack_real_batches:

View File

@@ -9,7 +9,6 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -32,44 +31,30 @@ class GRPOStrategy:
@classmethod @classmethod
def set_training_args_kwargs(cls, cfg): def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {} grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
if not hasattr(cfg, "trl") or not cfg.trl: grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
return grpo_args_kwargs if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
trl: TRLConfig = cfg.trl # type: ignore else:
grpo_args_kwargs["vllm_device"] = "auto"
if trl.use_vllm: if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_device"] = (
trl.vllm_device if trl.vllm_device else "auto"
)
if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[ grpo_args_kwargs[
"vllm_gpu_memory_utilization" "vllm_gpu_memory_utilization"
] = trl.vllm_gpu_memory_utilization ] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
if trl.vllm_max_model_len: grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if trl.num_generations: if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["num_generations"] = trl.num_generations grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
if trl.sync_ref_model: grpo_args_kwargs[
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model "ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if trl.ref_model_mixup_alpha: if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
if trl.ref_model_sync_steps: grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod

View File

@@ -23,8 +23,6 @@ import importlib
import logging import logging
from typing import OrderedDict from typing import OrderedDict
import torch
class BasePlugin: class BasePlugin:
""" """
@@ -471,14 +469,3 @@ class PluginManager:
""" """
for plugin in self.plugins.values(): for plugin in self.plugins.values():
plugin.post_train_unload(cfg) plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
pass

View File

@@ -4,22 +4,6 @@ Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy o
See https://github.com/apple/ml-cross-entropy See https://github.com/apple/ml-cross-entropy
## Requirements
- PyTorch 2.4.0 or higher
## Installation
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
```
## Usage ## Usage
```yaml ```yaml

View File

@@ -461,7 +461,7 @@ def setup_model_and_trainer(
def train( def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]: ) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
""" """
Train a model on the given dataset. Train a model on the given dataset.
@@ -510,4 +510,4 @@ def train(
# Create model card # Create model card
create_model_card(cfg, trainer) create_model_card(cfg, trainer)
return model, tokenizer, trainer return model, tokenizer

View File

@@ -64,18 +64,6 @@ class ChatTemplate(str, Enum):
metharme = "metharme" # pylint: disable=invalid-name metharme = "metharme" # pylint: disable=invalid-name
class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel): class DeprecatedParameters(BaseModel):
"""configurations that are deprecated""" """configurations that are deprecated"""
@@ -506,7 +494,17 @@ class HyperparametersConfig(BaseModel):
embedding_lr_scale: Optional[float] = None embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[ optimizer: Optional[
Union[OptimizerNames, CustomSupportedOptimizers] Union[
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF ] = OptimizerNames.ADAMW_HF
optim_args: Optional[Union[str, Dict[str, Any]]] = Field( optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, default=None,
@@ -520,7 +518,7 @@ class HyperparametersConfig(BaseModel):
) )
torchdistx_path: Optional[str] = None torchdistx_path: Optional[str] = None
lr_scheduler: Optional[ lr_scheduler: Optional[
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]] Union[SchedulerType, Literal["one_cycle"]]
] = SchedulerType.COSINE ] = SchedulerType.COSINE
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None lr_quadratic_warmup: Optional[bool] = None
@@ -1179,13 +1177,6 @@ class AxolotlInputConfig(
LOG.warning("adamw hyperparameters found, but no adamw optimizer set") LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self return self
@model_validator(mode="before")
@classmethod
def check_lr_groups(cls, data):
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_saves(cls, data): def check_saves(cls, data):

View File

@@ -27,7 +27,6 @@ class TRLConfig(BaseModel):
vllm_dtype: Optional[str] = "auto" vllm_dtype: Optional[str] = "auto"
reward_funcs: Optional[List[str]] = None reward_funcs: Optional[List[str]] = None
reward_weights: Optional[List[float]] = None
num_generations: Optional[int] = None num_generations: Optional[int] = None
log_completions: Optional[bool] = False log_completions: Optional[bool] = False

View File

@@ -6,80 +6,6 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torch.optim.lr_scheduler import LambdaLR, LRScheduler
class RexLR(LRScheduler):
"""
Reflected Exponential (REX) learning rate scheduler.
- Original implementation: https://github.com/IvanVassi/REX_LR
- Original license: Apache 2.0
- Based on: https://arxiv.org/abs/2107.04197
Args:
optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.
max_lr (float): The maximum learning rate.
min_lr (float): The minimum learning rate.
total_steps (int): The total number of training steps.
num_warmup_steps (int): The number of warmup steps.
last_step (int): The index of last step.
"""
def __init__(
self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0
):
if min_lr > max_lr:
raise ValueError(
f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}'
)
if num_warmup_steps > total_steps:
raise ValueError(
f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})."
)
self.min_lr = min_lr
self.max_lr = max_lr
self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
# Pass self.last_step as last_epoch to the parent.
super().__init__(optimizer, last_epoch=self.last_step)
@property
def last_step(self):
return self.last_epoch
@last_step.setter
def last_step(self, value):
self.last_epoch = value
def get_lr(self):
# Warmup phase: if defined, increase lr linearly from 0 to max_lr.
if 1 <= self.last_step <= self.num_warmup_steps:
return [
base_lr * self.last_step / self.num_warmup_steps
for base_lr in self.base_lrs
]
# Post-warmup phase: adjust step relative to the end of warmup.
step_after = self.last_step - self.num_warmup_steps
remaining_steps = self.total_steps - self.num_warmup_steps
# Avoid LR spiking
if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:
return [self.min_lr for _ in self.base_lrs]
mod_iter = step_after % remaining_steps
z = (remaining_steps - mod_iter) / remaining_steps
rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (
z / (0.1 + 0.9 * z)
)
return [base_lr * rex_factor for base_lr in self.base_lrs]
class InterpolatingLogScheduler(LRScheduler): class InterpolatingLogScheduler(LRScheduler):
""" """
A scheduler that interpolates learning rates in a logarithmic fashion A scheduler that interpolates learning rates in a logarithmic fashion

View File

@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
config_path.write_text(valid_test_config) config_path.write_text(valid_test_config)
with patch("axolotl.cli.train.train") as mock_train: with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock()) mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke( result = cli_runner.invoke(
cli, cli,
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
config_path = self._test_cli_overrides(tmp_path, valid_test_config) config_path = self._test_cli_overrides(tmp_path, valid_test_config)
with patch("axolotl.cli.train.train") as mock_train: with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock()) mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke( result = cli_runner.invoke(
cli, cli,

View File

@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32

View File

@@ -65,9 +65,8 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
@with_temp_dir @with_temp_dir
@require_torch_2_5_1 @require_torch_2_5_1
@@ -112,57 +111,8 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
@require_torch_2_5_1
def test_muon(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir @with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir): def test_fft_schedule_free_adamw(self, temp_dir):

View File

@@ -1,71 +0,0 @@
"""
E2E tests for custom schedulers using Llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomSchedulers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_rex_scheduler(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_hf",
"max_steps": 20,
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)