Compare commits

..

11 Commits

Author SHA1 Message Date
Wing Lian
c578c8f256 Validation for Muon optimizer with DS/FSDP 2025-04-01 09:29:54 -04:00
NanoCode012
f4ae8816bb Fix: remove the numerous sequential log (#2461)
* fix: remove sequential logs

* feat(doc): add for sample pack sequentially and curriculum sampling
2025-04-01 09:20:00 -04:00
NanoCode012
9b95e06cbb Fix(doc): Minor doc changes for peft and modal (#2462) [skip ci]
* fix(doc): document peft configs

* fix(doc): explain modal env vs secrets difference

* fix(doc): clarify evaluate vs lm-eval

* fix: clarify what is performance
2025-04-01 08:48:36 -04:00
Wing Lian
e0aba74dd0 Release update 20250331 (#2460) [skip ci]
* make torch 2.6.0 the default image

* fix tests against upstream main

* fix attribute access

* use fixture dataset

* fix dataset load

* correct the fixtures + tests

* more fixtures

* add accidentally removed shakespeare fixture

* fix conversion from unittest to pytest class

* nightly main ci caches

* build 12.6.3 cuda base image

* override for fix from huggingface/transformers#37162

* address PR feedback
2025-04-01 08:47:50 -04:00
Wing Lian
328d598114 gemma3 packing fixes (#2449)
* make gemma3 work with packing

* multi-gpu e2e for ci

* update gemma3 model namespace to use mirror

* add gradient checkpointing to multigpu e2e ci

* update gemma3 examples for use_reentrant and fix ddp find unused params

* fix tests for gemma3

* fix import for test utils

* set correct train loss for gemma3 e2e
2025-03-31 17:15:23 -04:00
DreamGenX
4d36ecc724 Sequential sample packing (#2404) [skip ci]
* add sequential sample packing

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-03-31 15:48:20 -04:00
NanoCode012
7acf93b59f Fix(doc): Clarify doc on attention configs and missing pad_token (#2455) [skip ci]
* fix: clarify input type

* fix: handling of error message if data_files not available

* fix: clarify attention handling

* fix: add doc on missing pad token
2025-03-31 15:47:28 -04:00
Wing Lian
b6fc46ada8 Updates for trl 0.16.0 - mostly for GRPO (#2437) [skip ci]
* add grpo scale_rewards config for trl#3135

* options to connect to vllm server directly w grpo trl#3094

* temperature support trl#3029

* sampling/generation kwargs for grpo trl#2989

* make vllm_enable_prefix_caching a config param trl#2900

* grpo multi-step optimizeations trl#2899

* remove overrides for grpo trainer

* bump trl to 0.16.0

* add cli  to start vllm-serve via trl

* call the python module directly

* update to use vllm with 2.6.0 too now and call trl vllm serve from module

* vllm 0.8.1

* use python3

* use sys.executable

* remove context and wait for start

* fixes to make it actually work

* fixes so the grpo tests pass with new vllm paradigm

* explicit host/port and check in start vllm

* make sure that vllm doesn't hang by setting quiet so outouts go to dev null

* also bump bnb to latest release

* add option for wait from cli and nccl debugging for ci

* grpo + vllm test on separate devices for now

* make sure grpo + vllm tests runs single worker since pynccl comms would conflict

* fix cli

* remove wait and add caching for argilla dataset

* refactoring configs

* chore: lint

* add vllm config

* fixup vllm grpo args

* fix one more incorrect schema/config path

* fix another vlllm reference and increase timeout

* make the tests run a bit faster

* change mbsz back so it is correct for grpo

* another change mbsz back so it is correct for grpo

* fixing cli args

* nits

* adding docs

* docs

* include tensor parallel size for vllm in pydantic schema

* moving start_vllm, more docs

* limit output len for grpo vllm

* vllm enable_prefix_caching isn't a bool cli arg

* fix env ordering in tests and also use pid check when looking for vllm

---------

Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-03-31 15:47:11 -04:00
Dan Saunders
b35992262e Ray train bugfix (#2458)
* fix nccl pg destroy warning

* update

* ray bugfix
2025-03-31 15:17:43 -04:00
Dan Saunders
ef6eb77cc8 destroy process group on Ctrl+C / training or eval run (#2457)
* fix nccl pg destroy warning

* update
2025-03-31 12:36:47 -04:00
Dan Saunders
5410195e0b Sequence parallelism quick follow-ups; remove ModelCallback (#2450)
* guard return if ring attn alrady registered

* add docs link, bits in multi-gpu docs, remove save model callback (subsumed by HF trainers)

* configurable heads_k_stride from ring-flash-attn hf adapter
2025-03-31 09:13:42 -04:00
60 changed files with 1703 additions and 695 deletions

View File

@@ -40,6 +40,12 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -25,12 +25,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -87,12 +87,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -42,8 +42,7 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
# awaiting vllm#12721
axolotl_extras:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]

View File

@@ -33,6 +33,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -46,7 +55,7 @@ jobs:
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
pip3 install torch==${{ matrix.pytorch_version }}
- name: Update requirements.txt
run: |
@@ -58,8 +67,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
@@ -73,10 +81,15 @@ jobs:
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |

View File

@@ -96,6 +96,10 @@ jobs:
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
@@ -256,7 +260,7 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
axolotl_extras: vllm
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -40,6 +40,7 @@ quartodoc:
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- title: Trainers
@@ -243,6 +244,7 @@ website:
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -2,4 +2,5 @@
set -e
# only run one test at a time so as not to OOM the GPU
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v -n1 /workspace/axolotl/tests/e2e/multigpu/solo/

View File

@@ -170,7 +170,7 @@ axolotl merge-sharded-fsdp-weights config.yml
### evaluate
Evaluates a model's performance using metrics specified in the config.
Evaluates a model's performance (loss etc) on the train and eval datasets.
```bash
# Basic evaluation
@@ -197,6 +197,8 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
## Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
@@ -235,7 +237,7 @@ Create a cloud config YAML with your Modal settings:
```yaml
# cloud_config.yml
provider: modal
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
gpu_count: 1 # Number of GPUs to use
timeout: 86400 # Maximum runtime in seconds (24 hours)
branch: main # Git branch to use (optional)
@@ -248,7 +250,7 @@ volumes: # Persistent storage volumes
- name: axolotl-artifacts
mount: /workspace/artifacts
env: # Environment variables
secrets: # Secrets to inject
- WANDB_API_KEY
- HF_TOKEN
```
@@ -274,15 +276,27 @@ axolotl lm-eval config.yml --cloud cloud_config.yml
### Cloud Configuration Options
```yaml
provider: # compute provider, currently only `modal` is supported
gpu: # GPU type to use
gpu_count: # Number of GPUs (default: 1)
memory: # RAM in GB (default: 128)
timeout: # Maximum runtime in seconds
provider: # compute provider, currently only `modal` is supported
gpu: # GPU type to use
gpu_count: # Number of GPUs (default: 1)
memory: # RAM in GB (default: 128)
timeout: # Maximum runtime in seconds
timeout_preprocess: # Preprocessing timeout
branch: # Git branch to use
docker_tag: # Custom Docker image tag
volumes: # List of persistent storage volumes
env: # Environment variables to pass
secrets: # Secrets to inject
branch: # Git branch to use
docker_tag: # Custom Docker image tag
volumes: # List of persistent storage volumes
# Environment variables to pass. Can be specified in two ways:
# 1. As a string: Will load the value from the host computer's environment variables
# 2. As a key-value pair: Will use the specified value directly
# Example:
# env:
# - CUSTOM_VAR # Loads from host's $CUSTOM_VAR
# - {CUSTOM_VAR: "value"} # Uses "value" directly
env:
# Secrets to inject. Same input format as `env` but for sensitive data.
secrets:
# - HF_TOKEN
# - WANDB_API_KEY
```

View File

@@ -238,10 +238,10 @@ simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_device: # Optional[str]. Device to use for VLLM.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
vllm_dtype: # Optional[str]. Data type for VLLM.
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
@@ -320,9 +320,13 @@ total_num_tokens:
sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
# Use batch flattening for speedups when not using sample_packing
batch_flattening:
@@ -354,7 +358,27 @@ lora_target_modules:
# - down_proj
# - up_proj
lora_target_linear: # If true, will target all linear modules
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
peft_layers_to_transform:
# Optional[bool]. Whether to use DoRA.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
peft_use_dora:
# Optional[bool]. Whether to use RSLoRA.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
peft_use_rslora:
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
peft_layer_replication:
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
# How to initialize LoRA weights. Default to True which is MS original implementation.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
peft_init_lora_weights:
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
@@ -587,26 +611,31 @@ max_grad_norm:
# currently only supported on Llama and Mistral
neftune_noise_alpha:
# Whether to bettertransformers
# Optional[bool]. Whether to bettertransformers
flash_optimum:
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
# Note: Only one of the following attention patches can be used at a time.
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
# Optional[bool]. Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Resume from a specific checkpoint dir
# Optional[str]. Resume from a specific checkpoint dir
resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# Be careful with this being turned on between different models.
auto_resume_from_checkpoints: false
@@ -658,6 +687,9 @@ ddp_broadcast_buffers:
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -35,12 +35,22 @@ description: Frequently asked questions
**Q: How to call Axolotl via custom python scripts?**
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
> A: Since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
**Q: ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as pad_token**
> A: This is because the tokenizer does not have a padding token. Please add a padding token to the tokenizer via:
> ```yaml
> special_tokens:
> # str. If you're not sure, set to same as `eos_token`.
> pad_token: "..."
> ```
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

View File

@@ -18,6 +18,7 @@ Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
@@ -66,6 +67,28 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Sequence parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).

View File

@@ -502,9 +502,48 @@ The input format is a simple JSON input with customizable fields based on the ab
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
using 4 GPUs - 2 for training, and 2 for vLLM:
::: {.callout-important}
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
:::
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
dtype: auto
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
rl: grpo
trl:
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_server_timeout: 300
```
```bash
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml
```
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
```bash
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
```
#### Reward functions
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions:
For example, to load OpenAI's GSM8K and use a random reward for completions:
```python
# rewards.py
@@ -530,8 +569,6 @@ trl:
beta: 0.001
max_completion_length: 256
use_vllm: True
vllm_device: auto
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]

View File

@@ -25,6 +25,8 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
@@ -58,11 +60,16 @@ To use sequence parallelism, you need:
## Example
```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
sequence_parallel_degree: 2 # Split each sequence into 4 parts
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
...
```

View File

@@ -5,6 +5,9 @@ tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -54,6 +57,8 @@ fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -7,6 +7,9 @@ skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
@@ -48,6 +51,8 @@ fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
local_rank:
logging_steps: 1
flash_attention: true

View File

@@ -0,0 +1,80 @@
base_model: meta-llama/Llama-3.2-1B
# optionally might have model_type or tokenizer_type
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/lora-out
test_value: true
sequence_len: 4096
sample_packing: true
sample_packing_sequentially: true
curriculum_sampling: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.3
bitsandbytes==0.45.4
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
@@ -12,12 +12,12 @@ liger-kernel==0.5.5
packaging==23.2
peft==0.15.0
transformers==4.50.0
transformers==4.50.3
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.5.0
deepspeed==0.16.4
trl==0.15.1
trl==0.16.0
optimum==1.16.2
hf_transfer

View File

@@ -10,7 +10,7 @@ from pathlib import Path
from setuptools import find_packages, setup
def parse_requirements():
def parse_requirements(extras_require_map):
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
@@ -67,6 +67,7 @@ def parse_requirements():
if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2")
extras_require_map["vllm"] = ["vllm==0.8.1"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
@@ -86,7 +87,7 @@ def parse_requirements():
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
return _install_requires, _dependency_links, extras_require_map
def get_package_version():
@@ -103,7 +104,46 @@ def get_package_version():
return version_
install_requires, dependency_links = parse_requirements()
extras_require = {
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require
)
setup(
version=get_package_version(),
@@ -116,40 +156,5 @@ setup(
"axolotl=axolotl.cli.main:main",
],
},
extras_require={
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
],
},
extras_require=extras_require_build,
)

View File

@@ -35,6 +35,55 @@ class TrainerCliArgs:
num_processes: Optional[int] = field(default=None)
@dataclass
class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: int = field(
default=1,
metadata={"help": "Number of tensor parallel workers to use."},
)
host: str = field(
default="0.0.0.0", # nosec B104
metadata={"help": "Host address to run the server on."},
)
port: int = field(
default=8000,
metadata={"help": "Port to run the server on."},
)
gpu_memory_utilization: Optional[float] = field(
default=None,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
"out-of-memory (OOM) errors during initialization."
},
)
dtype: Optional[str] = field(
default=None,
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
enable_prefix_caching: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
"hardware support this feature."
},
)
@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""

View File

@@ -14,7 +14,12 @@ import yaml
from dotenv import load_dotenv
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import (
add_options_from_config,
@@ -23,6 +28,7 @@ from axolotl.cli.utils import (
fetch_from_github,
filter_none_kwargs,
)
from axolotl.cli.vllm_serve import do_vllm_serve
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -316,6 +322,14 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(VllmServeCliArgs)
@filter_none_kwargs
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
cli.add_command(lm_eval)

View File

@@ -0,0 +1,55 @@
"""
CLI to start the vllm server for online RL
"""
from pathlib import Path
from typing import Union
from trl.scripts.vllm_serve import ScriptArguments
from trl.scripts.vllm_serve import main as vllm_serve_main
from axolotl.cli.config import load_cfg
def do_vllm_serve(
config: Union[Path, str],
cli_args: dict,
):
"""
Starts the VLLM server for serving LLM models used for online RL
Args
:param cfg: Parsed doct of the YAML config
:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs
Returns:
process_id: the process id of the started VLLM server
"""
cfg = load_cfg(config)
model = cfg.base_model
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = (
cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization
)
dtype = cli_args.get("dtype") or cfg.vllm.dtype
max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len
enable_prefix_caching = (
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
)
vllm_script_args = ScriptArguments(
model,
tensor_parallel_size=tensor_parallel_size,
host=host,
port=port,
gpu_memory_utilization=gpu_memory_utilization,
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
)
vllm_serve_main(vllm_script_args)

View File

@@ -69,7 +69,6 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
SaveModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
@@ -249,7 +248,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback())
return callbacks
@@ -526,9 +524,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
# handle ddp
ddp_find_unused_parameters = None
if self.cfg.ddp:
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
training_arguments_kwargs["ddp_find_unused_parameters"] = (
False if self.cfg.ddp else None
ddp_find_unused_parameters
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = []
@@ -937,7 +941,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(SaveModelCallback())
return callbacks

View File

@@ -28,6 +28,7 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import (
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
@@ -40,7 +41,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -112,6 +115,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
)

View File

@@ -13,7 +13,7 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.mixins import SchedulerMixin
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""

View File

@@ -40,18 +40,15 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_device"] = (
trl.vllm_device if trl.vllm_device else "auto"
)
if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
trl.vllm_gpu_memory_utilization
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_guided_decoding_regex:
grpo_args_kwargs["vllm_guided_decoding_regex"] = (
trl.vllm_guided_decoding_regex
)
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
if trl.num_generations:
grpo_args_kwargs["num_generations"] = trl.num_generations
@@ -70,6 +67,25 @@ class GRPOStrategy:
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
if trl.scale_rewards is not None:
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
if trl.temperature is not None:
grpo_args_kwargs["temperature"] = trl.temperature
if trl.top_p is not None:
grpo_args_kwargs["top_p"] = trl.top_p
if trl.top_k is not None:
grpo_args_kwargs["top_k"] = trl.top_k
if trl.min_p is not None:
grpo_args_kwargs["min_p"] = trl.min_p
if trl.repetition_penalty is not None:
grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty
if trl.num_iterations is not None:
grpo_args_kwargs["num_iterations"] = trl.num_iterations
if trl.epsilon is not None:
grpo_args_kwargs["epsilon"] = trl.epsilon
return grpo_args_kwargs
@classmethod

View File

@@ -2,108 +2,68 @@
Axolotl GRPO trainer
"""
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel
from trl import GRPOConfig, GRPOTrainer
from trl.models import unwrap_model_for_generation
from contextlib import nullcontext
from axolotl.core.trainers.base import SchedulerMixin
from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
if is_deepspeed_available():
import deepspeed
# mypy: ignore-errors
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# pylint: disable=access-member-before-definition
# Enable gradient checkpointing if requested
if kwargs["args"].gradient_checkpointing:
# Ensure use_cache is disabled
if hasattr(self.model, "config"):
self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr(
self.model.base_model, "gradient_checkpointing_enable"
):
self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"):
self.model.gradient_checkpointing_enable()
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
)
if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
return model
# pylint: enable=unused-argument,redefined-builtin
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model,
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = (
unwrapped_model._orig_mod # pylint: disable=protected-access
)
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
for k, v in state_dict.items()
if "original_module" not in k
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()

View File

@@ -4,5 +4,6 @@
# flake8: noqa
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -0,0 +1,67 @@
"""
Temporary fix/override for bug in resume from checkpoint
See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release
"""
import logging
import os
import random
import numpy as np
import torch
from transformers import Trainer, is_torch_npu_available
from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode
LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer):
"""
mixin for method override to load RNG states from a checkpoint
"""
def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`
if checkpoint is None:
return
if self.args.world_size > 1:
process_index = self.args.process_index
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
if not os.path.isfile(rng_file):
LOG.info(
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
else:
rng_file = os.path.join(checkpoint, "rng_state.pth")
if not os.path.isfile(rng_file):
LOG.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
"fashion, reproducibility is not guaranteed."
)
return
# Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
# which requires allowlisted classes when loading with weights_only=True.
with safe_globals():
checkpoint_rng_state = torch.load(rng_file) # nosec B614
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
if torch.cuda.is_available():
set_rng_state_for_device(
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
)
if is_torch_npu_available():
set_rng_state_for_device(
"NPU", torch.npu, checkpoint_rng_state, is_distributed
)

View File

@@ -13,6 +13,7 @@ from trl import (
RewardTrainer,
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
@@ -74,7 +75,7 @@ class TRLPPOTrainer(PPOTrainer):
)
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
@@ -154,7 +155,7 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
return loss, metrics
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
@@ -162,7 +163,7 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
@@ -244,7 +245,7 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
return loss, metrics
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
@@ -252,7 +253,7 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""

View File

@@ -34,6 +34,12 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},

View File

@@ -15,6 +15,7 @@ from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -159,4 +160,6 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
del model
del tokenizer
cleanup_distributed()
return all_metrics

View File

@@ -38,13 +38,19 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_degree: int):
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
"""
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
"""
if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...")
return
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
@@ -84,6 +90,11 @@ def register_ring_attn(sequence_parallel_degree: int):
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if heads_k_stride is None:
heads_k_stride = 1
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)

View File

@@ -1,113 +1,153 @@
"""
Hijack the LlamaAttention forward method to use xformers if available.
Updated for transformers v4.50.0.
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
"""
from typing import Optional
import logging
import warnings
from typing import Optional, Tuple
import torch
from torch import nn
from transformers.models.llama.modeling_llama import repeat_kv
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
try:
import xformers.ops
XFORMERS_AVAILABLE = True
except ImportError:
XFORMERS_AVAILABLE = False
def xformers_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs, # pylint: disable=unused-argument
):
"""
Implements xformers memory-efficient attention for LlamaAttention with support for GQA.
Args:
module: The LlamaAttention module
query: Query states of shape [batch, num_heads, seq_len, head_dim]
key: Key states of shape [batch, num_kv_heads, seq_len, head_dim]
value: Value states of shape [batch, num_kv_heads, seq_len, head_dim]
attention_mask: Attention mask
scaling: Scaling factor for attention scores
dropout: Dropout probability
Returns:
attn_output: Output of xformers memory-efficient attention
attn_weights: None
"""
# First, handle grouped-query attention (GQA)
# We need to repeat key and value states to match the number of query heads
num_key_value_groups = getattr(module, "num_key_value_groups", 1)
key = repeat_kv(key, num_key_value_groups)
value = repeat_kv(value, num_key_value_groups)
# xformers expects inputs in shape [batch, seq_len, num_heads, head_dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Determine if we need a causal mask
is_causal = getattr(module, "is_causal", True)
# Set up the attention bias for xformers
if is_causal:
# Use xformers built-in causal mask
attn_bias = xformers.ops.LowerTriangularMask()
elif attention_mask is not None:
# For non-causal attention with a mask, we'd need to convert the mask
# This is a simplification - you might need to adapt based on your mask format
attn_bias = attention_mask
else:
# No mask needed
attn_bias = None
# Apply xformers memory-efficient attention
attn_output = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attn_bias,
p=dropout if module.training else 0.0,
scale=scaling,
)
# Reshape back to [batch, seq_len, hidden_size]
attn_output = attn_output.transpose(1, 2)
return attn_output, None # Return None for attn_weights to match interface
logging.error("xformers not found! Please install it before trying to use it.")
def hijack_llama_attention():
"""
Patch the LlamaAttention forward method to use xformers if available.
"""
if not XFORMERS_AVAILABLE:
raise ValueError(
"xformers not available. Please install it following axolotl's requirements."
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
cos, sin = self.rotary_emb(value_states)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
import transformers.models.llama.modeling_llama as llama_modeling
#
# xformers-attn start
#
# Add xformers to the available attention implementations
llama_modeling.ALL_ATTENTION_FUNCTIONS["xformers"] = xformers_attention_forward
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Create a wrapper for the original LlamaAttention forward method
original_forward = llama_modeling.LlamaAttention.forward
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
def patched_forward(self, *args, **kwargs):
# Set the attention implementation to xformers
# pylint: disable=protected-access
self.config._attn_implementation = "xformers"
return original_forward(self, *args, **kwargs)
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# Apply the patch
llama_modeling.LlamaAttention.forward = patched_forward
#
# xformers-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -22,6 +22,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"phi3",
"gemma",
"gemma2",
"gemma3",
"gemma3_text",
"cohere",
"cohere2",

View File

@@ -27,6 +27,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -157,6 +158,8 @@ def setup_signal_handler(
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
cleanup_distributed()
sys.exit(0)
_model_weakref = weakref.ref(model)
@@ -478,7 +481,7 @@ def train(
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer etc.
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,
model,
@@ -487,34 +490,26 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Save initial configs
# Additional setup
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_model_card(cfg)
# Execute the training
resume_from_checkpoint = determine_resume_checkpoint(cfg)
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()
return model, tokenizer, trainer

View File

@@ -816,27 +816,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
return control
class SaveModelCallback(TrainerCallback):
"""Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
):
control.should_save = True
return control
class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""

View File

@@ -112,6 +112,7 @@ class DataCollatorForSeq2Seq:
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys()
labels = None
if return_tensors is None:
return_tensors = self.return_tensors
@@ -164,6 +165,8 @@ class DataCollatorForSeq2Seq:
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
if not has_attn_mask:
del features["attention_mask"]
# prepare decoder_input_ids
if (

View File

@@ -238,7 +238,8 @@ def load_dataset_w_config(
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
else:
elif config_dataset.data_files:
fp: str | list[str] | None = None
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,

View File

@@ -71,8 +71,8 @@ def barrier():
def is_main_process():
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
Check if the current process is the main process. If not in distributed mode,
always return `True`.
"""
if not is_distributed():
return True
@@ -87,6 +87,18 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
def cleanup_distributed():
"""
Destroy process group if torch distributed is initialized. Called in training early
termination or when training successfully completes.
"""
# Ensure that all operations are completed before destroying the process group
torch.cuda.synchronize()
# Destroy the process group
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@contextmanager
def zero_only():
"""

View File

@@ -609,7 +609,10 @@ class ModelLoader:
# Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash
# attn to use ring comm for SP training across multiple GPUs.
register_ring_attn(self.cfg.sequence_parallel_degree)
register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride,
)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):

View File

@@ -8,11 +8,13 @@ from typing import Any, Iterable, List, Union
import numba
import numpy as np
from torch.utils.data import BatchSampler, Sampler
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
@numba.njit
@@ -103,6 +105,55 @@ def allocate(
return result, s, len(result) * c * n
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
"""
Sequential allocator that preserves example order
Parameters:
- lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
Returns:
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
result = []
total_used = 0
# First, do sequential packing into bins
all_bins = []
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
for idx, size in enumerate(lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
remaining_capacity -= size
total_used += size
else:
# Example doesn't fit, start a new bin
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = c - size
total_used += size
# Add the last bin if not empty
if current_bin:
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
result.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * c
class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack"""
@@ -115,6 +166,7 @@ class MultipackBatchSampler(BatchSampler):
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
sequential: bool = False,
**kwargs,
):
super().__init__(sampler, batch_size, drop_last)
@@ -122,6 +174,7 @@ class MultipackBatchSampler(BatchSampler):
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
assert isinstance(self.lengths, np.ndarray)
@@ -136,6 +189,11 @@ class MultipackBatchSampler(BatchSampler):
# the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warn(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)
def set_epoch(self, epoch: int):
self.epoch = epoch
@@ -145,13 +203,21 @@ class MultipackBatchSampler(BatchSampler):
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
if self.sequential:
batches, total_used, total_slots = allocate_sequentially(
lengths=lengths,
rank=0,
c=self.batch_max_len,
n=1,
)
else:
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
batches = [
[

View File

@@ -46,6 +46,7 @@ from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig
LOG = logging.getLogger(__name__)
@@ -86,6 +87,9 @@ class AxolotlInputConfig(
trl: TRLConfig | None = Field(
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
)
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
)
reward_model: bool | None = None
process_reward_model: bool | None = None
num_labels: int | None = None
@@ -188,6 +192,7 @@ class AxolotlInputConfig(
sample_packing: bool | None = None
sample_packing_group_size: int | None = 100_000
sample_packing_bin_size: int | None = 200
sample_packing_sequentially: bool | None = None
eval_sample_packing: bool | None = None
pad_to_sequence_len: bool | None = None
curriculum_sampling: bool | None = None
@@ -248,6 +253,7 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -1108,7 +1114,7 @@ class AxolotlInputConfig(
@field_validator("sequence_parallel_degree", mode="before")
@classmethod
def check_sequence_parallel_config(cls, value, info):
def check_sequence_parallel_degree(cls, value, info):
if not value:
value = 1
@@ -1129,6 +1135,17 @@ class AxolotlInputConfig(
return value
@model_validator(mode="before")
@classmethod
def check_muon_deepspeed_fsdp(cls, data):
if data.get("optimizer") == "muon" and (
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
):
raise ValueError(
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
@@ -1264,3 +1281,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data["beta"] != data["trl"]["beta"]:
raise ValueError("beta and trl.beta must match or one must be removed")
return data
@model_validator(mode="after")
def check_min_torch_version(self):
if self.env_capabilities and self.env_capabilities.torch_version:
torch_version = self.env_capabilities.torch_version
if version.parse(torch_version) < version.parse("2.5.1"):
LOG.warning(
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
)

View File

@@ -20,27 +20,30 @@ class TRLConfig(BaseModel):
)
# GRPO specific args
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
use_vllm: bool | None = Field(
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
use_vllm: bool = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
)
vllm_device: str | None = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
vllm_server_host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host of the vLLM server to connect to"},
)
vllm_gpu_memory_utilization: float | None = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
vllm_server_port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to connect to"},
)
vllm_dtype: str | None = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: int | None = Field(
vllm_server_timeout: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
"after the timeout, a `ConnectionError` is raised."
},
)
vllm_guided_decoding_regex: str | None = Field(
default=None,
json_schema_extra={
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
},
)
@@ -85,3 +88,48 @@ class TRLConfig(BaseModel):
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)
scale_rewards: bool = Field(
default=True,
json_schema_extra={
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
},
)
temperature: float | None = Field(
default=None,
json_schema_extra={"description": "Sampling temperature for the GRPO policy."},
)
top_p: float | None = Field(
default=None,
json_schema_extra={
"description": "Top-p sampling probability for the generation policy."
},
)
top_k: int | None = Field(
default=None,
json_schema_extra={"description": "Top-k sampling for the generation policy."},
)
min_p: float | None = Field(
default=None,
json_schema_extra={
"description": "Minimum probability for the generation policy."
},
)
repetition_penalty: float | None = Field(
default=None,
json_schema_extra={
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
},
)
num_iterations: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
},
)
epsilon: float | None = Field(
default=None,
json_schema_extra={
"description": "Epsilon value for clipping in the GRPO algorithm."
},
)

View File

@@ -0,0 +1,38 @@
"""
Pydantic models for VLLM configuration, used primarily for RL training with TRL + grpo
"""
from pydantic import BaseModel, Field
class VllmConfig(BaseModel):
"""
Configuration for VLLM server
"""
device: str | None = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
tensor_parallel_size: int | None = Field(
default=None,
json_schema_extra={"description": "Tensor parallel size for VLLM"},
)
gpu_memory_utilization: float | None = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
dtype: str | None = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
max_model_len: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
enable_prefix_caching: bool | None = Field(
default=None,
json_schema_extra={"description": "Enable prefix caching for VLLM"},
)

View File

@@ -13,7 +13,7 @@ import torch
import torch.cuda
from accelerate.logging import get_logger
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.model_config_type == "mamba":
if cfg.model_config_type in ["mamba", "gemma3"]:
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
@@ -456,13 +456,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
else:
sampler_batch_size = cfg.micro_batch_size
batch_max_len = cfg.sequence_len
if cfg.curriculum_sampling:
sampler = SequentialSampler(train_dataset)
else:
sampler = RandomSampler(train_dataset)
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
sampler=sampler,
lengths=get_dataset_lengths(train_dataset),
batch_size=sampler_batch_size,
batch_max_len=batch_max_len,
group_size=cfg.sample_packing_group_size,
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
)

View File

@@ -8,11 +8,13 @@ import shutil
import sys
import tempfile
import time
from pathlib import Path
import datasets
import pytest
import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download
from tokenizers import AddedToken
from transformers import AutoTokenizer
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
@@ -48,6 +50,14 @@ def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs)
@pytest.fixture(scope="session", autouse=True)
def download_ds_fixture_bundle():
ds_dir = snapshot_download_w_retry(
"axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
)
return Path(ds_dir)
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
@@ -101,42 +111,50 @@ def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
@pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
def download_argilla_distilabel_intel_orca_dpo_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_fozzie_alpaca_dpo_dataset():
# download the dataset
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
)
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test",
repo_type="dataset",
revision="ea82cff",
)
# @pytest.fixture(scope="session", autouse=True)
# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# # download the dataset
# snapshot_download_w_retry(
# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
# )
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
# @pytest.fixture(scope="session", autouse=True)
# def download_fozzie_alpaca_dpo_dataset():
# # download the dataset
# snapshot_download_w_retry(
# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
# )
# snapshot_download_w_retry(
# "fozziethebeat/alpaca_messages_2k_dpo_test",
# repo_type="dataset",
# revision="ea82cff",
# )
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset(
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
)
# @pytest.fixture(scope="session")
# @disable_hf_offline
# def dataset_fozzie_alpaca_dpo_dataset(
# download_fozzie_alpaca_dpo_dataset,
# ): # pylint: disable=unused-argument,redefined-outer-name
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
#
#
# @pytest.fixture(scope="session")
# @disable_hf_offline
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
# download_fozzie_alpaca_dpo_dataset,
# ): # pylint: disable=unused-argument,redefined-outer-name
# return load_dataset(
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
# )
@pytest.fixture(scope="session", autouse=True)
@@ -263,7 +281,7 @@ def download_mlx_mistral_7b_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
@pytest.fixture
def download_llama2_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
@@ -273,7 +291,7 @@ def download_llama2_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
@@ -284,6 +302,57 @@ def tokenizer_huggyllama(
return tokenizer
@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama_w_special_tokens(
tokenizer_huggyllama,
): # pylint: disable=redefined-outer-name
tokenizer_huggyllama.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
return tokenizer_huggyllama
@pytest.fixture
@enable_hf_offline
def tokenizer_llama2_7b(
download_llama2_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
return tokenizer
@pytest.fixture
@enable_hf_offline
def tokenizer_mistral_7b_instruct(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
@pytest.fixture
def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
tokenizer_mistral_7b_instruct.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
tokenizer_mistral_7b_instruct.add_tokens(
[
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
]
)
return tokenizer_mistral_7b_instruct
@pytest.fixture
def temp_dir():
# Create a temporary directory
@@ -349,6 +418,60 @@ def cleanup_monkeypatches():
globals().pop(module_global, None)
@pytest.fixture
def dataset_winglian_tiny_shakespeare(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
return datasets.load_from_disk(ds_path)
@pytest.fixture
def dataset_tatsu_lab_alpaca(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_mhenrichsen_alpaca_2k_test(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = (
download_ds_fixture_bundle
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
)
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = (
download_ds_fixture_bundle
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
)
return datasets.load_from_disk(ds_path)["train"]
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,

View File

View File

@@ -0,0 +1,294 @@
"""
GRPO test suite
"""
import os
import random
import subprocess # nosec B404
import sys
import time
from pathlib import Path
import pytest
import requests
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
def start_vllm(
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
) -> int:
"""
helper function to start the VLLM server in the background, mostly for testing purposes
"""
cmd = [sys.executable, "-m", "trl.scripts.vllm_serve", "--model", model]
if tensor_parallel_size := kwargs.get("tensor_parallel_size"):
cmd.extend(["--tensor-parallel-size", str(tensor_parallel_size)])
if host := kwargs.get("host"):
cmd.extend(["--host", host])
if port := kwargs.get("port"):
cmd.extend(["--port", str(port)])
if gpu_memory_utilization := kwargs.get("gpu_memory_utilization"):
cmd.extend(["--gpu-memory-utilization", str(gpu_memory_utilization)])
if dtype := kwargs.get("dtype"):
cmd.extend(["--dtype", dtype])
if max_model_len := kwargs.get("max_model_len"):
cmd.extend(["--max-model-len", str(max_model_len)])
if kwargs.get("enable_prefix_caching"):
cmd.extend(["--enable-prefix-caching", "True"])
# print out the command to be executed
print(" ".join(cmd))
# start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with
cmd,
env=env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
) # nosec B603
# print out the process id so the user can easily kill it later
print(f"VLLM server process started (PID: {process.pid})")
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
started = False
if wait and host and port:
for _ in range(int(wait)):
try:
response = requests.get(f"http://{host}:{port}", timeout=1)
if int(response.status_code) in [200, 404]:
started = True
break
except requests.exceptions.RequestException:
pass
# also check if the process.pid is still running
if not process.poll() is None:
break
time.sleep(1)
if wait and not started:
print(
f"VLLM server process did not start within {wait} seconds. Please check your server logs."
)
process.kill()
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process id
return process.pid
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
"""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_dora(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"peft_use_dora": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process_id = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
)
finally:
os.kill(vllm_process_id, 9)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_fft(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process_id = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
)
finally:
os.kill(vllm_process_id, 9)

View File

@@ -52,9 +52,9 @@ class TestMultiGPUEval:
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -121,9 +121,9 @@ class TestMultiGPUEval:
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",

View File

@@ -0,0 +1,100 @@
"""
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@pytest.fixture(scope="session", autouse=True)
def download_model():
# download the model
snapshot_download("axolotl-mirrors/gemma-3-4b-pt", repo_type="model")
class TestMultiGPUGemma3:
"""
Test case for Gemma3 models using LoRA
"""
def test_lora_ddp_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-mirrors/gemma-3-4b-pt",
"sequence_len": 2048,
"ddp_find_unused_parameters": True,
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"chat_template": "gemma3",
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {
"use_reentrant": False,
},
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
)

View File

@@ -1,175 +0,0 @@
"""
GRPO test suite
"""
import random
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
"""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_dora(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
"vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"peft_use_dora": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_fft(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
"vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)

View File

@@ -58,6 +58,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -121,6 +122,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -193,6 +195,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
@@ -270,6 +273,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
@@ -330,6 +334,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -399,7 +404,8 @@ class TestMultiGPULlama:
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -478,7 +484,8 @@ class TestMultiGPULlama:
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -778,9 +785,10 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",

View File

@@ -46,7 +46,7 @@ class TestMultiGPUQwen2:
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"warmup_steps": 20,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,

View File

@@ -50,7 +50,7 @@ class TestMultiGPURay:
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",

View File

@@ -110,7 +110,7 @@ class TestRingAttention:
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4)
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2

View File

@@ -324,7 +324,7 @@ class TestDatasetPreparation:
@enable_hf_offline
def test_load_hub_with_revision_with_dpo(
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
):
"""Verify that processing dpo data from the hub works with a specific revision"""
@@ -339,12 +339,10 @@ class TestDatasetPreparation:
)
# pylint: disable=duplicate-code
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
@@ -354,7 +352,9 @@ class TestDatasetPreparation:
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub_with_revision(self, tokenizer):
def test_load_local_hub_with_revision(
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer
):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
@@ -386,13 +386,23 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
dataset, _ = load_tokenized_prepared_datasets(
tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
@enable_hf_offline
def test_loading_local_dataset_folder(self, tokenizer):

View File

@@ -238,21 +238,22 @@ class TestDeduplicateRLDataset:
@enable_hf_offline
def test_load_with_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
self,
cfg,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
tokenizer_huggyllama,
):
"""Verify that loading with deduplication removes duplicates."""
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
@@ -263,19 +264,20 @@ class TestDeduplicateRLDataset:
@enable_hf_offline
def test_load_without_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
self,
cfg,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
tokenizer_huggyllama,
):
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama

View File

@@ -1,7 +1,7 @@
"""Module for testing streaming dataset sequence packing"""
import pytest
from datasets import concatenate_datasets, load_dataset
from datasets import concatenate_datasets
from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer
@@ -27,7 +27,6 @@ class TestBatchedSamplerPacking:
Test class for packing streaming dataset sequences
"""
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@pytest.mark.parametrize(
"batch_size, num_workers",
[
@@ -38,14 +37,20 @@ class TestBatchedSamplerPacking:
],
)
@pytest.mark.parametrize("max_seq_length", [4096, 512])
@pytest.mark.parametrize("sequential", [True, False])
@enable_hf_offline
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
def test_packing(
self,
dataset_winglian_tiny_shakespeare,
batch_size,
num_workers,
tokenizer,
max_seq_length,
sequential,
):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = load_dataset(
"winglian/tiny-shakespeare",
split="train",
)
dataset = dataset_winglian_tiny_shakespeare["train"]
cfg = DictDefault(
{
@@ -55,7 +60,7 @@ class TestBatchedSamplerPacking:
)
ds_cfg = DictDefault(
{
"field": "Text",
"field": "text",
}
)
completion_strategy = load(tokenizer, cfg, ds_cfg)
@@ -75,6 +80,7 @@ class TestBatchedSamplerPacking:
batch_max_len=max_seq_length,
group_size=100000,
bin_size=200,
sequential=sequential,
)
loader = DataLoader(

View File

@@ -2,13 +2,8 @@
import json
import logging
import unittest
from pathlib import Path
import pytest
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
@@ -61,24 +56,13 @@ test_data = {
}
class TestPromptTokenizationStrategies(unittest.TestCase):
class TestPromptTokenizationStrategies:
"""
Test class for prompt tokenization strategies.
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_no_sys_prompt(self):
def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens):
"""
tests the interface between the user and assistant parts
"""
@@ -86,7 +70,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
tokenizer_huggyllama_w_special_tokens,
False,
2048,
)
@@ -99,7 +83,8 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100
def test_alpaca(self):
@enable_hf_offline
def test_alpaca(self, tokenizer_huggyllama_w_special_tokens):
"""
tests the interface between the user and assistant parts
"""
@@ -107,7 +92,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
tokenizer_huggyllama_w_special_tokens,
False,
2048,
)
@@ -118,28 +103,17 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
assert example["labels"][world_idx - 1] == -100
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
class TestInstructionWSystemPromptTokenizingStrategy:
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_system_alpaca(self):
def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens):
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
strat = InstructionWSystemPromptTokenizingStrategy(
prompter,
self.tokenizer,
tokenizer_huggyllama_w_special_tokens,
False,
2048,
)
@@ -160,18 +134,13 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
assert example["input_ids"][8] == 11889 # USER
class Llama2ChatTokenizationTest(unittest.TestCase):
class Llama2ChatTokenizationTest:
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
# woraround because official Meta repos are not open
def test_llama2_chat_integration(self):
def test_llama2_chat_integration(self, tokenizer_llama2_7b):
with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin:
@@ -186,16 +155,18 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
prompter = Llama2ChatPrompter()
strat = LLama2ChatTokenizingStrategy(
prompter,
self.tokenizer,
tokenizer_llama2_7b,
False,
4096,
)
example = strat.tokenize_prompt(conversation)
for fields in ["input_ids", "attention_mask", "labels"]:
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])
# pytest assert equals
def compare_with_transformers_integration(self):
assert len(example[fields]) == len(tokenized_conversation[fields])
assert example[fields] == tokenized_conversation[fields]
def compare_with_transformers_integration(self, tokenizer_llama2_7b):
# this needs transformers >= v4.31.0
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
from transformers.pipelines.conversational import Conversation
@@ -234,49 +205,27 @@ If a question does not make any sense, or is not factually coherent, explain why
generated_responses=answers,
)
# pylint: disable=W0212
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
self.assertEqual(
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
)
assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]
class OrpoTokenizationTest(unittest.TestCase):
class OrpoTokenizationTest:
"""test case for the ORPO tokenization"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
tokenizer = LlamaTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
tokenizer.add_tokens(
[
AddedToken(
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
),
]
)
self.tokenizer = tokenizer
self.dataset = load_dataset(
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
).select([0])
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
def test_orpo_integration(self):
def test_orpo_integration(
self,
tokenizer_mistral_7b_instruct_chatml,
dataset_argilla_ultrafeedback_binarized_preferences_cleaned,
):
ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0])
strat = load(
self.tokenizer,
tokenizer_mistral_7b_instruct_chatml,
DictDefault({"train_on_inputs": False}),
DictDefault({"chat_template": "chatml"}),
)
res = strat.tokenize_prompt(self.dataset[0])
res = strat.tokenize_prompt(ds[0])
assert "rejected_input_ids" in res
assert "rejected_labels" in res
assert "input_ids" in res
@@ -295,7 +244,3 @@ class OrpoTokenizationTest(unittest.TestCase):
assert res["prompt_attention_mask"][0] == 1
assert res["prompt_attention_mask"][-1] == 0
if __name__ == "__main__":
unittest.main()

View File

@@ -321,3 +321,48 @@ class TestValidationCheckDatasetConfig(BaseValidation):
)
validate_config(cfg)
class TestOptimizerValidation(BaseValidation):
"""
Test muon optimizer validation
"""
def test_muon_deepspeed(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"optimizer": "muon",
"deepspeed": "deepspeed_configs/zero3.json",
}
)
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
validate_config(cfg)
def test_muon_fsdp(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"optimizer": "muon",
"fsdp": ["full_shard"],
"fsdp_config": {
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
}
)
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
validate_config(cfg)