Compare commits

..

5 Commits

Author SHA1 Message Date
Dan Saunders
5c0510a876 review comments 2025-03-03 18:44:16 +00:00
Dan Saunders
e1bc18763a combine like functions 2025-02-28 17:47:39 +00:00
Dan Saunders
ed5178cd3d update 2025-02-26 21:03:44 +00:00
Dan Saunders
a3224c7c3c updates 2025-02-26 20:31:54 +00:00
Dan Saunders
c4104fc10c refactor train.py 2025-02-26 19:37:42 +00:00
36 changed files with 228 additions and 1005 deletions

View File

@@ -19,6 +19,9 @@
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p>
Axolotl is a tool designed to streamline post-training for various AI models.

View File

@@ -40,7 +40,6 @@ website:
- section: "Deployments"
contents:
- docs/docker.qmd
- docs/multi-gpu.qmd
- docs/multi-node.qmd
- docs/ray-integration.qmd

View File

@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

View File

@@ -154,6 +154,8 @@ datasets:
content: value
# ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
user: ["human", "user"]
@@ -161,12 +163,6 @@ datasets:
system: ["system"]
tool: ["tool"]
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
# This does not drop the default system message from chat_template if it exists. If you wish to,
# we recommend using a custom jinja template with the default system message removed or
# adding a system turn with empty content.
drop_system_message:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
@@ -226,8 +222,8 @@ process_reward_model:
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message. Currently only supports chatml.
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
@@ -449,7 +445,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -532,8 +528,6 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
sdp_attention:
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Resume from a specific checkpoint dir
resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
@@ -554,13 +548,6 @@ special_tokens:
# Add extra tokens.
tokens:
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP
fsdp:
fsdp_config:

View File

@@ -74,10 +74,6 @@ datasets:
train_on_eos:
```
::: {.callout-tip}
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml

View File

@@ -129,7 +129,6 @@ You can mix and match within each approach or across approaches to train a model
We suggest this approach when you want to bring your own tokenized dataset.
Axolotl expects the dataset to have three keys:
- `input_ids`: from tokenizing formatted prompt
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.

View File

@@ -1,140 +0,0 @@
---
title: "Docker"
format:
html:
toc: true
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
#### Tags format
```bash
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
```
Tags examples:
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
#### Image
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
#### Tags format {#sec-main-tags}
```bash
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest
# nightly build
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
# tagged release
{version}
```
:::{.callout-tip}
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
:::
Tags examples:
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
## Cloud
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
:::{.callout-tip}
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
:::
#### Image
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
#### Tags format
This uses the same tags as the [`main` image](#sec-main-tags).
#### Environment variables
- `JUPYTER_DISABLE`: Disable Jupyter lab.
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
- `PUBLIC_KEY`: Add a public key for the SSH service.
- `SSH_KEY`: Add a private key for the SSH service.
#### Volume mounts
:::{.callout-tip}
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
:::
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
## Cloud-no-tmux
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
#### Image
```
axolotlai/axolotl-cloud-term
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
:::{.callout-note}
The naming may be a bit confusing as it has `-term` appended to the end.
:::
#### Tags format
This uses the same tags as the [`cloud` image](#sec-cloud-tags).

View File

@@ -19,9 +19,7 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed**
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
**Q: The codes is stuck on saving preprocessed datasets.**
@@ -52,7 +50,3 @@ description: Frequently asked questions
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.

View File

@@ -65,8 +65,6 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
```
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu}

View File

@@ -28,17 +28,6 @@ val_set_size: 0.1
eval_steps: 100
```
Bradley-Terry chat templates expect single-turn conversations in the following format:
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### Process Reward Models (PRM)
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
@@ -56,5 +45,3 @@ datasets:
val_set_size: 0.1
eval_steps: 100
```
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.

View File

@@ -3,7 +3,6 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
@@ -529,7 +528,6 @@ trl:
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
@@ -538,8 +536,6 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### Using local dataset files
```yaml

View File

@@ -62,5 +62,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl @ git+https://github.com/axolotl-ai-cloud/axolotl-contribs-lgpl.git@import-issues-v2
axolotl-contribs-mit==0.0.3
axolotl-contribs-lgpl==0.0.3

View File

@@ -24,5 +24,5 @@ if cce_spec:
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
)

View File

@@ -113,7 +113,7 @@ class ModalCloud(Cloud):
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
]
)
@@ -270,7 +270,6 @@ def _preprocess(config_yaml: str, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
@@ -289,7 +288,6 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _lm_eval(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"

View File

@@ -41,12 +41,11 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
del trainer
plugin_manager.post_train_unload(cfg)

View File

@@ -35,7 +35,6 @@ from transformers import (
EarlyStoppingCallback,
TrainerCallback,
)
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
@@ -85,7 +84,6 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
from axolotl.utils.models import ensure_dtype
try:
@@ -551,8 +549,30 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
@@ -636,114 +656,46 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# Handle custom optimizer
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_arguments_kwargs.get("learning_rate"),
"weight_decay": training_arguments_kwargs.get("weight_decay"),
}
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
# Adam-specific kwargs
adam_kwargs = {}
if training_arguments_kwargs.get(
"adam_beta1"
) and training_arguments_kwargs.get("adam_beta2"):
adam_kwargs["betas"] = (
training_arguments_kwargs.get("adam_beta1"),
training_arguments_kwargs.get("adam_beta2"),
)
if training_arguments_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
MuonOptimizerFactory,
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
)
else:
# Use transformers' optimizer
training_arguments_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config:
training_arguments_kwargs[
"accelerator_config"

View File

@@ -14,7 +14,6 @@ from typing import Dict, Literal, Optional
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
@@ -23,11 +22,9 @@ from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
@@ -118,17 +115,6 @@ class SchedulerMixin(Trainer):
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
@@ -168,18 +154,47 @@ class SchedulerMixin(Trainer):
return self.lr_scheduler
class OptimizerMixin(Trainer):
class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Mixin class for shared handling of building custom optimizers
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def create_optimizer_grouped_parameters(
self, opt_model, optimizer_kwargs
) -> list[dict]:
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params: dict = {
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
@@ -266,30 +281,23 @@ class OptimizerMixin(Trainer):
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer
and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
@@ -306,47 +314,50 @@ class OptimizerMixin(Trainer):
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
else:
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -355,45 +366,6 @@ class OptimizerMixin(Trainer):
return self.optimizer
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:

View File

@@ -9,7 +9,6 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl")
@@ -32,44 +31,30 @@ class GRPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if not hasattr(cfg, "trl") or not cfg.trl:
return grpo_args_kwargs
trl: TRLConfig = cfg.trl # type: ignore
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_device"] = (
trl.vllm_device if trl.vllm_device else "auto"
)
if trl.vllm_gpu_memory_utilization:
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = trl.vllm_gpu_memory_utilization
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
if trl.num_generations:
grpo_args_kwargs["num_generations"] = trl.num_generations
if trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
if trl.ref_model_mixup_alpha:
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
if trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
return grpo_args_kwargs
@classmethod

View File

@@ -23,8 +23,6 @@ import importlib
import logging
from typing import OrderedDict
import torch
class BasePlugin:
"""
@@ -471,14 +469,3 @@ class PluginManager:
"""
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
pass

View File

@@ -4,22 +4,6 @@ Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy o
See https://github.com/apple/ml-cross-entropy
## Requirements
- PyTorch 2.4.0 or higher
## Installation
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
```
## Usage
```yaml

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
)

View File

@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator
from pydantic import BaseModel
class SpectrumArgs(BaseModel):
@@ -27,20 +27,3 @@ class SpectrumArgs(BaseModel):
spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_fsdp_use_orig_params(cls, data):
if (
data.get("fsdp")
and data.get("fsdp_config")
and not data["fsdp_config"].get("use_orig_params")
and data.get("plugins")
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
):
# would otherwise raise
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
raise ValueError(
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
)
return data

View File

@@ -461,7 +461,7 @@ def setup_model_and_trainer(
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
"""
Train a model on the given dataset.
@@ -510,4 +510,4 @@ def train(
# Create model card
create_model_card(cfg, trainer)
return model, tokenizer, trainer
return model, tokenizer

View File

@@ -64,17 +64,6 @@ class ChatTemplate(str, Enum):
metharme = "metharme" # pylint: disable=invalid-name
class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -505,7 +494,17 @@ class HyperparametersConfig(BaseModel):
embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, CustomSupportedOptimizers]
Union[
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None,
@@ -519,7 +518,7 @@ class HyperparametersConfig(BaseModel):
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
Union[SchedulerType, Literal["one_cycle"]]
] = SchedulerType.COSINE
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
@@ -779,9 +778,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[
Union[Literal["unsloth", "offload"], bool]
] = Field(default=False)
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
default=False
)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
@@ -856,7 +855,6 @@ class AxolotlInputConfig(
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None
added_tokens_overrides: Optional[Dict[int, str]] = None
torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None
@@ -1155,15 +1153,6 @@ class AxolotlInputConfig(
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_offload_grad_checkpointing(self):
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
LOG.warning(
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
)
self.gradient_checkpointing = "offload"
return self
@model_validator(mode="after")
def check_better_transformers(self):
if self.flash_optimum is True:
@@ -1188,13 +1177,6 @@ class AxolotlInputConfig(
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@model_validator(mode="before")
@classmethod
def check_lr_groups(cls, data):
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
return data
@model_validator(mode="before")
@classmethod
def check_saves(cls, data):

View File

@@ -1,8 +1,7 @@
"""
GRPO specific configuration args
"""
from typing import Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -12,10 +11,7 @@ class TRLConfig(BaseModel):
Input args for TRL.
"""
beta: Optional[float] = Field(
default=None,
json_schema_extra={"description": "Beta for RL training"},
)
beta: Optional[float] = None
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
@@ -24,68 +20,16 @@ class TRLConfig(BaseModel):
)
# GRPO specific args
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
use_vllm: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
)
vllm_device: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
vllm_gpu_memory_utilization: Optional[float] = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
vllm_dtype: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
use_vllm: Optional[bool] = False
vllm_device: Optional[str] = "auto"
vllm_gpu_memory_utilization: Optional[float] = 0.9
vllm_max_model_len: Optional[int] = None
vllm_dtype: Optional[str] = "auto"
reward_funcs: Optional[list[str]] = Field(
default=None,
json_schema_extra={"description": "List of reward functions to load"},
)
reward_weights: Optional[list[float]] = Field(
default=None,
json_schema_extra={
"description": "Weights for each reward function. Must match the number of reward functions."
},
)
num_generations: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
)
log_completions: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
sync_ref_model: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": (
"Whether to sync the reference model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
)
},
)
ref_model_mixup_alpha: Optional[float] = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
)
ref_model_sync_steps: Optional[int] = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)
reward_funcs: Optional[List[str]] = None
num_generations: Optional[int] = None
log_completions: Optional[bool] = False
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64

View File

@@ -79,7 +79,7 @@ def is_main_process():
def is_local_main_process():
return PartialState().is_local_main_process
return PartialState().is_main_process
def get_world_size():

View File

@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
)
def hf_grad_checkpoint_offload_wrapper(
def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(

View File

@@ -57,14 +57,8 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
barrier,
get_device_count,
get_device_type,
is_local_main_process,
zero_only,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -171,95 +165,7 @@ def load_model_config(cfg):
return model_config
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
"""
import json
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
# Get the token IDs and map them to their new values
token_id_mappings = {
int(token_id): new_value for token_id, new_value in token_mappings.items()
}
# 1. Update tokenizer_config.json - added_tokens_decoder
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Update added_tokens_decoder
if "added_tokens_decoder" in config_data:
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
# Write the updated config back
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
# 2. Update tokenizer.json - added_tokens
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
# Update added_tokens
if "added_tokens" in tokenizer_data:
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
# Write the updated tokenizer data back
with open(tokenizer_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
barrier()
return tokenizer_dir
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -274,18 +180,8 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
# Set base tokenizer path
tokenizer_path = cfg.tokenizer_config
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path,
cfg.tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
@@ -493,8 +389,8 @@ class ModelLoader:
patch_fa_peft_integration()
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention:
self.patch_attention()

View File

@@ -6,80 +6,6 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
class RexLR(LRScheduler):
"""
Reflected Exponential (REX) learning rate scheduler.
- Original implementation: https://github.com/IvanVassi/REX_LR
- Original license: Apache 2.0
- Based on: https://arxiv.org/abs/2107.04197
Args:
optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.
max_lr (float): The maximum learning rate.
min_lr (float): The minimum learning rate.
total_steps (int): The total number of training steps.
num_warmup_steps (int): The number of warmup steps.
last_step (int): The index of last step.
"""
def __init__(
self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0
):
if min_lr > max_lr:
raise ValueError(
f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}'
)
if num_warmup_steps > total_steps:
raise ValueError(
f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})."
)
self.min_lr = min_lr
self.max_lr = max_lr
self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
# Pass self.last_step as last_epoch to the parent.
super().__init__(optimizer, last_epoch=self.last_step)
@property
def last_step(self):
return self.last_epoch
@last_step.setter
def last_step(self, value):
self.last_epoch = value
def get_lr(self):
# Warmup phase: if defined, increase lr linearly from 0 to max_lr.
if 1 <= self.last_step <= self.num_warmup_steps:
return [
base_lr * self.last_step / self.num_warmup_steps
for base_lr in self.base_lrs
]
# Post-warmup phase: adjust step relative to the end of warmup.
step_after = self.last_step - self.num_warmup_steps
remaining_steps = self.total_steps - self.num_warmup_steps
# Avoid LR spiking
if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:
return [self.min_lr for _ in self.base_lrs]
mod_iter = step_after % remaining_steps
z = (remaining_steps - mod_iter) / remaining_steps
rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (
z / (0.1 + 0.9 * z)
)
return [base_lr * rex_factor for base_lr in self.base_lrs]
class InterpolatingLogScheduler(LRScheduler):
"""
A scheduler that interpolates learning rates in a logarithmic fashion

View File

@@ -14,7 +14,7 @@
h1 {
font-family: var(--font-title);
font-weight: 400;
font-size: 5rem;
font-size: 6rem;
line-height: 1.1;
letter-spacing: -0.05em;
font-feature-settings: "ss01" on;

View File

@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
config_path.write_text(valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,

View File

@@ -69,51 +69,6 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# pylint: disable=redefined-outer-name
def test_qwen2_w_cce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
"cut_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
"attention_type",
[

View File

@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32

View File

@@ -65,9 +65,8 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
@with_temp_dir
@require_torch_2_5_1
@@ -112,57 +111,8 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
@require_torch_2_5_1
def test_muon(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):

View File

@@ -1,71 +0,0 @@
"""
E2E tests for custom schedulers using Llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomSchedulers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_rex_scheduler(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_hf",
"max_steps": 20,
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -1,7 +1,6 @@
"""
Test cases for the tokenizer loading
"""
import unittest
import pytest
@@ -10,7 +9,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
class TestTokenizers:
class TestTokenizers(unittest.TestCase):
"""
test class for the load_tokenizer fn
"""
@@ -76,48 +75,12 @@ class TestTokenizers:
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
assert len(tokenizer) == 32001
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
self.assertEqual(len(tokenizer), 32001)
# ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg)
assert len(tokenizer) == 32001
def test_added_tokens_overrides(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {
128041: "RANDOM_OVERRIDE_1",
128042: "RANDOM_OVERRIDE_2",
},
"output_dir": temp_dir,
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
128041
]
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
128042
]
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
"output_dir": temp_dir,
}
)
with pytest.raises(
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
):
load_tokenizer(cfg)
self.assertEqual(len(tokenizer), 32001)
if __name__ == "__main__":