Compare commits
25 Commits
train-refa
...
kto_fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92c217677c | ||
|
|
fbe54be6b8 | ||
|
|
04f6324833 | ||
|
|
f0072f3b9d | ||
|
|
59899b9817 | ||
|
|
4a736986fa | ||
|
|
5d0f110a3b | ||
|
|
83f8698b8a | ||
|
|
60a11a6410 | ||
|
|
46a045e528 | ||
|
|
3b477e08a0 | ||
|
|
16dc6ee68d | ||
|
|
fa7c79b3b9 | ||
|
|
ae66374156 | ||
|
|
5e21b1a9da | ||
|
|
575e5f28ec | ||
|
|
0134093acc | ||
|
|
d4de93a7bb | ||
|
|
c8191394e9 | ||
|
|
f18231c653 | ||
|
|
9ed4f6b3aa | ||
|
|
05dddfc41d | ||
|
|
8e30917440 | ||
|
|
d883b11b6f | ||
|
|
f4910dd2ea |
5
.github/workflows/main.yml
vendored
5
.github/workflows/main.yml
vendored
@@ -88,6 +88,11 @@ jobs:
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
5
.github/workflows/nightlies.yml
vendored
5
.github/workflows/nightlies.yml
vendored
@@ -80,6 +80,11 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
@@ -19,9 +19,6 @@
|
||||
<br/>
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
||||
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||
|
||||
@@ -40,6 +40,7 @@ website:
|
||||
|
||||
- section: "Deployments"
|
||||
contents:
|
||||
- docs/docker.qmd
|
||||
- docs/multi-gpu.qmd
|
||||
- docs/multi-node.qmd
|
||||
- docs/ray-integration.qmd
|
||||
|
||||
@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
|
||||
|
||||
RUN pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
||||
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
mkdir -p ~/.ssh && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
|
||||
@@ -154,8 +154,6 @@ datasets:
|
||||
content: value
|
||||
# ...
|
||||
|
||||
message_property_mappings:
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||
roles:
|
||||
user: ["human", "user"]
|
||||
@@ -163,6 +161,12 @@ datasets:
|
||||
system: ["system"]
|
||||
tool: ["tool"]
|
||||
|
||||
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
|
||||
# This does not drop the default system message from chat_template if it exists. If you wish to,
|
||||
# we recommend using a custom jinja template with the default system message removed or
|
||||
# adding a system turn with empty content.
|
||||
drop_system_message:
|
||||
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||
@@ -222,8 +226,8 @@ process_reward_model:
|
||||
chat_template: tokenizer_default
|
||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||
chat_template_jinja: null
|
||||
# Changes the default system message
|
||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
||||
# Changes the default system message. Currently only supports chatml.
|
||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
@@ -445,7 +449,7 @@ gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler_kwargs:
|
||||
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
||||
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
||||
@@ -528,6 +532,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
sdp_attention:
|
||||
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||
s2_attention:
|
||||
# Optional[bool]. Whether to use low_cpu_mem_usage
|
||||
low_cpu_mem_usage:
|
||||
# Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
@@ -548,6 +554,13 @@ special_tokens:
|
||||
# Add extra tokens.
|
||||
tokens:
|
||||
|
||||
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||
# Can be checked if they exist in tokenizer.json added_tokens.
|
||||
added_tokens_overrides: # Dict[int, str]
|
||||
# 128041: "<|im_start|>"
|
||||
# 128042: "<|im_end|>"
|
||||
|
||||
# FSDP
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
@@ -74,6 +74,10 @@ datasets:
|
||||
train_on_eos:
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
||||
:::
|
||||
|
||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -129,6 +129,7 @@ You can mix and match within each approach or across approaches to train a model
|
||||
We suggest this approach when you want to bring your own tokenized dataset.
|
||||
|
||||
Axolotl expects the dataset to have three keys:
|
||||
|
||||
- `input_ids`: from tokenizing formatted prompt
|
||||
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
|
||||
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
|
||||
|
||||
140
docs/docker.qmd
Normal file
140
docs/docker.qmd
Normal file
@@ -0,0 +1,140 @@
|
||||
---
|
||||
title: "Docker"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
|
||||
## Base
|
||||
|
||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl-base
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
|
||||
|
||||
#### Tags format
|
||||
|
||||
```bash
|
||||
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
```
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu124-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.5.1`
|
||||
- `main-base-py3.11-cu124-2.4.1`
|
||||
|
||||
## Main
|
||||
|
||||
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
||||
|
||||
#### Tags format {#sec-main-tags}
|
||||
|
||||
```bash
|
||||
# on push to main
|
||||
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
|
||||
main-latest
|
||||
|
||||
# nightly build
|
||||
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
# tagged release
|
||||
{version}
|
||||
```
|
||||
|
||||
:::{.callout-tip}
|
||||
|
||||
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
|
||||
|
||||
:::
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-py3.11-cu124-2.6.0`
|
||||
- `main-py3.11-cu124-2.5.1`
|
||||
- `main-py3.11-cu124-2.4.1`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu124-2.5.1`
|
||||
- `main-20250303-py3.11-cu124-2.4.1`
|
||||
- `0.7.1`
|
||||
|
||||
## Cloud
|
||||
|
||||
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
|
||||
|
||||
:::{.callout-tip}
|
||||
|
||||
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
|
||||
|
||||
:::
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl-cloud
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
|
||||
|
||||
#### Tags format
|
||||
|
||||
This uses the same tags as the [`main` image](#sec-main-tags).
|
||||
|
||||
#### Environment variables
|
||||
|
||||
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
||||
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
||||
- `PUBLIC_KEY`: Add a public key for the SSH service.
|
||||
- `SSH_KEY`: Add a private key for the SSH service.
|
||||
|
||||
#### Volume mounts
|
||||
|
||||
:::{.callout-tip}
|
||||
|
||||
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
|
||||
|
||||
:::
|
||||
|
||||
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
|
||||
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
|
||||
|
||||
## Cloud-no-tmux
|
||||
|
||||
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl-cloud-term
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
|
||||
|
||||
:::{.callout-note}
|
||||
|
||||
The naming may be a bit confusing as it has `-term` appended to the end.
|
||||
|
||||
:::
|
||||
|
||||
#### Tags format
|
||||
|
||||
This uses the same tags as the [`cloud` image](#sec-cloud-tags).
|
||||
@@ -19,7 +19,9 @@ description: Frequently asked questions
|
||||
|
||||
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
|
||||
|
||||
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
||||
**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed**
|
||||
|
||||
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
|
||||
|
||||
**Q: The codes is stuck on saving preprocessed datasets.**
|
||||
|
||||
@@ -50,3 +52,7 @@ description: Frequently asked questions
|
||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
||||
|
||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
||||
|
||||
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
|
||||
|
||||
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
|
||||
|
||||
@@ -65,6 +65,8 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
||||
```
|
||||
:::
|
||||
|
||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||
|
||||
## Cloud Environments {#sec-cloud}
|
||||
|
||||
### Cloud GPU Providers {#sec-cloud-gpu}
|
||||
|
||||
@@ -28,6 +28,17 @@ val_set_size: 0.1
|
||||
eval_steps: 100
|
||||
```
|
||||
|
||||
Bradley-Terry chat templates expect single-turn conversations in the following format:
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"input": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### Process Reward Models (PRM)
|
||||
|
||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||
@@ -45,3 +56,5 @@ datasets:
|
||||
val_set_size: 0.1
|
||||
eval_steps: 100
|
||||
```
|
||||
|
||||
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.
|
||||
|
||||
@@ -3,6 +3,7 @@ title: "RLHF (Beta)"
|
||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
@@ -528,6 +529,7 @@ trl:
|
||||
vllm_gpu_memory_utilization: 0.15
|
||||
num_generations: 4
|
||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||
reward_weights: [1.0]
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
@@ -536,6 +538,8 @@ datasets:
|
||||
|
||||
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
||||
|
||||
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||
|
||||
### Using local dataset files
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -62,4 +62,5 @@ antlr4-python3-runtime==4.13.2
|
||||
torchao==0.7.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
axolotl-contribs-lgpl==0.0.3
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
@@ -24,5 +24,5 @@ if cce_spec:
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
||||
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
|
||||
)
|
||||
|
||||
@@ -113,7 +113,7 @@ class ModalCloud(Cloud):
|
||||
[
|
||||
# Random id for cache busting of branch commits
|
||||
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
|
||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -270,6 +270,7 @@ def _preprocess(config_yaml: str, volumes=None):
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
@@ -288,6 +289,7 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
|
||||
|
||||
def _lm_eval(config_yaml: str, volumes=None):
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""CLI to run training on a model."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -34,18 +35,20 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
del trainer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ class TrainDatasetMeta:
|
||||
"""Dataclass with fields for training and validation datasets and metadata."""
|
||||
|
||||
train_dataset: Dataset
|
||||
eval_dataset: Optional[Dataset] = None
|
||||
total_num_steps: Optional[int] = None
|
||||
eval_dataset: Dataset | None = None
|
||||
total_num_steps: int | None = None
|
||||
|
||||
|
||||
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
|
||||
@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
|
||||
process_or_cpu_count: int = (
|
||||
process_count or os.cpu_count() # type: ignore[assignment]
|
||||
)
|
||||
num_proc = min(64, process_or_cpu_count)
|
||||
num_proc = min(32, process_or_cpu_count)
|
||||
features = data.features.keys()
|
||||
tokenized_data = data.map(
|
||||
map_fn,
|
||||
|
||||
@@ -35,11 +35,11 @@ from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.training_args import OptimizerNames
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlMambaTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
@@ -50,6 +50,7 @@ from axolotl.core.trainers.base import (
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.kto import AxolotlKTOTrainer
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
@@ -84,6 +85,7 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
||||
from axolotl.utils.models import ensure_dtype
|
||||
|
||||
try:
|
||||
@@ -91,13 +93,11 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
"""
|
||||
"""Base class for trainer builder."""
|
||||
|
||||
_train_dataset = None
|
||||
_eval_dataset = None
|
||||
@@ -110,9 +110,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
self.tokenizer = tokenizer
|
||||
self.processor = processor
|
||||
|
||||
# in case the model supports tagging, add the axolotl tag.
|
||||
# If the model supports tagging, add the axolotl tag.
|
||||
# This makes sure the tag is correctly pushed even if a user calls
|
||||
# model.push_to_hub instad of trainer.push_to_hub.
|
||||
# model.push_to_hub instead of trainer.push_to_hub.
|
||||
if hasattr(model, "add_model_tags"):
|
||||
model.add_model_tags(["axolotl"])
|
||||
|
||||
@@ -227,8 +227,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Build the HuggingFace training args/trainer for causal models
|
||||
and reward modelling using TRL.
|
||||
Build the HuggingFace training args/trainer for causal models and reward modeling
|
||||
using TRL.
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
@@ -551,30 +551,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
else:
|
||||
training_arguments_kwargs["run_name"] = None
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
if self.cfg.optim_args:
|
||||
if isinstance(self.cfg.optim_args, dict):
|
||||
optim_args = ",".join(
|
||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||
)
|
||||
else:
|
||||
optim_args = self.cfg.optim_args
|
||||
training_arguments_kwargs["optim_args"] = optim_args
|
||||
if self.cfg.optim_target_modules:
|
||||
training_arguments_kwargs[
|
||||
"optim_target_modules"
|
||||
] = self.cfg.optim_target_modules
|
||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
training_arguments_kwargs[
|
||||
"alternate_lr_scheduler_type"
|
||||
@@ -658,46 +636,114 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.reward_model:
|
||||
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if self.cfg.optimizer in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]:
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
||||
# Handle custom optimizer
|
||||
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
||||
if self.cfg.optimizer in custom_supported_optimizers:
|
||||
# Common optimizer kwargs
|
||||
optimizer_kwargs = {
|
||||
"lr": training_arguments_kwargs.get("learning_rate"),
|
||||
"weight_decay": training_arguments_kwargs.get("weight_decay"),
|
||||
}
|
||||
|
||||
if self.cfg.optimizer == "lion_pytorch":
|
||||
from lion_pytorch import Lion
|
||||
# Adam-specific kwargs
|
||||
adam_kwargs = {}
|
||||
if training_arguments_kwargs.get(
|
||||
"adam_beta1"
|
||||
) and training_arguments_kwargs.get("adam_beta2"):
|
||||
adam_kwargs["betas"] = (
|
||||
training_arguments_kwargs.get("adam_beta1"),
|
||||
training_arguments_kwargs.get("adam_beta2"),
|
||||
)
|
||||
if training_arguments_kwargs.get("adam_epsilon"):
|
||||
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
|
||||
|
||||
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
|
||||
if "weight_decay" in training_arguments_kwargs:
|
||||
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
|
||||
|
||||
if (
|
||||
"adam_beta1" in training_arguments_kwargs
|
||||
and "adam_beta2" in training_arguments_kwargs
|
||||
):
|
||||
lion_kwargs["betas"] = (
|
||||
training_arguments_kwargs["adam_beta1"],
|
||||
training_arguments_kwargs["adam_beta2"],
|
||||
if self.cfg.optimizer == "muon":
|
||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
trainer_kwargs["optimizers"] = (
|
||||
Lion(params=self.model.parameters(), **lion_kwargs),
|
||||
None,
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
optimizer_kwargs["foreach"] = False
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||
# TODO remove 20250401
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
optimizer_cls = AdamW4bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
LOG.warning(
|
||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||
)
|
||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
optimizer_cls = AdamW8bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
optimizer_cls = AdamWFp8
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
|
||||
optimizer_cls = ADOPT
|
||||
adam_kwargs["decouple"] = True
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
# Parse any additional optimizer args from config
|
||||
if self.cfg.optim_args:
|
||||
if isinstance(self.cfg.optim_args, dict):
|
||||
optimizer_kwargs.update(self.cfg.optim_args)
|
||||
else:
|
||||
# Parse string format "key1=value1,key2=value2"
|
||||
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
|
||||
key, value = mapping.split("=")
|
||||
optimizer_kwargs[key] = value
|
||||
|
||||
trainer_kwargs["optimizer_cls_and_kwargs"] = (
|
||||
optimizer_cls,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
else:
|
||||
# Use transformers' optimizer
|
||||
training_arguments_kwargs["optim"] = self.cfg.optimizer
|
||||
|
||||
# Parse any additional optimizer args from config
|
||||
if self.cfg.optim_args:
|
||||
if isinstance(self.cfg.optim_args, dict):
|
||||
optim_args = ",".join(
|
||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||
)
|
||||
else:
|
||||
optim_args = self.cfg.optim_args
|
||||
training_arguments_kwargs["optim_args"] = optim_args
|
||||
|
||||
if self.cfg.optimizer == "adamw_anyprecision":
|
||||
if Path(self.cfg.torchdistx_path).exists():
|
||||
sys.path.append(self.cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
if self.cfg.optim_target_modules:
|
||||
training_arguments_kwargs[
|
||||
"optim_target_modules"
|
||||
] = self.cfg.optim_target_modules
|
||||
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
|
||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.accelerator_config:
|
||||
training_arguments_kwargs[
|
||||
"accelerator_config"
|
||||
@@ -872,9 +918,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
|
||||
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
|
||||
"""
|
||||
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
|
||||
@@ -14,17 +14,21 @@ from typing import Dict, Literal, Optional
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl import CPOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.core.trainers.kto import AxolotlKTOTrainer
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
@@ -115,6 +119,17 @@ class SchedulerMixin(Trainer):
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||
if use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
|
||||
self.lr_scheduler = RexLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
total_steps=num_training_steps,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
@@ -154,47 +169,18 @@ class SchedulerMixin(Trainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
class OptimizerMixin(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
Mixin class for shared handling of building custom optimizers
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
||||
def create_optimizer_grouped_parameters(
|
||||
self, opt_model, optimizer_kwargs
|
||||
) -> list[dict]:
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params = {
|
||||
params: dict = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
@@ -281,23 +267,30 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
and self.optimizer_cls_and_kwargs is None
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
|
||||
if (
|
||||
not self.optimizer
|
||||
and self.optimizer_cls_and_kwargs is not None
|
||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||
):
|
||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
self.optimizer = optimizer_factory_cls()(
|
||||
opt_model, self.args, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if not self.optimizer:
|
||||
if self.optimizer_cls_and_kwargs is not None:
|
||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
else:
|
||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||
self.args, opt_model
|
||||
)
|
||||
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
@@ -314,50 +307,47 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
elif (
|
||||
self.args.embedding_lr_scale is not None
|
||||
or self.args.embedding_lr is not None
|
||||
or self.args.lr_groups is not None
|
||||
):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
else:
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW(
|
||||
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for LOMO optimizer.
|
||||
if "model" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||
"optimizer_dict"
|
||||
)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
self.optimizer = optimizer_cls(
|
||||
optimizer_grouped_parameters, **optimizer_kwargs
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
ADOPT(
|
||||
optimizer_grouped_parameters,
|
||||
decouple=True,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
)
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum(
|
||||
{
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
@@ -366,6 +356,45 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
|
||||
return self.optimizer
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
@@ -846,14 +875,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -31,30 +32,44 @@ class GRPOStrategy:
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
grpo_args_kwargs = {}
|
||||
if cfg.trl and cfg.trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||
if cfg.trl and cfg.trl.vllm_device:
|
||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||
else:
|
||||
grpo_args_kwargs["vllm_device"] = "auto"
|
||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
||||
|
||||
if not hasattr(cfg, "trl") or not cfg.trl:
|
||||
return grpo_args_kwargs
|
||||
|
||||
trl: TRLConfig = cfg.trl # type: ignore
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_device"] = (
|
||||
trl.vllm_device if trl.vllm_device else "auto"
|
||||
)
|
||||
|
||||
if trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = cfg.trl.vllm_gpu_memory_utilization
|
||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||
if cfg.trl and cfg.trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||
if cfg.trl and cfg.trl.sync_ref_model:
|
||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
||||
grpo_args_kwargs[
|
||||
"ref_model_mixup_alpha"
|
||||
] = cfg.trl.ref_model_mixup_alpha
|
||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
||||
] = trl.vllm_gpu_memory_utilization
|
||||
|
||||
if trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||
|
||||
if trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||
|
||||
if trl.sync_ref_model:
|
||||
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
|
||||
|
||||
if trl.ref_model_mixup_alpha:
|
||||
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
|
||||
|
||||
if trl.ref_model_sync_steps:
|
||||
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
|
||||
|
||||
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
|
||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
7
src/axolotl/core/trainers/kto/__init__.py
Normal file
7
src/axolotl/core/trainers/kto/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
KTO package initialization.
|
||||
"""
|
||||
|
||||
from axolotl.core.trainers.kto.trainer import AxolotlKTOTrainer
|
||||
|
||||
__all__ = ["AxolotlKTOTrainer"]
|
||||
512
src/axolotl/core/trainers/kto/trainer.py
Normal file
512
src/axolotl/core/trainers/kto/trainer.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""
|
||||
KTO trainer implementation for Axolotl.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
BaseImageProcessor,
|
||||
DataCollator,
|
||||
FeatureExtractionMixin,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from trl import KTOTrainer
|
||||
from trl.trainer.kto_config import KTOConfig
|
||||
from trl.trainer.utils import KTODataCollatorWithPadding, pad_to_length
|
||||
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
|
||||
# Check if PEFT is available
|
||||
try:
|
||||
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training, peft_module_casting_to_bf16
|
||||
is_peft_available = True
|
||||
except ImportError:
|
||||
is_peft_available = False
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainers.kto")
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, Trainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||
args: KTOConfig = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
dataset_tags=None,
|
||||
model_adapter_name: Optional[str] = None,
|
||||
ref_adapter_name: Optional[str] = None,
|
||||
):
|
||||
self.dataset_tags = dataset_tags
|
||||
self._tag_names = ["trl", "kto"]
|
||||
if hasattr(self, "tag_names"):
|
||||
self._tag_names.extend(self.tag_names)
|
||||
|
||||
if type(args) is TrainingArguments:
|
||||
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
||||
|
||||
if args.model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
elif not isinstance(model, str):
|
||||
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
||||
else:
|
||||
model_init_kwargs = args.model_init_kwargs
|
||||
torch_dtype = model_init_kwargs.get("torch_dtype")
|
||||
if torch_dtype is not None:
|
||||
# Convert to `torch.dtype` if an str is passed
|
||||
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
||||
)
|
||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
|
||||
if args.ref_model_init_kwargs is None:
|
||||
ref_model_init_kwargs = {}
|
||||
elif not isinstance(ref_model, str):
|
||||
raise ValueError(
|
||||
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
||||
)
|
||||
else:
|
||||
ref_model_init_kwargs = args.ref_model_init_kwargs
|
||||
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
||||
if torch_dtype is not None:
|
||||
# Convert to `torch.dtype` if an str is passed
|
||||
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
||||
)
|
||||
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
|
||||
if isinstance(model, str):
|
||||
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
||||
|
||||
if isinstance(ref_model, str):
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
||||
|
||||
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
||||
# has been called in order to properly call autocast if needed.
|
||||
self._peft_has_been_casted_to_bf16 = False
|
||||
|
||||
if not is_peft_available() and peft_config is not None:
|
||||
raise ValueError(
|
||||
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
||||
)
|
||||
elif is_peft_available() and peft_config is not None:
|
||||
# if model is a peft model and we have a peft_config, we merge and unload it first
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
||||
_support_gc_kwargs = hasattr(
|
||||
args, "gradient_checkpointing_kwargs"
|
||||
) and "gradient_checkpointing_kwargs" in list(
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if _support_gc_kwargs:
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
elif getattr(args, "gradient_checkpointing", False):
|
||||
# For backward compatibility with older versions of transformers
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
# get peft model with the given config
|
||||
model = get_peft_model(model, peft_config)
|
||||
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
||||
peft_module_casting_to_bf16(model)
|
||||
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
||||
self._peft_has_been_casted_to_bf16 = True
|
||||
|
||||
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
||||
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
||||
# fail or completely fail.
|
||||
elif getattr(args, "gradient_checkpointing", False):
|
||||
# For backward compatibility with older versions of transformers
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
||||
raise ValueError(
|
||||
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
||||
" Please install `wandb` or `comet-ml` to resolve."
|
||||
)
|
||||
|
||||
if model is not None:
|
||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||
elif args.is_encoder_decoder is None:
|
||||
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
||||
else:
|
||||
self.is_encoder_decoder = args.is_encoder_decoder
|
||||
|
||||
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
||||
self.model_adapter_name = model_adapter_name
|
||||
self.ref_adapter_name = ref_adapter_name
|
||||
|
||||
if ref_model:
|
||||
self.ref_model = ref_model
|
||||
elif self.is_peft_model or args.precompute_ref_log_probs:
|
||||
# The `model` with adapters turned off will be used as the reference model
|
||||
self.ref_model = None
|
||||
else:
|
||||
self.ref_model = create_reference_model(model)
|
||||
|
||||
if processing_class is None:
|
||||
raise ValueError(
|
||||
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
||||
)
|
||||
if args.max_length is None:
|
||||
warnings.warn(
|
||||
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
||||
" it will be set to `512` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_length = 512
|
||||
if args.max_length is not None:
|
||||
max_length = args.max_length
|
||||
|
||||
if args.max_prompt_length is None:
|
||||
warnings.warn(
|
||||
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
||||
" it will be set to `128` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_prompt_length = 128
|
||||
if args.max_prompt_length is not None:
|
||||
max_prompt_length = args.max_prompt_length
|
||||
|
||||
max_completion_length = None
|
||||
if args.max_completion_length is None and self.is_encoder_decoder:
|
||||
warnings.warn(
|
||||
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
||||
" it will be set to `128` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_completion_length = 128
|
||||
if args.max_completion_length is not None and self.is_encoder_decoder:
|
||||
max_completion_length = args.max_completion_length
|
||||
|
||||
if data_collator is None:
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
pad_token_id=processing_class.pad_token_id,
|
||||
label_pad_token_id=args.label_pad_token_id,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
)
|
||||
|
||||
if args.remove_unused_columns:
|
||||
args.remove_unused_columns = False
|
||||
# warn users
|
||||
warnings.warn(
|
||||
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
||||
" we have set it for you, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self.use_dpo_data_collator = True
|
||||
else:
|
||||
self.use_dpo_data_collator = False
|
||||
|
||||
# Disable dropout in the model and reference model
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.ref_model is not None:
|
||||
disable_dropout_in_model(self.ref_model)
|
||||
|
||||
self.loss_type = args.loss_type
|
||||
self.max_length = max_length
|
||||
self.generate_during_eval = args.generate_during_eval
|
||||
self.label_pad_token_id = args.label_pad_token_id
|
||||
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
||||
self.max_prompt_length = max_prompt_length
|
||||
self.truncation_mode = args.truncation_mode
|
||||
self.max_completion_length = max_completion_length
|
||||
self.processing_class = processing_class
|
||||
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
||||
|
||||
# Not all losses require a KL calculation
|
||||
self.calculate_KL = True
|
||||
if self.loss_type in ["apo_zero_unpaired"]:
|
||||
self.calculate_KL = False
|
||||
|
||||
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
||||
# keep track of first called to avoid computation of future calls
|
||||
self._precomputed_train_ref_log_probs = False
|
||||
self._precomputed_eval_ref_log_probs = False
|
||||
|
||||
# metric
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# KTO parameter
|
||||
self.beta = args.beta
|
||||
self.desirable_weight = args.desirable_weight
|
||||
self.undesirable_weight = args.undesirable_weight
|
||||
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
||||
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
||||
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
||||
warnings.warn(
|
||||
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
||||
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
||||
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
||||
"loss.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
|
||||
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
||||
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
||||
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
||||
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
||||
# issued.
|
||||
model.warnings_issued["estimate_tokens"] = True
|
||||
|
||||
# Compute that only on the main process for faster data processing.
|
||||
# see: https://github.com/huggingface/trl/pull/1255
|
||||
with PartialState().local_main_process_first():
|
||||
# Extract the prompt if needed
|
||||
train_dataset = train_dataset.map(
|
||||
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
||||
)
|
||||
# Unpair the dataset if needed
|
||||
train_dataset = maybe_unpair_preference_dataset(
|
||||
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
||||
)
|
||||
# Apply the chat template if needed
|
||||
train_dataset = train_dataset.map(
|
||||
maybe_apply_chat_template,
|
||||
fn_kwargs={"tokenizer": processing_class},
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Applying chat template to train dataset",
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(
|
||||
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
||||
)
|
||||
eval_dataset = maybe_unpair_preference_dataset(
|
||||
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
||||
)
|
||||
eval_dataset = eval_dataset.map(
|
||||
maybe_apply_chat_template,
|
||||
fn_kwargs={"tokenizer": processing_class},
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Applying chat template to eval dataset",
|
||||
)
|
||||
|
||||
# Tokenize and prepare the training datasets
|
||||
train_dataset = train_dataset.map(
|
||||
_tokenize,
|
||||
batched=True,
|
||||
fn_kwargs={"tokenizer": self.processing_class},
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Tokenizing train dataset",
|
||||
)
|
||||
|
||||
fn_kwargs = {
|
||||
"prefix": "",
|
||||
"is_encoder_decoder": self.is_encoder_decoder,
|
||||
"tokenizer": self.processing_class,
|
||||
"max_length": self.max_length,
|
||||
"truncation_mode": self.truncation_mode,
|
||||
"label_pad_token_id": self.label_pad_token_id,
|
||||
"max_prompt_length": self.max_prompt_length,
|
||||
"max_completion_length": self.max_completion_length,
|
||||
}
|
||||
|
||||
train_dataset = train_dataset.map(
|
||||
_process_tokens,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Processing tokenized train dataset",
|
||||
)
|
||||
|
||||
# Tokenize and prepare the eval datasets
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(
|
||||
_tokenize,
|
||||
fn_kwargs={"tokenizer": self.processing_class},
|
||||
batched=True,
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Tokenizing eval dataset",
|
||||
)
|
||||
|
||||
eval_dataset = eval_dataset.map(
|
||||
_process_tokens,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Processing tokenized eval dataset",
|
||||
)
|
||||
|
||||
# Get KL datasets if needed
|
||||
if self.calculate_KL:
|
||||
if args.per_device_train_batch_size <= 1:
|
||||
raise ValueError(
|
||||
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
||||
)
|
||||
|
||||
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
||||
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
|
||||
train_kl_dataset = train_dataset.map(
|
||||
_get_kl_dataset,
|
||||
batched=True,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Extracting KL train dataset",
|
||||
)
|
||||
|
||||
fn_kwargs["prefix"] = "KL_"
|
||||
train_kl_dataset = train_kl_dataset.map(
|
||||
_process_tokens,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
||||
desc="Processing tokenized train KL dataset",
|
||||
)
|
||||
|
||||
# merge the datasets
|
||||
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
||||
|
||||
if eval_dataset is not None:
|
||||
# Get KL dataset
|
||||
eval_kl_dataset = eval_dataset.map(
|
||||
_get_kl_dataset,
|
||||
batched=True,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Extracting eval KL dataset",
|
||||
)
|
||||
|
||||
eval_kl_dataset = eval_kl_dataset.map(
|
||||
_process_tokens,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
||||
desc="Processing tokenized eval KL dataset",
|
||||
)
|
||||
|
||||
# merge the datasets
|
||||
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
||||
|
||||
# calculate dataset desirability balance
|
||||
num_desirable = max(sum(train_dataset["label"]), 1)
|
||||
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
||||
|
||||
if num_desirable != num_undesirable:
|
||||
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
|
||||
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
||||
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
||||
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
||||
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
||||
|
||||
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
||||
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
||||
|
||||
if not (des_weight_in_range or und_weight_in_range):
|
||||
warnings.warn(
|
||||
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
||||
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
||||
f"on your data, we recommend EITHER "
|
||||
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
||||
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
||||
"See the documentation on how to optimally set these weights.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
||||
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
||||
# self.model_accepts_loss_kwargs to False to enable scaling.
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError(
|
||||
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||
)
|
||||
|
||||
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
||||
if self.is_deepspeed_enabled:
|
||||
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
||||
raise ValueError(
|
||||
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
||||
)
|
||||
|
||||
if self.ref_model is None:
|
||||
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
||||
raise ValueError(
|
||||
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
||||
)
|
||||
else:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
@@ -23,6 +23,8 @@ import importlib
|
||||
import logging
|
||||
from typing import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
"""
|
||||
@@ -469,3 +471,14 @@ class PluginManager:
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train_unload(cfg)
|
||||
|
||||
|
||||
class BaseOptimizerFactory:
|
||||
"""
|
||||
Base class for factories to create custom optimizers
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, opt_model, training_args, **optimizer_kwargs
|
||||
) -> "torch.optim.Optimizer":
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,22 @@ Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy o
|
||||
|
||||
See https://github.com/apple/ml-cross-entropy
|
||||
|
||||
## Requirements
|
||||
|
||||
- PyTorch 2.4.0 or higher
|
||||
|
||||
## Installation
|
||||
|
||||
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
|
||||
|
||||
```bash
|
||||
# if you are in dev environment
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
# if you are not in dev environment
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class SpectrumArgs(BaseModel):
|
||||
@@ -27,3 +27,20 @@ class SpectrumArgs(BaseModel):
|
||||
|
||||
spectrum_top_fraction: Optional[float] = 0.5
|
||||
spectrum_model_name: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_use_orig_params(cls, data):
|
||||
if (
|
||||
data.get("fsdp")
|
||||
and data.get("fsdp_config")
|
||||
and not data["fsdp_config"].get("use_orig_params")
|
||||
and data.get("plugins")
|
||||
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
|
||||
):
|
||||
# would otherwise raise
|
||||
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
|
||||
raise ValueError(
|
||||
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -1,26 +1,29 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from peft import PeftModel
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from datasets import Dataset
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
@@ -32,17 +35,25 @@ try:
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def train(
|
||||
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||
def setup_model_and_tokenizer(
|
||||
cfg: DictDefault,
|
||||
) -> tuple[
|
||||
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
|
||||
]:
|
||||
"""
|
||||
Load the tokenizer, processor (for multimodal models), and model based on configuration.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
|
||||
`None`), and processor (if multimodal, else `None`).
|
||||
"""
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
@@ -55,11 +66,58 @@ def train(
|
||||
if cfg.is_multimodal:
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
# Load the model and peft_config
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
|
||||
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
# Apply freezing if specified
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
return model, tokenizer, peft_config, processor
|
||||
|
||||
|
||||
def setup_reference_model(
|
||||
cfg: DictDefault, tokenizer: PreTrainedTokenizer
|
||||
) -> PreTrainedModel | None:
|
||||
"""
|
||||
Set up the reference model for RL training if needed.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
tokenizer: The tokenizer to use for the reference model.
|
||||
|
||||
Returns:
|
||||
Reference model if needed for RL training, `None` otherwise.
|
||||
"""
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
# use built-in trl autounwrap
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||
return model_ref
|
||||
|
||||
|
||||
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
||||
"""
|
||||
Determine the checkpoint to resume from based on configuration.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
Path to the checkpoint to resume from, or `None` if not resuming.
|
||||
"""
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
@@ -73,77 +131,22 @@ def train(
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
return cfg.resume_from_checkpoint
|
||||
|
||||
# Load the model and tokenizer
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
# use built-in trl autounwrap
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||
def setup_signal_handler(
|
||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||
):
|
||||
"""
|
||||
Set up signal handler for graceful termination.
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
(model, model_ref, peft_config),
|
||||
tokenizer,
|
||||
processor,
|
||||
total_num_steps,
|
||||
)
|
||||
|
||||
if cfg.fix_untrained_tokens:
|
||||
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||
sig = inspect.signature(fix_untrained_tokens)
|
||||
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||
cfg.fix_untrained_tokens, list
|
||||
):
|
||||
fix_untrained_tokens(
|
||||
model,
|
||||
tokenizer,
|
||||
train_dataset,
|
||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||
)
|
||||
else:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
peft_config.save_pretrained(cfg.output_dir)
|
||||
# additionally presave the tokenizer and model configs
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if (
|
||||
cfg.local_rank == 0 and not cfg.use_ray
|
||||
): # ray workers don't have access to this signal
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
model: The model to save on termination
|
||||
safe_serialization: Whether to use safe serialization when saving
|
||||
"""
|
||||
# ray workers don't have access to this signal
|
||||
if cfg.local_rank == 0 and not cfg.use_ray:
|
||||
|
||||
def terminate_handler(_, __, model_weakref):
|
||||
if model_weakref() is not None:
|
||||
@@ -161,21 +164,22 @@ def train(
|
||||
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||
)
|
||||
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
if getattr(cfg, "axolotl_config_path"):
|
||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||
version = get_distribution("axolotl").version
|
||||
if raw_axolotl_cfg.is_file():
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||
def execute_training(
|
||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||
):
|
||||
"""
|
||||
Execute the training process with appropriate backend configurations.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
trainer: The configured trainer object.
|
||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||
"""
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
@@ -187,15 +191,30 @@ def train(
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
trainer: Any,
|
||||
model: PreTrainedModel,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
Save the trained model according to configuration and training setup.
|
||||
|
||||
# post training
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
trainer: The trainer object.
|
||||
model: The trained model to save.
|
||||
safe_serialization: Whether to use safe serialization.
|
||||
"""
|
||||
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
|
||||
|
||||
# Post training module hooks
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "_post_training"):
|
||||
module._post_training(model, name) # pylint: disable=protected-access
|
||||
|
||||
# Handle FSDP state dict type
|
||||
state_dict_type = "FULL_STATE_DICT"
|
||||
if trainer.is_fsdp_enabled:
|
||||
if cfg.fsdp_final_state_dict_type:
|
||||
@@ -203,16 +222,18 @@ def train(
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
|
||||
|
||||
# Handle ReLoRA early return case
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
model = model.merge_and_unload()
|
||||
else:
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return model, tokenizer
|
||||
return
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple
|
||||
# processes attempt to write the same file
|
||||
if (
|
||||
state_dict_type == "SHARDED_STATE_DICT"
|
||||
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
|
||||
@@ -244,7 +265,6 @@ def train(
|
||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
@@ -255,58 +275,241 @@ def train(
|
||||
)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||
"""
|
||||
Create a model card for the trained model if needed.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
trainer: The trainer object with model card creation capabilities.
|
||||
"""
|
||||
if not cfg.hub_model_id:
|
||||
# Guard since create_model_card may fail if dataset_tags is empty list
|
||||
try:
|
||||
model_card_kwarg = {
|
||||
"model_name": cfg.output_dir.lstrip("./")
|
||||
.encode("utf-8")
|
||||
.decode("utf-8")
|
||||
}
|
||||
if cfg.datasets is not None:
|
||||
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
|
||||
dataset_tags = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
dataset_tags = [
|
||||
d for d in dataset_tags if not d.startswith("https://")
|
||||
]
|
||||
if dataset_tags:
|
||||
# guard as create_model_card may fail if dataset_tags is empty list
|
||||
model_card_kwarg["dataset_name"] = dataset_tags
|
||||
else:
|
||||
dataset_tags = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
dataset_tags = [
|
||||
d for d in dataset_tags if not d.startswith("https://")
|
||||
]
|
||||
if dataset_tags:
|
||||
# guard as create_model_card may fail if dataset_tags is empty list
|
||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||
|
||||
# We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
|
||||
rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
|
||||
if cfg.datasets is not None and not rl:
|
||||
dataset_tags = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
|
||||
|
||||
if dataset_tags:
|
||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||
|
||||
trainer.create_model_card(**model_card_kwarg)
|
||||
except (AttributeError, UnicodeDecodeError):
|
||||
pass
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
# Defensively push to the hub to ensure the model card is updated
|
||||
trainer.push_to_hub()
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def save_initial_configs(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PreTrainedModel,
|
||||
peft_config: PeftConfig | None,
|
||||
):
|
||||
"""
|
||||
Save initial configurations before training.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
tokenizer: The tokenizer to save.
|
||||
model: The model to save configuration for.
|
||||
peft_config: The PEFT configuration to save if applicable.
|
||||
"""
|
||||
# Create output_dir if it doesn't already exist
|
||||
output_dir = Path(cfg.output_dir)
|
||||
if not output_dir.is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
|
||||
# Pre-save adapter config so it's available to inspect
|
||||
if peft_config:
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
|
||||
peft_config.save_pretrained(cfg.output_dir)
|
||||
|
||||
# Pre-save the tokenizer and model configs
|
||||
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
|
||||
tokenizer.save_pretrained(str(output_dir))
|
||||
if hasattr(model, "config"):
|
||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||
model.config.save_pretrained(str(output_dir))
|
||||
|
||||
|
||||
def pretrain_hooks(_cfg, _trainer):
|
||||
def setup_model_card(cfg: DictDefault):
|
||||
"""
|
||||
Run hooks right before kicking off the training
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
Set up the Axolotl badge and add the Axolotl config to the model card if available.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
"""
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
if getattr(cfg, "axolotl_config_path"):
|
||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||
version = importlib.metadata.version("axolotl")
|
||||
if raw_axolotl_cfg.is_file():
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||
|
||||
|
||||
def post_train_hooks(_cfg, _trainer):
|
||||
def handle_untrained_tokens_fix(
|
||||
cfg: DictDefault,
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
train_dataset: Dataset,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
Run hooks right after training completes
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
Apply fixes for untrained tokens if configured.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
model: The model to apply fixes to.
|
||||
tokenizer: The tokenizer for token identification.
|
||||
train_dataset: The training dataset to use.
|
||||
safe_serialization: Whether to use safe serialization when saving.
|
||||
"""
|
||||
if not cfg.fix_untrained_tokens:
|
||||
return
|
||||
|
||||
is_ds_zero3: bool = False
|
||||
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
|
||||
is_ds_zero3 = True
|
||||
|
||||
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||
sig = inspect.signature(fix_untrained_tokens)
|
||||
|
||||
fix_kwargs: Dict[str, Any] = {}
|
||||
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||
cfg.fix_untrained_tokens, list
|
||||
):
|
||||
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
|
||||
if "is_ds_zero3" in sig.parameters:
|
||||
fix_kwargs["is_ds_zero3"] = is_ds_zero3
|
||||
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
|
||||
def setup_model_and_trainer(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> tuple[
|
||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
||||
PeftModel | PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PeftConfig | None,
|
||||
]:
|
||||
"""
|
||||
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
||||
trainer setup.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary with training parameters.
|
||||
dataset_meta: Object with training, validation datasets and metadata.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Trainer (Causal or RLHF)
|
||||
- Model
|
||||
- Tokenizer
|
||||
- PEFT config
|
||||
"""
|
||||
# Load tokenizer, processor and model
|
||||
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
||||
|
||||
# Set up reference model for RL if needed
|
||||
model_ref = setup_reference_model(cfg, tokenizer)
|
||||
|
||||
# Get datasets from metadata
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# Set up trainer
|
||||
trainer = setup_trainer(
|
||||
cfg=cfg,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
total_num_steps=total_num_steps,
|
||||
model_ref=model_ref,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
return (
|
||||
trainer,
|
||||
model,
|
||||
tokenizer,
|
||||
peft_config,
|
||||
)
|
||||
|
||||
|
||||
def train(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
||||
"""
|
||||
Train a model on the given dataset.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary with training parameters
|
||||
dataset_meta: Object with training, validation datasets and metadata
|
||||
|
||||
Returns:
|
||||
Tuple of (model, tokenizer) after training
|
||||
"""
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer etc.
|
||||
(
|
||||
trainer,
|
||||
model,
|
||||
tokenizer,
|
||||
peft_config,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
# Determine if we need to resume from a checkpoint
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
|
||||
# Configuration for saving
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
handle_untrained_tokens_fix(
|
||||
cfg, model, tokenizer, train_dataset, safe_serialization
|
||||
)
|
||||
|
||||
# Save initial configs
|
||||
save_initial_configs(cfg, tokenizer, model, peft_config)
|
||||
|
||||
# Set up signal handler for graceful termination
|
||||
setup_signal_handler(cfg, model, safe_serialization)
|
||||
|
||||
# Set up badges and config info for model card
|
||||
setup_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# Save the trained model
|
||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||
|
||||
# Create model card
|
||||
create_model_card(cfg, trainer)
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -64,6 +64,17 @@ class ChatTemplate(str, Enum):
|
||||
metharme = "metharme" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CustomSupportedOptimizers(str, Enum):
|
||||
"""Custom supported optimizers"""
|
||||
|
||||
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||
muon = "muon" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
"""configurations that are deprecated"""
|
||||
|
||||
@@ -494,17 +505,7 @@ class HyperparametersConfig(BaseModel):
|
||||
embedding_lr_scale: Optional[float] = None
|
||||
weight_decay: Optional[float] = 0.0
|
||||
optimizer: Optional[
|
||||
Union[
|
||||
OptimizerNames,
|
||||
Literal[
|
||||
"lion_pytorch",
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
],
|
||||
]
|
||||
Union[OptimizerNames, CustomSupportedOptimizers]
|
||||
] = OptimizerNames.ADAMW_HF
|
||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
@@ -518,7 +519,7 @@ class HyperparametersConfig(BaseModel):
|
||||
)
|
||||
torchdistx_path: Optional[str] = None
|
||||
lr_scheduler: Optional[
|
||||
Union[SchedulerType, Literal["one_cycle"]]
|
||||
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
|
||||
] = SchedulerType.COSINE
|
||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||
lr_quadratic_warmup: Optional[bool] = None
|
||||
@@ -727,7 +728,7 @@ class AxolotlInputConfig(
|
||||
default=None,
|
||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||
)
|
||||
dataset_processes: Optional[int] = Field(default=os.cpu_count())
|
||||
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
|
||||
dataset_exact_deduplication: Optional[bool] = None
|
||||
dataset_keep_in_memory: Optional[bool] = None
|
||||
dataloader_pin_memory: Optional[bool] = None
|
||||
@@ -778,9 +779,9 @@ class AxolotlInputConfig(
|
||||
|
||||
# torch_dtype: Optional[torch.dtype]
|
||||
|
||||
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
|
||||
default=False
|
||||
)
|
||||
gradient_checkpointing: Optional[
|
||||
Union[Literal["unsloth", "offload"], bool]
|
||||
] = Field(default=False)
|
||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
unfrozen_parameters: Optional[List[str]] = None
|
||||
@@ -855,6 +856,7 @@ class AxolotlInputConfig(
|
||||
|
||||
special_tokens: Optional[SpecialTokensConfig] = None
|
||||
tokens: Optional[List[str]] = None
|
||||
added_tokens_overrides: Optional[Dict[int, str]] = None
|
||||
|
||||
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
||||
torch_compile_backend: Optional[str] = None
|
||||
@@ -1153,6 +1155,15 @@ class AxolotlInputConfig(
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_offload_grad_checkpointing(self):
|
||||
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
||||
LOG.warning(
|
||||
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
|
||||
)
|
||||
self.gradient_checkpointing = "offload"
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_better_transformers(self):
|
||||
if self.flash_optimum is True:
|
||||
@@ -1177,6 +1188,13 @@ class AxolotlInputConfig(
|
||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lr_groups(cls, data):
|
||||
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
|
||||
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_saves(cls, data):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
GRPO specific configuration args
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -11,7 +12,10 @@ class TRLConfig(BaseModel):
|
||||
Input args for TRL.
|
||||
"""
|
||||
|
||||
beta: Optional[float] = None
|
||||
beta: Optional[float] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Beta for RL training"},
|
||||
)
|
||||
max_completion_length: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -20,16 +24,68 @@ class TRLConfig(BaseModel):
|
||||
)
|
||||
|
||||
# GRPO specific args
|
||||
use_vllm: Optional[bool] = False
|
||||
vllm_device: Optional[str] = "auto"
|
||||
vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||
vllm_max_model_len: Optional[int] = None
|
||||
vllm_dtype: Optional[str] = "auto"
|
||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||
use_vllm: Optional[bool] = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||
)
|
||||
vllm_device: Optional[str] = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Device to use for VLLM"},
|
||||
)
|
||||
vllm_gpu_memory_utilization: Optional[float] = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
)
|
||||
vllm_dtype: Optional[str] = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Data type for VLLM"},
|
||||
)
|
||||
vllm_max_model_len: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the model context for VLLM"
|
||||
},
|
||||
)
|
||||
|
||||
reward_funcs: Optional[List[str]] = None
|
||||
num_generations: Optional[int] = None
|
||||
log_completions: Optional[bool] = False
|
||||
|
||||
sync_ref_model: Optional[bool] = False
|
||||
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||
ref_model_sync_steps: Optional[int] = 64
|
||||
reward_funcs: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "List of reward functions to load"},
|
||||
)
|
||||
reward_weights: Optional[list[float]] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||
},
|
||||
)
|
||||
num_generations: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||
},
|
||||
)
|
||||
log_completions: Optional[bool] = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to log completions"},
|
||||
)
|
||||
sync_ref_model: Optional[bool] = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Whether to sync the reference model every `ref_model_sync_steps` "
|
||||
"steps, using the `ref_model_mixup_alpha` parameter."
|
||||
)
|
||||
},
|
||||
)
|
||||
ref_model_mixup_alpha: Optional[float] = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={
|
||||
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
ref_model_sync_steps: Optional[int] = Field(
|
||||
default=64,
|
||||
json_schema_extra={
|
||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -121,6 +121,7 @@ def drop_long_rl_seq(
|
||||
|
||||
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
import pdb; pdb.set_trace()
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
use_auth_token = _cfg.hf_use_auth_token
|
||||
|
||||
@@ -79,7 +79,7 @@ def is_main_process():
|
||||
|
||||
|
||||
def is_local_main_process():
|
||||
return PartialState().is_main_process
|
||||
return PartialState().is_local_main_process
|
||||
|
||||
|
||||
def get_world_size():
|
||||
|
||||
@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||
)
|
||||
|
||||
|
||||
def hf_grad_checkpoint_unsloth_wrapper(
|
||||
def hf_grad_checkpoint_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
|
||||
@@ -57,8 +57,14 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
get_device_count,
|
||||
get_device_type,
|
||||
is_local_main_process,
|
||||
zero_only,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
|
||||
@@ -165,7 +171,95 @@ def load_model_config(cfg):
|
||||
return model_config
|
||||
|
||||
|
||||
def modify_tokenizer_files(
|
||||
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
|
||||
) -> str:
|
||||
"""
|
||||
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
|
||||
|
||||
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
|
||||
|
||||
Args:
|
||||
tokenizer_path: Path or name of the original tokenizer
|
||||
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||
output_dir: Directory to save the modified tokenizer
|
||||
|
||||
Returns:
|
||||
Path to the modified tokenizer directory
|
||||
|
||||
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
# Create the tokenizer directory in output_dir if it doesn't exist
|
||||
tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
|
||||
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
|
||||
# Load the tokenizer
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
||||
|
||||
# Save the tokenizer to the output directory
|
||||
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||
|
||||
# Get the token IDs and map them to their new values
|
||||
token_id_mappings = {
|
||||
int(token_id): new_value for token_id, new_value in token_mappings.items()
|
||||
}
|
||||
|
||||
# 1. Update tokenizer_config.json - added_tokens_decoder
|
||||
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# Update added_tokens_decoder
|
||||
if "added_tokens_decoder" in config_data:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
token_id_str = str(token_id)
|
||||
if token_id_str in config_data["added_tokens_decoder"]:
|
||||
config_data["added_tokens_decoder"][token_id_str][
|
||||
"content"
|
||||
] = new_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||
)
|
||||
|
||||
# Write the updated config back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
# 2. Update tokenizer.json - added_tokens
|
||||
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||
if os.path.exists(tokenizer_path):
|
||||
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||
tokenizer_data = json.load(f)
|
||||
|
||||
# Update added_tokens
|
||||
if "added_tokens" in tokenizer_data:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||
if token_entry["id"] == token_id:
|
||||
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||
break
|
||||
else:
|
||||
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
|
||||
raise ValueError(
|
||||
f"Token ID {token_id} not found in added_tokens"
|
||||
)
|
||||
|
||||
# Write the updated tokenizer data back
|
||||
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokenizer_data, f, indent=2)
|
||||
|
||||
barrier()
|
||||
return tokenizer_dir
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
@@ -180,8 +274,18 @@ def load_tokenizer(cfg):
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
|
||||
# Set base tokenizer path
|
||||
tokenizer_path = cfg.tokenizer_config
|
||||
|
||||
# Apply token string overrides if specified
|
||||
if cfg.added_tokens_overrides:
|
||||
# Modify tokenizer files and get path to modified tokenizer
|
||||
tokenizer_path = modify_tokenizer_files(
|
||||
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
||||
)
|
||||
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
cfg.tokenizer_config,
|
||||
tokenizer_path,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
@@ -389,8 +493,8 @@ class ModelLoader:
|
||||
|
||||
patch_fa_peft_integration()
|
||||
|
||||
if self.cfg.gradient_checkpointing == "unsloth":
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
self.patch_attention()
|
||||
|
||||
@@ -6,6 +6,80 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
|
||||
class RexLR(LRScheduler):
|
||||
"""
|
||||
Reflected Exponential (REX) learning rate scheduler.
|
||||
|
||||
- Original implementation: https://github.com/IvanVassi/REX_LR
|
||||
- Original license: Apache 2.0
|
||||
- Based on: https://arxiv.org/abs/2107.04197
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.
|
||||
max_lr (float): The maximum learning rate.
|
||||
min_lr (float): The minimum learning rate.
|
||||
total_steps (int): The total number of training steps.
|
||||
num_warmup_steps (int): The number of warmup steps.
|
||||
last_step (int): The index of last step.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0
|
||||
):
|
||||
if min_lr > max_lr:
|
||||
raise ValueError(
|
||||
f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}'
|
||||
)
|
||||
if num_warmup_steps > total_steps:
|
||||
raise ValueError(
|
||||
f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})."
|
||||
)
|
||||
|
||||
self.min_lr = min_lr
|
||||
self.max_lr = max_lr
|
||||
self.total_steps = total_steps
|
||||
self.num_warmup_steps = num_warmup_steps
|
||||
self.last_step = last_step - 1
|
||||
|
||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
|
||||
# Pass self.last_step as last_epoch to the parent.
|
||||
super().__init__(optimizer, last_epoch=self.last_step)
|
||||
|
||||
@property
|
||||
def last_step(self):
|
||||
return self.last_epoch
|
||||
|
||||
@last_step.setter
|
||||
def last_step(self, value):
|
||||
self.last_epoch = value
|
||||
|
||||
def get_lr(self):
|
||||
# Warmup phase: if defined, increase lr linearly from 0 to max_lr.
|
||||
if 1 <= self.last_step <= self.num_warmup_steps:
|
||||
return [
|
||||
base_lr * self.last_step / self.num_warmup_steps
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
# Post-warmup phase: adjust step relative to the end of warmup.
|
||||
step_after = self.last_step - self.num_warmup_steps
|
||||
remaining_steps = self.total_steps - self.num_warmup_steps
|
||||
|
||||
# Avoid LR spiking
|
||||
if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
mod_iter = step_after % remaining_steps
|
||||
z = (remaining_steps - mod_iter) / remaining_steps
|
||||
rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (
|
||||
z / (0.1 + 0.9 * z)
|
||||
)
|
||||
return [base_lr * rex_factor for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class InterpolatingLogScheduler(LRScheduler):
|
||||
"""
|
||||
A scheduler that interpolates learning rates in a logarithmic fashion
|
||||
|
||||
@@ -574,14 +574,40 @@ def prepare_opinionated_env(cfg):
|
||||
|
||||
|
||||
def setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||
cfg,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
model,
|
||||
tokenizer,
|
||||
processor,
|
||||
total_num_steps,
|
||||
model_ref=None,
|
||||
peft_config=None,
|
||||
):
|
||||
"""
|
||||
Helper method for instantiating and building a (causal or RLHF) trainer.
|
||||
|
||||
Args:
|
||||
cfg: Axolotl config object containing training parameters.
|
||||
train_dataset: Dataset to use for training.
|
||||
eval_dataset: Dataset to use for evaluation.
|
||||
model: The model to train.
|
||||
tokenizer: Tokenizer for processing text input.
|
||||
processor: Processor for data preparation.
|
||||
total_num_steps: The total number of training steps.
|
||||
model_ref: Optional reference model for RLHF training. Default is None.
|
||||
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
|
||||
|
||||
Returns:
|
||||
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
||||
on the provided parameters.
|
||||
"""
|
||||
if cfg.rl:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||
trainer_builder.model_ref = model[1]
|
||||
trainer_builder.peft_config = model[2]
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
||||
trainer_builder.model_ref = model_ref
|
||||
trainer_builder.peft_config = peft_config
|
||||
else:
|
||||
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
|
||||
|
||||
trainer_builder.train_dataset = train_dataset
|
||||
trainer_builder.eval_dataset = eval_dataset
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
h1 {
|
||||
font-family: var(--font-title);
|
||||
font-weight: 400;
|
||||
font-size: 6rem;
|
||||
font-size: 5rem;
|
||||
line-height: 1.1;
|
||||
letter-spacing: -0.05em;
|
||||
font-feature-settings: "ss01" on;
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("axolotl.cli.train.train") as mock_train:
|
||||
mock_train.return_value = (MagicMock(), MagicMock())
|
||||
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
||||
|
||||
with patch("axolotl.cli.train.train") as mock_train:
|
||||
mock_train.return_value = (MagicMock(), MagicMock())
|
||||
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
|
||||
@@ -69,6 +69,51 @@ class TestCutCrossEntropyIntegration:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_qwen2_w_cce(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"plugins": [
|
||||
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
|
||||
],
|
||||
"cut_cross_entropy": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"output_dir": temp_dir,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"max_steps": 10,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
with pytest.raises(ImportError):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
else:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention_type",
|
||||
[
|
||||
|
||||
@@ -750,3 +750,66 @@ class TestMultiGPULlama:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
def test_fix_untrained_tokens(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"fix_untrained_tokens": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
"bos_token": "<|custom_im_start|>",
|
||||
"eos_token": "<|custom_im_end|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "jinja",
|
||||
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@@ -66,6 +66,54 @@ class TestLlama:
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_fix_untrained_tokens(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"fix_untrained_tokens": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
"bos_token": "<|custom_im_start|>",
|
||||
"eos_token": "<|custom_im_end|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "jinja",
|
||||
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_fix_untrained_tokens_already_trained(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
|
||||
@@ -65,8 +65,9 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
@@ -111,8 +112,57 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
def test_muon(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "muon",
|
||||
"lr_scheduler": "cosine",
|
||||
"weight_decay": 0.01,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||
|
||||
71
tests/e2e/test_schedulers.py
Normal file
71
tests/e2e/test_schedulers.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
E2E tests for custom schedulers using Llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestCustomSchedulers(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_rex_scheduler(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_hf",
|
||||
"max_steps": 20,
|
||||
"lr_scheduler": "rex",
|
||||
"warmup_steps": 5,
|
||||
"cosine_min_lr_ratio": 0.05,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Test cases for the tokenizer loading
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -9,7 +10,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
|
||||
|
||||
class TestTokenizers(unittest.TestCase):
|
||||
class TestTokenizers:
|
||||
"""
|
||||
test class for the load_tokenizer fn
|
||||
"""
|
||||
@@ -75,12 +76,48 @@ class TestTokenizers(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
|
||||
self.assertEqual(len(tokenizer), 32001)
|
||||
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
|
||||
assert len(tokenizer) == 32001
|
||||
|
||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
self.assertEqual(len(tokenizer), 32001)
|
||||
assert len(tokenizer) == 32001
|
||||
|
||||
def test_added_tokens_overrides(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
# use with tokenizer that has reserved_tokens in added_tokens
|
||||
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
||||
"added_tokens_overrides": {
|
||||
128041: "RANDOM_OVERRIDE_1",
|
||||
128042: "RANDOM_OVERRIDE_2",
|
||||
},
|
||||
"output_dir": temp_dir,
|
||||
}
|
||||
)
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
|
||||
128041
|
||||
]
|
||||
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
|
||||
128042
|
||||
]
|
||||
|
||||
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
# use with tokenizer that has reserved_tokens in added_tokens
|
||||
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
||||
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
|
||||
"output_dir": temp_dir,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
|
||||
):
|
||||
load_tokenizer(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user