Compare commits
29 Commits
optimizers
...
fix_kto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c36a6fea6 | ||
|
|
64aca3c23c | ||
|
|
22abfd6170 | ||
|
|
0658c458b7 | ||
|
|
690908cf2f | ||
|
|
b9378e9b39 | ||
|
|
57b0ad1467 | ||
|
|
ec4ead6e3e | ||
|
|
a319ac7d3e | ||
|
|
09d3f2cffa | ||
|
|
85147ec430 | ||
|
|
51cd409488 | ||
|
|
7235123d44 | ||
|
|
4f5eb42a73 | ||
|
|
fbe54be6b8 | ||
|
|
04f6324833 | ||
|
|
f0072f3b9d | ||
|
|
59899b9817 | ||
|
|
4a736986fa | ||
|
|
5d0f110a3b | ||
|
|
83f8698b8a | ||
|
|
60a11a6410 | ||
|
|
46a045e528 | ||
|
|
3b477e08a0 | ||
|
|
16dc6ee68d | ||
|
|
fa7c79b3b9 | ||
|
|
ae66374156 | ||
|
|
5e21b1a9da | ||
|
|
575e5f28ec |
5
.github/workflows/main.yml
vendored
5
.github/workflows/main.yml
vendored
@@ -88,6 +88,11 @@ jobs:
|
|||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
5
.github/workflows/nightlies.yml
vendored
5
.github/workflows/nightlies.yml
vendored
@@ -80,6 +80,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ Features:
|
|||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
|
|
||||||
# Download example axolotl configs, deepspeed configs
|
# Download example axolotl configs, deepspeed configs
|
||||||
|
|||||||
@@ -32,8 +32,9 @@ website:
|
|||||||
contents:
|
contents:
|
||||||
- docs/getting-started.qmd
|
- docs/getting-started.qmd
|
||||||
- docs/installation.qmd
|
- docs/installation.qmd
|
||||||
- docs/cli.qmd
|
|
||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
|
- docs/cli.qmd
|
||||||
|
- docs/config.qmd
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
@@ -74,10 +75,6 @@ website:
|
|||||||
- docs/debugging.qmd
|
- docs/debugging.qmd
|
||||||
- docs/nccl.qmd
|
- docs/nccl.qmd
|
||||||
|
|
||||||
- section: "Reference"
|
|
||||||
contents:
|
|
||||||
- docs/config.qmd
|
|
||||||
|
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
theme: darkly
|
theme: darkly
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
|
|||||||
|
|
||||||
RUN pip install jupyterlab notebook ipywidgets && \
|
RUN pip install jupyterlab notebook ipywidgets && \
|
||||||
jupyter lab clean
|
jupyter lab clean
|
||||||
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||||
mkdir -p ~/.ssh && \
|
mkdir -p ~/.ssh && \
|
||||||
chmod 700 ~/.ssh && \
|
chmod 700 ~/.ssh && \
|
||||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: Config options
|
title: Config Reference
|
||||||
description: A complete list of all configuration options.
|
description: A complete list of all configuration options.
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -30,6 +30,8 @@ tokenizer_legacy:
|
|||||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||||
# This is reported to improve training speed on some models
|
# This is reported to improve training speed on some models
|
||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
|
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||||
|
shrink_embeddings:
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
@@ -154,8 +156,6 @@ datasets:
|
|||||||
content: value
|
content: value
|
||||||
# ...
|
# ...
|
||||||
|
|
||||||
message_property_mappings:
|
|
||||||
|
|
||||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||||
roles:
|
roles:
|
||||||
user: ["human", "user"]
|
user: ["human", "user"]
|
||||||
@@ -207,10 +207,46 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto'
|
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
||||||
rl:
|
rl:
|
||||||
# whether to perform weighting if doing DPO training. Boolean.
|
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
||||||
dpo_use_weighting:
|
|
||||||
|
# dpo
|
||||||
|
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
||||||
|
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
||||||
|
|
||||||
|
# orpo
|
||||||
|
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
||||||
|
|
||||||
|
# kto
|
||||||
|
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
||||||
|
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
||||||
|
|
||||||
|
# simpo
|
||||||
|
cpo_alpha: 1.0 # Weight of the BC regularizer
|
||||||
|
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||||
|
|
||||||
|
# grpo
|
||||||
|
trl:
|
||||||
|
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||||
|
vllm_device: # Optional[str]. Device to use for VLLM.
|
||||||
|
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
||||||
|
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
||||||
|
vllm_dtype: # Optional[str]. Data type for VLLM.
|
||||||
|
|
||||||
|
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||||
|
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||||
|
|
||||||
|
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
||||||
|
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
||||||
|
|
||||||
|
num_generations: # Optional[int]. Number of generations to sample.
|
||||||
|
log_completions: # Optional[bool]. Whether to log completions.
|
||||||
|
|
||||||
|
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
||||||
|
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
||||||
|
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
||||||
|
|
||||||
|
|
||||||
# reward modelling: `True` or `False`
|
# reward modelling: `True` or `False`
|
||||||
reward_model:
|
reward_model:
|
||||||
@@ -234,7 +270,7 @@ default_system_message: You are a helpful assistant. Please give a long and deta
|
|||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
# Push prepared dataset to hub
|
# Push prepared dataset to hub
|
||||||
push_dataset_to_hub: # repo path
|
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
||||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||||
# if not set.
|
# if not set.
|
||||||
dataset_processes: # defaults to os.cpu_count() if not set
|
dataset_processes: # defaults to os.cpu_count() if not set
|
||||||
@@ -556,6 +592,13 @@ special_tokens:
|
|||||||
# Add extra tokens.
|
# Add extra tokens.
|
||||||
tokens:
|
tokens:
|
||||||
|
|
||||||
|
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||||
|
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||||
|
# Can be checked if they exist in tokenizer.json added_tokens.
|
||||||
|
added_tokens_overrides: # Dict[int, str]
|
||||||
|
# 128041: "<|im_start|>"
|
||||||
|
# 128042: "<|im_end|>"
|
||||||
|
|
||||||
# FSDP
|
# FSDP
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
|||||||
@@ -74,6 +74,10 @@ datasets:
|
|||||||
train_on_eos:
|
train_on_eos:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
||||||
|
:::
|
||||||
|
|
||||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
14
docs/faq.qmd
14
docs/faq.qmd
@@ -27,6 +27,16 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||||
|
|
||||||
|
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
|
||||||
|
|
||||||
|
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
|
||||||
|
|
||||||
|
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
|
||||||
|
|
||||||
|
**Q: How to call Axolotl via custom python scripts?**
|
||||||
|
|
||||||
|
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
@@ -52,3 +62,7 @@ description: Frequently asked questions
|
|||||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
||||||
|
|
||||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
||||||
|
|
||||||
|
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
|
||||||
|
|
||||||
|
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
|
||||||
|
|||||||
@@ -36,7 +36,9 @@ The YAML configuration file controls everything about your training. Here's what
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: NousResearch/Llama-3.2-1B
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
load_in_8bit: true
|
||||||
|
adapter: lora
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
@@ -44,11 +46,15 @@ datasets:
|
|||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
|
||||||
|
|
||||||
|
- To perform Full finetuning, remove these two lines.
|
||||||
|
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
||||||
|
:::
|
||||||
|
|
||||||
See our [Config options](config.qmd) for more details.
|
See our [Config options](config.qmd) for more details.
|
||||||
|
|
||||||
### Training {#sec-training}
|
### Training {#sec-training}
|
||||||
@@ -56,7 +62,7 @@ See our [Config options](config.qmd) for more details.
|
|||||||
When you run `axolotl train`, Axolotl:
|
When you run `axolotl train`, Axolotl:
|
||||||
|
|
||||||
1. Downloads the base model
|
1. Downloads the base model
|
||||||
2. (If specified) applies LoRA adapter layers
|
2. (If specified) applies QLoRA/LoRA adapter layers
|
||||||
3. Loads and processes the dataset
|
3. Loads and processes the dataset
|
||||||
4. Runs the training loop
|
4. Runs the training loop
|
||||||
5. Saves the trained model and / or LoRA weights
|
5. Saves the trained model and / or LoRA weights
|
||||||
@@ -69,6 +75,8 @@ Let's modify the example for your own data:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|
||||||
# Training settings
|
# Training settings
|
||||||
@@ -104,8 +112,6 @@ format):
|
|||||||
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
||||||
```
|
```
|
||||||
|
|
||||||
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
|
|
||||||
|
|
||||||
3. Run the training:
|
3. Run the training:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "Inference"
|
title: "Inference and Merging"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
@@ -9,10 +9,14 @@ execute:
|
|||||||
enabled: false
|
enabled: false
|
||||||
---
|
---
|
||||||
|
|
||||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
|
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
|
||||||
|
|
||||||
## Quick Start {#sec-quickstart}
|
## Quick Start {#sec-quickstart}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Use the same config used for training on inference/merging.
|
||||||
|
:::
|
||||||
|
|
||||||
### Basic Inference {#sec-basic}
|
### Basic Inference {#sec-basic}
|
||||||
|
|
||||||
::: {.panel-tabset}
|
::: {.panel-tabset}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
### PyPI Installation (Recommended) {#sec-pypi}
|
### PyPI Installation (Recommended) {#sec-pypi}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -37,7 +38,7 @@ For the latest features between releases:
|
|||||||
```{.bash}
|
```{.bash}
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip3 install packaging ninja
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -107,7 +108,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
|||||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||||
3. Install Axolotl:
|
3. Install Axolotl:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install packaging
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
4. (Optional) Login to Hugging Face:
|
4. (Optional) Login to Hugging Face:
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ logic to be compatible with more of them.
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
|
||||||
|
:::
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||||
|
|||||||
@@ -28,8 +28,23 @@ val_set_size: 0.1
|
|||||||
eval_steps: 100
|
eval_steps: 100
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Bradley-Terry chat templates expect single-turn conversations in the following format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"system": "...", // optional
|
||||||
|
"input": "...",
|
||||||
|
"chosen": "...",
|
||||||
|
"rejected": "..."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### Process Reward Models (PRM)
|
### Process Reward Models (PRM)
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
|
||||||
|
:::
|
||||||
|
|
||||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||||
```yaml
|
```yaml
|
||||||
base_model: Qwen/Qwen2.5-3B
|
base_model: Qwen/Qwen2.5-3B
|
||||||
@@ -45,3 +60,5 @@ datasets:
|
|||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
eval_steps: 100
|
eval_steps: 100
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ title: "RLHF (Beta)"
|
|||||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||||
back-to-top-navigation: true
|
back-to-top-navigation: true
|
||||||
toc: true
|
toc: true
|
||||||
|
toc-expand: 2
|
||||||
toc-depth: 4
|
toc-depth: 4
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -297,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
|
|
||||||
### IPO
|
### IPO
|
||||||
|
|
||||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: ipo
|
rl: ipo
|
||||||
@@ -343,8 +344,9 @@ ORPO supports the following types with the following dataset format:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: kto
|
rl: kto
|
||||||
rl_beta: 0.5
|
rl_beta: 0.1 # default
|
||||||
kto_desirable_weight: 0.2
|
kto_desirable_weight: 1.0 # default
|
||||||
|
kto_undesirable_weight: 1.0 # default
|
||||||
|
|
||||||
remove_unused_columns: false
|
remove_unused_columns: false
|
||||||
|
|
||||||
@@ -496,6 +498,10 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
|
|
||||||
### GRPO
|
### GRPO
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||||
|
:::
|
||||||
|
|
||||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||||
|
|
||||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||||
@@ -528,6 +534,7 @@ trl:
|
|||||||
vllm_gpu_memory_utilization: 0.15
|
vllm_gpu_memory_utilization: 0.15
|
||||||
num_generations: 4
|
num_generations: 4
|
||||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||||
|
reward_weights: [1.0]
|
||||||
datasets:
|
datasets:
|
||||||
- path: openai/gsm8k
|
- path: openai/gsm8k
|
||||||
name: main
|
name: main
|
||||||
@@ -536,6 +543,21 @@ datasets:
|
|||||||
|
|
||||||
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
||||||
|
|
||||||
|
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||||
|
|
||||||
|
### SimPO
|
||||||
|
|
||||||
|
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
rl: simpo
|
||||||
|
rl_beta: 0.1 # default in CPOTrainer
|
||||||
|
cpo_alpha: 1.0 # default in CPOTrainer
|
||||||
|
simpo_gamma: 0.5 # default in CPOTrainer
|
||||||
|
```
|
||||||
|
|
||||||
|
This method uses the same dataset format as [DPO](#dpo).
|
||||||
|
|
||||||
### Using local dataset files
|
### Using local dataset files
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ tf32: true
|
|||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: true
|
use_reentrant: false
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|||||||
@@ -62,4 +62,5 @@ antlr4-python3-runtime==4.13.2
|
|||||||
torchao==0.7.0
|
torchao==0.7.0
|
||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.3
|
axolotl-contribs-lgpl==0.0.6
|
||||||
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|||||||
@@ -24,5 +24,5 @@ if cce_spec:
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class ModalCloud(Cloud):
|
|||||||
[
|
[
|
||||||
# Random id for cache busting of branch commits
|
# Random id for cache busting of branch commits
|
||||||
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
||||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
|
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -270,6 +270,7 @@ def _preprocess(config_yaml: str, volumes=None):
|
|||||||
|
|
||||||
|
|
||||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||||
|
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||||
f_out.write(config_yaml)
|
f_out.write(config_yaml)
|
||||||
run_folder = "/workspace/mounts"
|
run_folder = "/workspace/mounts"
|
||||||
@@ -288,6 +289,7 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def _lm_eval(config_yaml: str, volumes=None):
|
def _lm_eval(config_yaml: str, volumes=None):
|
||||||
|
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||||
f_out.write(config_yaml)
|
f_out.write(config_yaml)
|
||||||
run_folder = "/workspace/mounts"
|
run_folder = "/workspace/mounts"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""CLI to run training on a model."""
|
"""CLI to run training on a model."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -34,18 +35,20 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
check_user_token()
|
||||||
|
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
del model
|
del model
|
||||||
del tokenizer
|
del tokenizer
|
||||||
|
del trainer
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
|
|||||||
process_or_cpu_count: int = (
|
process_or_cpu_count: int = (
|
||||||
process_count or os.cpu_count() # type: ignore[assignment]
|
process_count or os.cpu_count() # type: ignore[assignment]
|
||||||
)
|
)
|
||||||
num_proc = min(64, process_or_cpu_count)
|
num_proc = min(32, process_or_cpu_count)
|
||||||
features = data.features.keys()
|
features = data.features.keys()
|
||||||
tokenized_data = data.map(
|
tokenized_data = data.map(
|
||||||
map_fn,
|
map_fn,
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from transformers import (
|
|||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
@@ -84,6 +85,7 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -549,28 +551,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||||
else:
|
else:
|
||||||
training_arguments_kwargs["run_name"] = None
|
training_arguments_kwargs["run_name"] = None
|
||||||
training_arguments_kwargs["optim"] = (
|
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
|
||||||
)
|
|
||||||
if self.cfg.optim_args:
|
|
||||||
if isinstance(self.cfg.optim_args, dict):
|
|
||||||
optim_args = ",".join(
|
|
||||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optim_args = self.cfg.optim_args
|
|
||||||
training_arguments_kwargs["optim_args"] = optim_args
|
|
||||||
if self.cfg.optim_target_modules:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"optim_target_modules"
|
|
||||||
] = self.cfg.optim_target_modules
|
|
||||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"loraplus_lr_embedding"
|
|
||||||
] = self.cfg.loraplus_lr_embedding
|
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
|
||||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
@@ -656,46 +636,114 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# Handle custom optimizer
|
||||||
if self.cfg.optimizer in [
|
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
||||||
"optimi_adamw",
|
if self.cfg.optimizer in custom_supported_optimizers:
|
||||||
"ao_adamw_4bit",
|
# Common optimizer kwargs
|
||||||
"ao_adamw_8bit",
|
optimizer_kwargs = {
|
||||||
"ao_adamw_fp8",
|
"lr": training_arguments_kwargs.get("learning_rate"),
|
||||||
"adopt_adamw",
|
"weight_decay": training_arguments_kwargs.get("weight_decay"),
|
||||||
]:
|
}
|
||||||
# Set default so transformers doesn't throw
|
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
|
||||||
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
|
||||||
|
|
||||||
if self.cfg.optimizer == "lion_pytorch":
|
# Adam-specific kwargs
|
||||||
from lion_pytorch import Lion
|
adam_kwargs = {}
|
||||||
|
if training_arguments_kwargs.get(
|
||||||
|
"adam_beta1"
|
||||||
|
) and training_arguments_kwargs.get("adam_beta2"):
|
||||||
|
adam_kwargs["betas"] = (
|
||||||
|
training_arguments_kwargs.get("adam_beta1"),
|
||||||
|
training_arguments_kwargs.get("adam_beta2"),
|
||||||
|
)
|
||||||
|
if training_arguments_kwargs.get("adam_epsilon"):
|
||||||
|
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
|
||||||
|
|
||||||
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
|
if self.cfg.optimizer == "muon":
|
||||||
if "weight_decay" in training_arguments_kwargs:
|
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||||
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
|
MuonOptimizerFactory,
|
||||||
|
|
||||||
if (
|
|
||||||
"adam_beta1" in training_arguments_kwargs
|
|
||||||
and "adam_beta2" in training_arguments_kwargs
|
|
||||||
):
|
|
||||||
lion_kwargs["betas"] = (
|
|
||||||
training_arguments_kwargs["adam_beta1"],
|
|
||||||
training_arguments_kwargs["adam_beta2"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer_kwargs["optimizers"] = (
|
optimizer_cls = MuonOptimizerFactory
|
||||||
Lion(params=self.model.parameters(), **lion_kwargs),
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
None,
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
|
from optimi import AdamW
|
||||||
|
|
||||||
|
optimizer_kwargs["foreach"] = False
|
||||||
|
optimizer_cls = AdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||||
|
# TODO remove 20250401
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW4bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||||
|
)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW8bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
|
optimizer_cls = AdamWFp8
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "adopt_adamw":
|
||||||
|
from axolotl.utils.optimizers.adopt import ADOPT
|
||||||
|
|
||||||
|
optimizer_cls = ADOPT
|
||||||
|
adam_kwargs["decouple"] = True
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
# Parse any additional optimizer args from config
|
||||||
|
if self.cfg.optim_args:
|
||||||
|
if isinstance(self.cfg.optim_args, dict):
|
||||||
|
optimizer_kwargs.update(self.cfg.optim_args)
|
||||||
|
else:
|
||||||
|
# Parse string format "key1=value1,key2=value2"
|
||||||
|
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
|
||||||
|
key, value = mapping.split("=")
|
||||||
|
optimizer_kwargs[key] = value
|
||||||
|
|
||||||
|
trainer_kwargs["optimizer_cls_and_kwargs"] = (
|
||||||
|
optimizer_cls,
|
||||||
|
optimizer_kwargs,
|
||||||
)
|
)
|
||||||
# Set default so transformers doesn't throw
|
else:
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
# Use transformers' optimizer
|
||||||
|
training_arguments_kwargs["optim"] = self.cfg.optimizer
|
||||||
|
|
||||||
|
# Parse any additional optimizer args from config
|
||||||
|
if self.cfg.optim_args:
|
||||||
|
if isinstance(self.cfg.optim_args, dict):
|
||||||
|
optim_args = ",".join(
|
||||||
|
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optim_args = self.cfg.optim_args
|
||||||
|
training_arguments_kwargs["optim_args"] = optim_args
|
||||||
|
|
||||||
if self.cfg.optimizer == "adamw_anyprecision":
|
if self.cfg.optimizer == "adamw_anyprecision":
|
||||||
if Path(self.cfg.torchdistx_path).exists():
|
if Path(self.cfg.torchdistx_path).exists():
|
||||||
sys.path.append(self.cfg.torchdistx_path)
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
importlib.import_module("torchdistx")
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
|
if self.cfg.optim_target_modules:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"optim_target_modules"
|
||||||
|
] = self.cfg.optim_target_modules
|
||||||
|
|
||||||
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
|
||||||
|
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"loraplus_lr_embedding"
|
||||||
|
] = self.cfg.loraplus_lr_embedding
|
||||||
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
|
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"accelerator_config"
|
"accelerator_config"
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from typing import Dict, Literal, Optional
|
|||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
@@ -22,6 +23,7 @@ from transformers.utils import is_sagemaker_mp_enabled
|
|||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
@@ -166,47 +168,18 @@ class SchedulerMixin(Trainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
class OptimizerMixin(Trainer):
|
||||||
"""
|
"""
|
||||||
Extend the base Trainer for axolotl helpers
|
Mixin class for shared handling of building custom optimizers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
|
||||||
|
|
||||||
def __init__(
|
def create_optimizer_grouped_parameters(
|
||||||
self,
|
self, opt_model, optimizer_kwargs
|
||||||
*_args,
|
) -> list[dict]:
|
||||||
bench_data_collator=None,
|
|
||||||
eval_data_collator=None,
|
|
||||||
dataset_tags=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.bench_data_collator = bench_data_collator
|
|
||||||
self.eval_data_collator = eval_data_collator
|
|
||||||
self.dataset_tags = dataset_tags
|
|
||||||
self._signature_columns = None # workaround for pylint
|
|
||||||
super().__init__(*_args, **kwargs)
|
|
||||||
self.train_data_collator = self.data_collator
|
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
|
||||||
if self.args.orpo_alpha:
|
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
|
||||||
if self.args.torch_compile:
|
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
|
||||||
256
|
|
||||||
)
|
|
||||||
model = torch.compile(
|
|
||||||
model,
|
|
||||||
backend=self.args.torch_compile_backend,
|
|
||||||
mode=self.args.torch_compile_mode,
|
|
||||||
)
|
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
|
||||||
|
|
||||||
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
params = {
|
params: dict = {
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
"no_weight_decay": {},
|
"no_weight_decay": {},
|
||||||
@@ -293,23 +266,30 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
and self.args.embedding_lr_scale is None
|
and self.args.embedding_lr_scale is None
|
||||||
and self.args.embedding_lr is None
|
and self.args.embedding_lr is None
|
||||||
and self.args.lr_groups is None
|
and self.args.lr_groups is None
|
||||||
and self.args.alternate_optimizer
|
and self.optimizer_cls_and_kwargs is None
|
||||||
not in [
|
|
||||||
"optimi_adamw",
|
|
||||||
"ao_adamw_8bit",
|
|
||||||
"ao_adamw_4bit",
|
|
||||||
"ao_adamw_fp8",
|
|
||||||
"adopt_adamw",
|
|
||||||
]
|
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
if (
|
||||||
self.args,
|
not self.optimizer
|
||||||
opt_model,
|
and self.optimizer_cls_and_kwargs is not None
|
||||||
|
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||||
|
):
|
||||||
|
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
self.optimizer = optimizer_factory_cls()(
|
||||||
|
opt_model, self.args, **optimizer_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.optimizer:
|
||||||
|
if self.optimizer_cls_and_kwargs is not None:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
else:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args, opt_model
|
||||||
|
)
|
||||||
|
|
||||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
opt_model, optimizer_kwargs
|
opt_model, optimizer_kwargs
|
||||||
)
|
)
|
||||||
@@ -326,50 +306,47 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
**optimizer_kwargs,
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
elif (
|
else:
|
||||||
self.args.embedding_lr_scale is not None
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
or self.args.embedding_lr is not None
|
# e.g. for GaLore optimizer.
|
||||||
or self.args.lr_groups is not None
|
if "params" in optimizer_kwargs:
|
||||||
):
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
|
||||||
from optimi import AdamW
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
AdamW(
|
# e.g. for LOMO optimizer.
|
||||||
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
if "model" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
|
|
||||||
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
|
# to avoid arguments conflicts.
|
||||||
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||||
|
"optimizer_dict"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = optimizer_cls(
|
||||||
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
optimizer_grouped_parameters, **optimizer_kwargs
|
||||||
)
|
)
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
import bitsandbytes
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
|
||||||
from axolotl.utils.optimizers.adopt import ADOPT
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
skipped = 0
|
||||||
ADOPT(
|
for module in opt_model.modules():
|
||||||
optimizer_grouped_parameters,
|
if isinstance(module, nn.Embedding):
|
||||||
decouple=True,
|
skipped += sum(
|
||||||
**optimizer_kwargs,
|
{
|
||||||
)
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
)
|
}.values()
|
||||||
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -378,6 +355,45 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||||
|
"""
|
||||||
|
Extend the base Trainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
tag_names = ["axolotl"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*_args,
|
||||||
|
bench_data_collator=None,
|
||||||
|
eval_data_collator=None,
|
||||||
|
dataset_tags=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.bench_data_collator = bench_data_collator
|
||||||
|
self.eval_data_collator = eval_data_collator
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
|
self._signature_columns = None # workaround for pylint
|
||||||
|
super().__init__(*_args, **kwargs)
|
||||||
|
self.train_data_collator = self.data_collator
|
||||||
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
if self.args.orpo_alpha:
|
||||||
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.torch_compile:
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
|
256
|
||||||
|
)
|
||||||
|
model = torch.compile(
|
||||||
|
model,
|
||||||
|
backend=self.args.torch_compile_backend,
|
||||||
|
mode=self.args.torch_compile_mode,
|
||||||
|
)
|
||||||
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
if self.args.multipack_real_batches:
|
if self.args.multipack_real_batches:
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import importlib
|
|||||||
import logging
|
import logging
|
||||||
from typing import OrderedDict
|
from typing import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
class BasePlugin:
|
||||||
"""
|
"""
|
||||||
@@ -469,3 +471,14 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
plugin.post_train_unload(cfg)
|
plugin.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOptimizerFactory:
|
||||||
|
"""
|
||||||
|
Base class for factories to create custom optimizers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, opt_model, training_args, **optimizer_kwargs
|
||||||
|
) -> "torch.optim.Optimizer":
|
||||||
|
pass
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ Run the following command to install `cut_cross_entropy[transformers]` if you do
|
|||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|
||||||
# if you are not in dev environment
|
# if you are not in dev environment
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install cut_cross_entropy with transformers support using "
|
"Please install cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
|
|||||||
"""
|
"""
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
|
||||||
class SpectrumArgs(BaseModel):
|
class SpectrumArgs(BaseModel):
|
||||||
@@ -27,3 +27,20 @@ class SpectrumArgs(BaseModel):
|
|||||||
|
|
||||||
spectrum_top_fraction: Optional[float] = 0.5
|
spectrum_top_fraction: Optional[float] = 0.5
|
||||||
spectrum_model_name: Optional[str] = None
|
spectrum_model_name: Optional[str] = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_use_orig_params(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("fsdp")
|
||||||
|
and data.get("fsdp_config")
|
||||||
|
and not data["fsdp_config"].get("use_orig_params")
|
||||||
|
and data.get("plugins")
|
||||||
|
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
|
||||||
|
):
|
||||||
|
# would otherwise raise
|
||||||
|
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
|
||||||
|
raise ValueError(
|
||||||
|
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import weakref
|
import weakref
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
@@ -382,21 +382,23 @@ def handle_untrained_tokens_fix(
|
|||||||
if not cfg.fix_untrained_tokens:
|
if not cfg.fix_untrained_tokens:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
is_ds_zero3: bool = False
|
||||||
|
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
|
||||||
|
is_ds_zero3 = True
|
||||||
|
|
||||||
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||||
sig = inspect.signature(fix_untrained_tokens)
|
sig = inspect.signature(fix_untrained_tokens)
|
||||||
|
|
||||||
|
fix_kwargs: Dict[str, Any] = {}
|
||||||
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||||
cfg.fix_untrained_tokens, list
|
cfg.fix_untrained_tokens, list
|
||||||
):
|
):
|
||||||
fix_untrained_tokens(
|
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
|
||||||
model,
|
if "is_ds_zero3" in sig.parameters:
|
||||||
tokenizer,
|
fix_kwargs["is_ds_zero3"] = is_ds_zero3
|
||||||
train_dataset,
|
|
||||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
||||||
)
|
|
||||||
else:
|
|
||||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
@@ -461,7 +463,7 @@ def setup_model_and_trainer(
|
|||||||
|
|
||||||
def train(
|
def train(
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
|
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
||||||
"""
|
"""
|
||||||
Train a model on the given dataset.
|
Train a model on the given dataset.
|
||||||
|
|
||||||
@@ -510,4 +512,4 @@ def train(
|
|||||||
# Create model card
|
# Create model card
|
||||||
create_model_card(cfg, trainer)
|
create_model_card(cfg, trainer)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer, trainer
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module with Pydantic models for configuration."""
|
"""Module with Pydantic models for configuration."""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -64,6 +65,17 @@ class ChatTemplate(str, Enum):
|
|||||||
metharme = "metharme" # pylint: disable=invalid-name
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
|
"""Custom supported optimizers"""
|
||||||
|
|
||||||
|
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||||
|
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||||
|
muon = "muon" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
class DeprecatedParameters(BaseModel):
|
||||||
"""configurations that are deprecated"""
|
"""configurations that are deprecated"""
|
||||||
|
|
||||||
@@ -494,17 +506,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
embedding_lr_scale: Optional[float] = None
|
embedding_lr_scale: Optional[float] = None
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[
|
||||||
Union[
|
Union[OptimizerNames, CustomSupportedOptimizers]
|
||||||
OptimizerNames,
|
|
||||||
Literal[
|
|
||||||
"lion_pytorch",
|
|
||||||
"optimi_adamw",
|
|
||||||
"ao_adamw_4bit",
|
|
||||||
"ao_adamw_8bit",
|
|
||||||
"ao_adamw_fp8",
|
|
||||||
"adopt_adamw",
|
|
||||||
],
|
|
||||||
]
|
|
||||||
] = OptimizerNames.ADAMW_HF
|
] = OptimizerNames.ADAMW_HF
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -727,7 +729,7 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||||
)
|
)
|
||||||
dataset_processes: Optional[int] = Field(default=os.cpu_count())
|
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
|
||||||
dataset_exact_deduplication: Optional[bool] = None
|
dataset_exact_deduplication: Optional[bool] = None
|
||||||
dataset_keep_in_memory: Optional[bool] = None
|
dataset_keep_in_memory: Optional[bool] = None
|
||||||
dataloader_pin_memory: Optional[bool] = None
|
dataloader_pin_memory: Optional[bool] = None
|
||||||
@@ -778,9 +780,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: Optional[torch.dtype]
|
# torch_dtype: Optional[torch.dtype]
|
||||||
|
|
||||||
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
|
gradient_checkpointing: Optional[
|
||||||
default=False
|
Union[Literal["unsloth", "offload"], bool]
|
||||||
)
|
] = Field(default=False)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
@@ -855,6 +857,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
special_tokens: Optional[SpecialTokensConfig] = None
|
special_tokens: Optional[SpecialTokensConfig] = None
|
||||||
tokens: Optional[List[str]] = None
|
tokens: Optional[List[str]] = None
|
||||||
|
added_tokens_overrides: Optional[Dict[int, str]] = None
|
||||||
|
|
||||||
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
@@ -1153,6 +1156,15 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_offload_grad_checkpointing(self):
|
||||||
|
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
||||||
|
LOG.warning(
|
||||||
|
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
|
||||||
|
)
|
||||||
|
self.gradient_checkpointing = "offload"
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_better_transformers(self):
|
def check_better_transformers(self):
|
||||||
if self.flash_optimum is True:
|
if self.flash_optimum is True:
|
||||||
@@ -1177,6 +1189,13 @@ class AxolotlInputConfig(
|
|||||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_lr_groups(cls, data):
|
||||||
|
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
|
||||||
|
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_saves(cls, data):
|
def check_saves(cls, data):
|
||||||
@@ -1660,6 +1679,30 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_rl_config_gradient_checkpointing(cls, data):
|
||||||
|
# TODO: SalmanMohammadi
|
||||||
|
# Distributed RL with QLoRA + gradient checkpointing
|
||||||
|
# and use_reentrant = True is broken upstream in TRL
|
||||||
|
# pylint: disable=too-many-boolean-expressions
|
||||||
|
if (
|
||||||
|
data.get("rl")
|
||||||
|
and data.get("gradient_checkpointing")
|
||||||
|
and data.get("gradient_checkpointing_kwargs")
|
||||||
|
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
|
||||||
|
and data.get("load_in_4bit")
|
||||||
|
and data.get("adapter") == "qlora"
|
||||||
|
and data.get("capabilities")
|
||||||
|
and data.get("capabilities").get("n_gpu", 1) > 1
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"The `use_reentrant: True` implementation of gradient checkpointing "
|
||||||
|
"is not supported for distributed RL training with QLoRA. Please set "
|
||||||
|
"`use_reentrant: False` in `gradient_checkpointing_kwargs`."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_kto_config(cls, data):
|
def check_kto_config(cls, data):
|
||||||
@@ -1670,15 +1713,6 @@ class AxolotlInputConfig(
|
|||||||
if data.get("remove_unused_columns") is not False:
|
if data.get("remove_unused_columns") is not False:
|
||||||
raise ValueError("Set `remove_unused_columns: False` when using kto")
|
raise ValueError("Set `remove_unused_columns: False` when using kto")
|
||||||
|
|
||||||
if data.get("gradient_checkpointing") and not (
|
|
||||||
data.get("gradient_checkpointing_kwargs")
|
|
||||||
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
|
|
||||||
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -1809,6 +1843,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
data["torch_compile"] = False
|
data["torch_compile"] = False
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_beta_and_trl_beta_match(cls, data):
|
||||||
|
if data.get("beta") and data.get("trl", {}).get("beta"):
|
||||||
|
if data["beta"] != data["trl"]["beta"]:
|
||||||
|
raise ValueError("beta and trl.beta must match or one must be removed")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
GRPO specific configuration args
|
GRPO specific configuration args
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -11,7 +12,10 @@ class TRLConfig(BaseModel):
|
|||||||
Input args for TRL.
|
Input args for TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beta: Optional[float] = None
|
beta: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Beta for RL training"},
|
||||||
|
)
|
||||||
max_completion_length: Optional[int] = Field(
|
max_completion_length: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -20,17 +24,68 @@ class TRLConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# GRPO specific args
|
# GRPO specific args
|
||||||
use_vllm: Optional[bool] = False
|
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||||
vllm_device: Optional[str] = "auto"
|
use_vllm: Optional[bool] = Field(
|
||||||
vllm_gpu_memory_utilization: Optional[float] = 0.9
|
default=False,
|
||||||
vllm_max_model_len: Optional[int] = None
|
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||||
vllm_dtype: Optional[str] = "auto"
|
)
|
||||||
|
vllm_device: Optional[str] = Field(
|
||||||
|
default="auto",
|
||||||
|
json_schema_extra={"description": "Device to use for VLLM"},
|
||||||
|
)
|
||||||
|
vllm_gpu_memory_utilization: Optional[float] = Field(
|
||||||
|
default=0.9,
|
||||||
|
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||||
|
)
|
||||||
|
vllm_dtype: Optional[str] = Field(
|
||||||
|
default="auto",
|
||||||
|
json_schema_extra={"description": "Data type for VLLM"},
|
||||||
|
)
|
||||||
|
vllm_max_model_len: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum length of the model context for VLLM"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
reward_funcs: Optional[List[str]] = None
|
reward_funcs: Optional[list[str]] = Field(
|
||||||
reward_weights: Optional[List[float]] = None
|
default=None,
|
||||||
num_generations: Optional[int] = None
|
json_schema_extra={"description": "List of reward functions to load"},
|
||||||
log_completions: Optional[bool] = False
|
)
|
||||||
|
reward_weights: Optional[list[float]] = Field(
|
||||||
sync_ref_model: Optional[bool] = False
|
default=None,
|
||||||
ref_model_mixup_alpha: Optional[float] = 0.9
|
json_schema_extra={
|
||||||
ref_model_sync_steps: Optional[int] = 64
|
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
num_generations: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
log_completions: Optional[bool] = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={"description": "Whether to log completions"},
|
||||||
|
)
|
||||||
|
sync_ref_model: Optional[bool] = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": (
|
||||||
|
"Whether to sync the reference model every `ref_model_sync_steps` "
|
||||||
|
"steps, using the `ref_model_mixup_alpha` parameter."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ref_model_mixup_alpha: Optional[float] = Field(
|
||||||
|
default=0.9,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ref_model_sync_steps: Optional[int] = Field(
|
||||||
|
default=64,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ def is_main_process():
|
|||||||
|
|
||||||
|
|
||||||
def is_local_main_process():
|
def is_local_main_process():
|
||||||
return PartialState().is_main_process
|
return PartialState().is_local_main_process
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size():
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def hf_grad_checkpoint_unsloth_wrapper(
|
def hf_grad_checkpoint_offload_wrapper(
|
||||||
decoder_layer, *args, use_reentrant=None
|
decoder_layer, *args, use_reentrant=None
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from peft import (
|
|||||||
PeftModelForCausalLM,
|
PeftModelForCausalLM,
|
||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
from peft.tuners.lora import QuantLinear
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
AddedToken,
|
AddedToken,
|
||||||
@@ -57,8 +56,14 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
|||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
|
from axolotl.utils.distributed import (
|
||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
barrier,
|
||||||
|
get_device_count,
|
||||||
|
get_device_type,
|
||||||
|
is_local_main_process,
|
||||||
|
zero_only,
|
||||||
|
)
|
||||||
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
@@ -165,7 +170,95 @@ def load_model_config(cfg):
|
|||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|
||||||
|
def modify_tokenizer_files(
|
||||||
|
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
|
||||||
|
|
||||||
|
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer_path: Path or name of the original tokenizer
|
||||||
|
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||||
|
output_dir: Directory to save the modified tokenizer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the modified tokenizer directory
|
||||||
|
|
||||||
|
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Create the tokenizer directory in output_dir if it doesn't exist
|
||||||
|
tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
|
||||||
|
# Load the tokenizer
|
||||||
|
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
||||||
|
|
||||||
|
# Save the tokenizer to the output directory
|
||||||
|
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||||
|
|
||||||
|
# Get the token IDs and map them to their new values
|
||||||
|
token_id_mappings = {
|
||||||
|
int(token_id): new_value for token_id, new_value in token_mappings.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1. Update tokenizer_config.json - added_tokens_decoder
|
||||||
|
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config_data = json.load(f)
|
||||||
|
|
||||||
|
# Update added_tokens_decoder
|
||||||
|
if "added_tokens_decoder" in config_data:
|
||||||
|
for token_id, new_value in token_id_mappings.items():
|
||||||
|
token_id_str = str(token_id)
|
||||||
|
if token_id_str in config_data["added_tokens_decoder"]:
|
||||||
|
config_data["added_tokens_decoder"][token_id_str][
|
||||||
|
"content"
|
||||||
|
] = new_value
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write the updated config back
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(config_data, f, indent=2)
|
||||||
|
|
||||||
|
# 2. Update tokenizer.json - added_tokens
|
||||||
|
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||||
|
if os.path.exists(tokenizer_path):
|
||||||
|
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||||
|
tokenizer_data = json.load(f)
|
||||||
|
|
||||||
|
# Update added_tokens
|
||||||
|
if "added_tokens" in tokenizer_data:
|
||||||
|
for token_id, new_value in token_id_mappings.items():
|
||||||
|
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||||
|
if token_entry["id"] == token_id:
|
||||||
|
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
|
||||||
|
raise ValueError(
|
||||||
|
f"Token ID {token_id} not found in added_tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write the updated tokenizer data back
|
||||||
|
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(tokenizer_data, f, indent=2)
|
||||||
|
|
||||||
|
barrier()
|
||||||
|
return tokenizer_dir
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
|
"""Load and configure the tokenizer based on the provided config."""
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
use_fast = True # this is the default
|
use_fast = True # this is the default
|
||||||
@@ -180,8 +273,18 @@ def load_tokenizer(cfg):
|
|||||||
if cfg.tokenizer_type:
|
if cfg.tokenizer_type:
|
||||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||||
|
|
||||||
|
# Set base tokenizer path
|
||||||
|
tokenizer_path = cfg.tokenizer_config
|
||||||
|
|
||||||
|
# Apply token string overrides if specified
|
||||||
|
if cfg.added_tokens_overrides:
|
||||||
|
# Modify tokenizer files and get path to modified tokenizer
|
||||||
|
tokenizer_path = modify_tokenizer_files(
|
||||||
|
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
cfg.tokenizer_config,
|
tokenizer_path,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
use_fast=use_fast,
|
use_fast=use_fast,
|
||||||
**tokenizer_kwargs,
|
**tokenizer_kwargs,
|
||||||
@@ -389,8 +492,8 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_fa_peft_integration()
|
patch_fa_peft_integration()
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing == "unsloth":
|
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
@@ -1256,7 +1359,7 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
h1 {
|
h1 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 6rem;
|
font-size: 5rem;
|
||||||
line-height: 1.1;
|
line-height: 1.1;
|
||||||
letter-spacing: -0.05em;
|
letter-spacing: -0.05em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("axolotl.cli.train.train") as mock_train:
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
mock_train.return_value = (MagicMock(), MagicMock())
|
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||||
|
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
||||||
|
|
||||||
with patch("axolotl.cli.train.train") as mock_train:
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
mock_train.return_value = (MagicMock(), MagicMock())
|
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||||
|
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
|
|||||||
@@ -69,6 +69,51 @@ class TestCutCrossEntropyIntegration:
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
def test_qwen2_w_cce(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||||
|
"plugins": [
|
||||||
|
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
|
||||||
|
],
|
||||||
|
"cut_cross_entropy": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"max_steps": 10,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
major, minor, _ = get_pytorch_version()
|
||||||
|
if (major, minor) < (2, 4):
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
else:
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"attention_type",
|
"attention_type",
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -750,3 +750,66 @@ class TestMultiGPULlama:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_fix_untrained_tokens(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"fix_untrained_tokens": True,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
"bos_token": "<|custom_im_start|>",
|
||||||
|
"eos_token": "<|custom_im_end|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"chat_template": "jinja",
|
||||||
|
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"bf16": True,
|
||||||
|
"save_safetensors": True,
|
||||||
|
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|||||||
@@ -66,6 +66,54 @@ class TestLlama:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
def test_fix_untrained_tokens(self, temp_dir):
|
def test_fix_untrained_tokens(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"fix_untrained_tokens": True,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
"bos_token": "<|custom_im_start|>",
|
||||||
|
"eos_token": "<|custom_im_end|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"chat_template": "jinja",
|
||||||
|
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"bf16": True,
|
||||||
|
"save_safetensors": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
|
def test_fix_untrained_tokens_already_trained(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
|
|||||||
@@ -65,8 +65,9 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@require_torch_2_5_1
|
@require_torch_2_5_1
|
||||||
@@ -111,8 +112,57 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
@require_torch_2_5_1
|
||||||
|
def test_muon(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "muon",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Test cases for the tokenizer loading
|
Test cases for the tokenizer loading
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -9,7 +10,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TestTokenizers(unittest.TestCase):
|
class TestTokenizers:
|
||||||
"""
|
"""
|
||||||
test class for the load_tokenizer fn
|
test class for the load_tokenizer fn
|
||||||
"""
|
"""
|
||||||
@@ -75,12 +76,48 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
|
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
|
||||||
self.assertEqual(len(tokenizer), 32001)
|
assert len(tokenizer) == 32001
|
||||||
|
|
||||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
self.assertEqual(len(tokenizer), 32001)
|
assert len(tokenizer) == 32001
|
||||||
|
|
||||||
|
def test_added_tokens_overrides(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
# use with tokenizer that has reserved_tokens in added_tokens
|
||||||
|
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
||||||
|
"added_tokens_overrides": {
|
||||||
|
128041: "RANDOM_OVERRIDE_1",
|
||||||
|
128042: "RANDOM_OVERRIDE_2",
|
||||||
|
},
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
|
||||||
|
128041
|
||||||
|
]
|
||||||
|
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
|
||||||
|
128042
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
# use with tokenizer that has reserved_tokens in added_tokens
|
||||||
|
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
||||||
|
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
|
||||||
|
):
|
||||||
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user