Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
1a7f048c6b add SOAP optimizer 2025-03-31 08:33:19 -04:00
Wing Lian
76d26366ad upstream updates for momentum change 2025-03-31 08:33:19 -04:00
Wing Lian
64fe284765 add soap optimize 2025-03-31 08:33:19 -04:00
70 changed files with 1303 additions and 1951 deletions

View File

@@ -40,12 +40,6 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -25,12 +25,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -87,12 +87,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -42,7 +42,8 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras: vllm
# awaiting vllm#12721
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]

View File

@@ -33,15 +33,6 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -55,7 +46,7 @@ jobs:
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
- name: Update requirements.txt
run: |
@@ -67,7 +58,8 @@ jobs:
- name: Install dependencies
run: |
pip3 show torch
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
@@ -81,15 +73,10 @@ jobs:
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
- name: cleanup pip cache
run: |

View File

@@ -96,10 +96,6 @@ jobs:
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
@@ -260,7 +256,7 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -40,7 +40,6 @@ quartodoc:
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- title: Trainers
@@ -244,7 +243,6 @@ website:
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -2,5 +2,4 @@
set -e
# only run one test at a time so as not to OOM the GPU
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/

View File

@@ -170,7 +170,7 @@ axolotl merge-sharded-fsdp-weights config.yml
### evaluate
Evaluates a model's performance (loss etc) on the train and eval datasets.
Evaluates a model's performance using metrics specified in the config.
```bash
# Basic evaluation
@@ -197,8 +197,6 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
## Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
@@ -237,7 +235,7 @@ Create a cloud config YAML with your Modal settings:
```yaml
# cloud_config.yml
provider: modal
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
gpu_count: 1 # Number of GPUs to use
timeout: 86400 # Maximum runtime in seconds (24 hours)
branch: main # Git branch to use (optional)
@@ -250,7 +248,7 @@ volumes: # Persistent storage volumes
- name: axolotl-artifacts
mount: /workspace/artifacts
secrets: # Secrets to inject
env: # Environment variables
- WANDB_API_KEY
- HF_TOKEN
```
@@ -276,27 +274,15 @@ axolotl lm-eval config.yml --cloud cloud_config.yml
### Cloud Configuration Options
```yaml
provider: # compute provider, currently only `modal` is supported
gpu: # GPU type to use
gpu_count: # Number of GPUs (default: 1)
memory: # RAM in GB (default: 128)
timeout: # Maximum runtime in seconds
provider: # compute provider, currently only `modal` is supported
gpu: # GPU type to use
gpu_count: # Number of GPUs (default: 1)
memory: # RAM in GB (default: 128)
timeout: # Maximum runtime in seconds
timeout_preprocess: # Preprocessing timeout
branch: # Git branch to use
docker_tag: # Custom Docker image tag
volumes: # List of persistent storage volumes
# Environment variables to pass. Can be specified in two ways:
# 1. As a string: Will load the value from the host computer's environment variables
# 2. As a key-value pair: Will use the specified value directly
# Example:
# env:
# - CUSTOM_VAR # Loads from host's $CUSTOM_VAR
# - {CUSTOM_VAR: "value"} # Uses "value" directly
env:
# Secrets to inject. Same input format as `env` but for sensitive data.
secrets:
# - HF_TOKEN
# - WANDB_API_KEY
branch: # Git branch to use
docker_tag: # Custom Docker image tag
volumes: # List of persistent storage volumes
env: # Environment variables to pass
secrets: # Secrets to inject
```

View File

@@ -238,10 +238,10 @@ simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
vllm_device: # Optional[str]. Device to use for VLLM.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
vllm_dtype: # Optional[str]. Data type for VLLM.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
@@ -320,13 +320,9 @@ total_num_tokens:
sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
# Use batch flattening for speedups when not using sample_packing
batch_flattening:
@@ -358,27 +354,7 @@ lora_target_modules:
# - down_proj
# - up_proj
lora_target_linear: # If true, will target all linear modules
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
peft_layers_to_transform:
# Optional[bool]. Whether to use DoRA.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
peft_use_dora:
# Optional[bool]. Whether to use RSLoRA.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
peft_use_rslora:
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
peft_layer_replication:
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
# How to initialize LoRA weights. Default to True which is MS original implementation.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
peft_init_lora_weights:
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
@@ -611,31 +587,26 @@ max_grad_norm:
# currently only supported on Llama and Mistral
neftune_noise_alpha:
# Optional[bool]. Whether to bettertransformers
# Whether to bettertransformers
flash_optimum:
# Note: Only one of the following attention patches can be used at a time.
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
# Optional[bool]. Whether to use scaled-dot-product attention
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Optional[str]. Resume from a specific checkpoint dir
# Resume from a specific checkpoint dir
resume_from_checkpoint:
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# Be careful with this being turned on between different models.
auto_resume_from_checkpoints: false
@@ -686,11 +657,7 @@ ddp_broadcast_buffers:
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
flash_attention: true # SP requires flash attention
micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
heads_k_stride: 1
sequence_parallel_degree:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -35,22 +35,12 @@ description: Frequently asked questions
**Q: How to call Axolotl via custom python scripts?**
> A: Since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
**Q: ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as pad_token**
> A: This is because the tokenizer does not have a padding token. Please add a padding token to the tokenizer via:
> ```yaml
> special_tokens:
> # str. If you're not sure, set to same as `eos_token`.
> pad_token: "..."
> ```
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

View File

@@ -18,7 +18,6 @@ Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
@@ -67,28 +66,6 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Sequence parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).

View File

@@ -502,48 +502,9 @@ The input format is a simple JSON input with customizable fields based on the ab
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
using 4 GPUs - 2 for training, and 2 for vLLM:
::: {.callout-important}
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
:::
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
dtype: auto
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
rl: grpo
trl:
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_server_timeout: 300
```
```bash
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml
```
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
```bash
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
```
#### Reward functions
GRPO uses custom reward functions and transformations. Please have them ready locally.
For example, to load OpenAI's GSM8K and use a random reward for completions:
For ex, to load OpenAI's GSM8K and use a random reward for completions:
```python
# rewards.py
@@ -569,6 +530,8 @@ trl:
beta: 0.001
max_completion_length: 256
use_vllm: True
vllm_device: auto
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]

View File

@@ -23,11 +23,8 @@ Use sequence parallelism when:
To enable sequence parallelism, add the following to your configuration file:
```yaml
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # SP requires flash attention
micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
heads_k_stride: 1
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
@@ -61,22 +58,16 @@ To use sequence parallelism, you need:
## Example
```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # SP requires flash attention
micro_batch_size: 1 # SP requires this is set to 1
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
heads_k_stride: 1
sequence_parallel_degree: 2 # Split each sequence into 4 parts
flash_attention: true # Required with sequence parallelism
...
```
This will train the Llama 3 8B model with 8192 context length, with each sequence split
into 4 subsequences of length 2048 across 4 GPUs.
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism
@@ -88,14 +79,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size
First, note that sequence parallelism supports only the case where `micro_batch_size: 1`.
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches are processed per step
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -5,9 +5,6 @@ tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -57,8 +54,6 @@ fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -7,9 +7,6 @@ skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
@@ -51,8 +48,6 @@ fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
local_rank:
logging_steps: 1
flash_attention: true

View File

@@ -82,6 +82,3 @@ deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,80 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
# optionally might have model_type or tokenizer_type
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/lora-out
test_value: true
sequence_len: 4096
sample_packing: true
sample_packing_sequentially: true
curriculum_sampling: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.4
bitsandbytes==0.45.3
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
@@ -12,12 +12,12 @@ liger-kernel==0.5.5
packaging==23.2
peft==0.15.0
transformers==4.50.3
transformers==4.50.0
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.5.0
deepspeed==0.16.4
trl==0.16.0
trl==0.15.1
optimum==1.16.2
hf_transfer

View File

@@ -10,7 +10,7 @@ from pathlib import Path
from setuptools import find_packages, setup
def parse_requirements(extras_require_map):
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
@@ -67,7 +67,6 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2")
extras_require_map["vllm"] = ["vllm==0.8.1"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
@@ -87,7 +86,7 @@ def parse_requirements(extras_require_map):
except PackageNotFoundError:
pass
return _install_requires, _dependency_links, extras_require_map
return _install_requires, _dependency_links
def get_package_version():
@@ -104,50 +103,7 @@ def get_package_version():
return version_
extras_require = {
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": [
"flash-attn==2.7.4.post1",
"ring-flash-attn>=0.1.4",
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require
)
install_requires, dependency_links = parse_requirements()
setup(
version=get_package_version(),
@@ -160,5 +116,40 @@ setup(
"axolotl=axolotl.cli.main:main",
],
},
extras_require=extras_require_build,
extras_require={
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
],
},
)

View File

@@ -35,55 +35,6 @@ class TrainerCliArgs:
num_processes: Optional[int] = field(default=None)
@dataclass
class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: int = field(
default=1,
metadata={"help": "Number of tensor parallel workers to use."},
)
host: str = field(
default="0.0.0.0", # nosec B104
metadata={"help": "Host address to run the server on."},
)
port: int = field(
default=8000,
metadata={"help": "Port to run the server on."},
)
gpu_memory_utilization: Optional[float] = field(
default=None,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
"out-of-memory (OOM) errors during initialization."
},
)
dtype: Optional[str] = field(
default=None,
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
enable_prefix_caching: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
"hardware support this feature."
},
)
@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""

View File

@@ -14,12 +14,7 @@ import yaml
from dotenv import load_dotenv
import axolotl
from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import (
add_options_from_config,
@@ -28,7 +23,6 @@ from axolotl.cli.utils import (
fetch_from_github,
filter_none_kwargs,
)
from axolotl.cli.vllm_serve import do_vllm_serve
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -322,14 +316,6 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(VllmServeCliArgs)
@filter_none_kwargs
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
cli.add_command(lm_eval)

View File

@@ -1,55 +0,0 @@
"""
CLI to start the vllm server for online RL
"""
from pathlib import Path
from typing import Union
from trl.scripts.vllm_serve import ScriptArguments
from trl.scripts.vllm_serve import main as vllm_serve_main
from axolotl.cli.config import load_cfg
def do_vllm_serve(
config: Union[Path, str],
cli_args: dict,
):
"""
Starts the VLLM server for serving LLM models used for online RL
Args
:param cfg: Parsed doct of the YAML config
:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs
Returns:
process_id: the process id of the started VLLM server
"""
cfg = load_cfg(config)
model = cfg.base_model
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = (
cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization
)
dtype = cli_args.get("dtype") or cfg.vllm.dtype
max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len
enable_prefix_caching = (
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
)
vllm_script_args = ScriptArguments(
model,
tensor_parallel_size=tensor_parallel_size,
host=host,
port=port,
gpu_memory_utilization=gpu_memory_utilization,
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
)
vllm_serve_main(vllm_script_args)

View File

@@ -69,6 +69,7 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
SaveModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
@@ -248,6 +249,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback())
return callbacks
@@ -524,15 +526,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
# handle ddp
ddp_find_unused_parameters = None
if self.cfg.ddp:
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
training_arguments_kwargs["ddp_find_unused_parameters"] = (
ddp_find_unused_parameters
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = []
@@ -667,6 +663,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "soap":
from axolotl.utils.optimizers.soap import SOAP
optimizer_cls = SOAP
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
@@ -941,6 +942,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(SaveModelCallback())
return callbacks
@@ -1043,10 +1045,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
@@ -1165,7 +1163,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
@@ -1183,3 +1180,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer.add_callback(callback)
return dpo_trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# build PPOConfig
pass

View File

@@ -3,16 +3,16 @@
# pylint: disable=unused-import
# flake8: noqa
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
from axolotl.core.trainers.relora import ReLoRATrainer
from axolotl.core.trainers.trl import (
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
AxolotlPPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
TRLPPOTrainer,
)

View File

@@ -12,8 +12,8 @@ from typing import Any, Literal
import datasets
import torch
import torch.nn as nn
from datasets import Dataset
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
@@ -26,8 +26,11 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.handlers import SequenceParallelHandler
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.mixins import (
OptimizerMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -37,7 +40,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
class AxolotlTrainer(TrainerMixins, Trainer):
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -63,7 +66,9 @@ class AxolotlTrainer(TrainerMixins, Trainer):
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
self.sequence_parallel_handler = SequenceParallelHandler(self.args)
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
@@ -107,7 +112,6 @@ class AxolotlTrainer(TrainerMixins, Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
)
@@ -124,7 +128,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
base_sampler = self._sp_get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
@@ -160,7 +164,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
base_sampler = self._sp_get_eval_sampler(eval_dataset)
elif use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
@@ -232,10 +236,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
return dataloader
# Otherwise prepare with accelerator
dataloader = self.accelerator.prepare_data_loader(dataloader)
return dataloader
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
@@ -344,57 +345,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping of inputs.
num_items_in_batch: The number of items in the batch.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
@override
def compute_loss(
@@ -638,3 +589,27 @@ class AxolotlTrainer(TrainerMixins, Trainer):
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
loss = super().training_step(model, inputs, num_items_in_batch)
return loss

View File

@@ -1,10 +1,14 @@
"""DPO Specific Strategy for training"""
"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy:
"""Strategy for DPO training"""
"""
Strategy for DPO training
"""
@classmethod
def get_trainer_class(cls):

View File

@@ -1,4 +1,6 @@
"""Axolotl specific DPO args"""
"""
Axolotl specific DPO args
"""
from dataclasses import dataclass
@@ -9,4 +11,6 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""DPO config for DPO training"""
"""
DPO config for DPO training
"""

View File

@@ -1,7 +1,10 @@
"""DPO trainer for axolotl"""
"""
DPO trainer for axolotl
"""
import gc
from functools import wraps
from typing import Any
from typing import Any, Dict, Union
import torch
from peft.optimizers import create_loraplus_optimizer
@@ -10,8 +13,7 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.handlers import SequenceParallelHandler
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.mixins import SchedulerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -21,18 +23,18 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
"""Extend the base DPOTrainer for axolotl helpers"""
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
def create_optimizer(self):
# pylint: disable=duplicate-code
@@ -86,7 +88,7 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> dict:
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
@@ -115,9 +117,10 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any | None],
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
return super().training_step(model, inputs, num_items_in_batch)
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss

View File

@@ -40,15 +40,18 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_guided_decoding_regex:
grpo_args_kwargs["vllm_guided_decoding_regex"] = (
trl.vllm_guided_decoding_regex
grpo_args_kwargs["vllm_device"] = (
trl.vllm_device if trl.vllm_device else "auto"
)
if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
trl.vllm_gpu_memory_utilization
)
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
if trl.num_generations:
grpo_args_kwargs["num_generations"] = trl.num_generations
@@ -67,25 +70,6 @@ class GRPOStrategy:
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
if trl.scale_rewards is not None:
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
if trl.temperature is not None:
grpo_args_kwargs["temperature"] = trl.temperature
if trl.top_p is not None:
grpo_args_kwargs["top_p"] = trl.top_p
if trl.top_k is not None:
grpo_args_kwargs["top_k"] = trl.top_k
if trl.min_p is not None:
grpo_args_kwargs["min_p"] = trl.min_p
if trl.repetition_penalty is not None:
grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty
if trl.num_iterations is not None:
grpo_args_kwargs["num_iterations"] = trl.num_iterations
if trl.epsilon is not None:
grpo_args_kwargs["epsilon"] = trl.epsilon
return grpo_args_kwargs
@classmethod

View File

@@ -1,65 +1,109 @@
"""Axolotl GRPO trainer"""
"""
Axolotl GRPO trainer
"""
from contextlib import nullcontext
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel
from trl import GRPOConfig, GRPOTrainer
from trl.models import unwrap_model_for_generation
from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.mixins import TrainerMixins
if is_deepspeed_available():
import deepspeed
from axolotl.core.trainers.base import SchedulerMixin
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
# mypy: ignore-errors
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"]
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# pylint: disable=access-member-before-definition
# Enable gradient checkpointing if requested
if kwargs["args"].gradient_checkpointing:
# Ensure use_cache is disabled
if hasattr(self.model, "config"):
self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr(
self.model.base_model, "gradient_checkpointing_enable"
):
self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"):
self.model.gradient_checkpointing_enable()
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
)
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
return model
# pylint: enable=unused-argument,redefined-builtin
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model,
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = (
unwrapped_model._orig_mod # pylint: disable=protected-access
)
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
for k, v in state_dict.items()
if "original_module" not in k
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

View File

@@ -1,3 +0,0 @@
"""Init for trainer handlers"""
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler

View File

@@ -1,123 +0,0 @@
"""Handler class for sequence parallel trainer logic"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import DistributedSampler
class SequenceParallelHandler:
"""
Handler class that encapsulates sequence parallelism functionality.
This replaces the SequenceParallelMixin with a composition-based approach.
"""
def __init__(self, args=None):
"""
Initialize the sequence parallel handler.
Args:
args: The arguments object containing sequence parallelism settings.
"""
self.args = args
self.ring_attn_group = None
# Set up sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
from ring_flash_attn import update_ring_flash_attn_params
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
self.update_ring_flash_attn_params = update_ring_flash_attn_params
self.ring_attn_group = get_ring_attn_group()
def create_sequence_parallel_sampler(
self,
dataset,
shuffle=True,
is_eval=False,
):
"""
Helper method to create sampler for sequence parallelism (SP).
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _get_train_sampler(self, dataset):
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset.
Returns:
Configured sequence parallel sampler.
"""
return self.create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _get_eval_sampler(self, eval_dataset):
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self.create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)

View File

@@ -3,12 +3,6 @@
# pylint: disable=unused-import
# flake8: noqa
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TrainerMixins(
OptimizerMixin, RngLoaderMixin, SchedulerMixin
):
"""Stub class combining all mixins for Axolotl trainers."""
from .optimizer import OptimizerMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,65 +0,0 @@
"""
Temporary fix/override for bug in resume from checkpoint
See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release
"""
import logging
import os
import random
import numpy as np
import torch
from transformers import Trainer, is_torch_npu_available
from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode
LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer):
"""Mixin for method override to load RNG states from a checkpoint"""
def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`
if checkpoint is None:
return
if self.args.world_size > 1:
process_index = self.args.process_index
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
if not os.path.isfile(rng_file):
LOG.info(
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
else:
rng_file = os.path.join(checkpoint, "rng_state.pth")
if not os.path.isfile(rng_file):
LOG.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
"fashion, reproducibility is not guaranteed."
)
return
# Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
# which requires allowlisted classes when loading with weights_only=True.
with safe_globals():
checkpoint_rng_state = torch.load(rng_file) # nosec B614
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
if torch.cuda.is_available():
set_rng_state_for_device(
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
)
if is_torch_npu_available():
set_rng_state_for_device(
"NPU", torch.npu, checkpoint_rng_state, is_distributed
)

View File

@@ -1,5 +1,4 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
# TODO(Dan): remove
import logging
from typing import Any
@@ -71,12 +70,12 @@ class SequenceParallelMixin:
drop_last=not is_eval,
)
def _get_train_sampler(self, dataset) -> Sampler | None:
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset.
dataset: The training dataset
Returns:
Configured sequence parallel sampler.
@@ -86,7 +85,7 @@ class SequenceParallelMixin:
shuffle=not self.args.curriculum_sampling,
)
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.

View File

@@ -13,10 +13,10 @@ from trl import (
RewardTrainer,
)
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
class TRLPPOTrainer(PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"]
@@ -74,8 +74,10 @@ class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
)
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
"""Extend the base ORPOTrainer for axolotl helpers"""
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
@@ -152,14 +154,18 @@ class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
return loss, metrics
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
"""Extend the base KTOTrainer for axolotl helpers"""
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
"""Extend the base CPOTrainer for axolotl helpers"""
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
@@ -238,13 +244,17 @@ class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
return loss, metrics
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
"""Extend the base RewardTrainer for axolotl helpers"""
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
"""Extend the base trl.PRMTrainer for axolotl helpers"""
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]

View File

@@ -12,7 +12,9 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass
class AxolotlTrainingMixins:
"""Mixin class for the Axolotl training args."""
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
@@ -32,12 +34,6 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},

View File

@@ -15,7 +15,6 @@ from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -160,6 +159,4 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
del model
del tokenizer
cleanup_distributed()
return all_metrics

View File

@@ -6,22 +6,11 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2.
"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from accelerate.logging import get_logger
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.logging_config import configure_logging
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
configure_logging()
LOG = get_logger(__name__)
@@ -43,133 +32,19 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
Setter for ring attention group on this rank.
Args:
ring_attn_group: Process group for ring attention.
Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int):
"""
Patch flash attention a second time to handle batched data. This is a hack to
accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is
specified in the Axolotl config.
Args:
sequence_parallel_degree: Sequence parallelism factor.
"""
# Store the original flash attention function
original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
def sequential_batch_flash_attention(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
dropout: float = 0.0,
scaling: float | None = None,
sliding_window: int | None = None,
softcap: float | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
# Check if we have a batch dimension > 1
batch_size = query.shape[0]
if batch_size <= 1:
return original_flash_attention(
module,
query,
key,
value,
attention_mask,
dropout,
scaling,
sliding_window,
softcap,
**kwargs
)
# Process each item in the batch separately
outputs = []
for i in range(batch_size):
# Extract single batch item
q_item = query[i:i+1]
k_item = key[i:i+1]
v_item = value[i:i+1]
# Handle attention mask - it might be None or have different shapes
mask_item = None
if attention_mask is not None:
# The mask could have different formats depending on implementation
if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size:
mask_item = attention_mask[i:i+1]
else:
# For broadcast masks that don't have a batch dimension
mask_item = attention_mask
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = q_item.shape[0]
seq_len = q_item.shape[2]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
# Call the original function for a single batch item
output, _ = original_flash_attention(
module,
q_item,
k_item,
v_item,
mask_item,
dropout,
scaling,
sliding_window,
softcap,
**kwargs
)
outputs.append(output)
dist.barrier()
# Concatenate results along batch dimension
concatenated_output = torch.cat(outputs, dim=0)
return concatenated_output, None
# Replace the original function with our sequential version
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
def register_ring_attn(sequence_parallel_degree: int):
"""
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
"""
if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...")
return
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
@@ -209,12 +84,6 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if heads_k_stride is None:
heads_k_stride = 1
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)
patch_flash_attention_for_sequential_batch(sequence_parallel_degree)
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)

View File

@@ -22,7 +22,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"phi3",
"gemma",
"gemma2",
"gemma3",
"gemma3_text",
"cohere",
"cohere2",

View File

@@ -27,7 +27,6 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -158,8 +157,6 @@ def setup_signal_handler(
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
cleanup_distributed()
sys.exit(0)
_model_weakref = weakref.ref(model)
@@ -481,7 +478,7 @@ def train(
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
# Setup model, tokenizer, (causal or RLHF) trainer etc.
(
trainer,
model,
@@ -490,26 +487,34 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
# Handle untrained tokens if configured
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Additional setup
# Save initial configs
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_model_card(cfg)
# Execute the training
resume_from_checkpoint = determine_resume_checkpoint(cfg)
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model and cleanup
# Save the trained model
save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()
return model, tokenizer, trainer

View File

@@ -816,6 +816,27 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
return control
class SaveModelCallback(TrainerCallback):
"""Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
):
control.should_save = True
return control
class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""

View File

@@ -112,7 +112,6 @@ class DataCollatorForSeq2Seq:
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys()
labels = None
if return_tensors is None:
return_tensors = self.return_tensors
@@ -165,8 +164,6 @@ class DataCollatorForSeq2Seq:
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
if not has_attn_mask:
del features["attention_mask"]
# prepare decoder_input_ids
if (

View File

@@ -238,8 +238,7 @@ def load_dataset_w_config(
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.data_files:
fp: str | list[str] | None = None
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,

View File

@@ -71,8 +71,8 @@ def barrier():
def is_main_process():
"""
Check if the current process is the main process. If not in distributed mode,
always return `True`.
Check if the current process is the main process.
If not in distributed mode, always return True.
"""
if not is_distributed():
return True
@@ -87,18 +87,6 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
def cleanup_distributed():
"""
Destroy process group if torch distributed is initialized. Called in training early
termination or when training successfully completes.
"""
# Ensure that all operations are completed before destroying the process group
torch.cuda.synchronize()
# Destroy the process group
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@contextmanager
def zero_only():
"""

View File

@@ -609,10 +609,7 @@ class ModelLoader:
# Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash
# attn to use ring comm for SP training across multiple GPUs.
register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride,
)
register_ring_attn(self.cfg.sequence_parallel_degree)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
@@ -1351,7 +1348,9 @@ def load_model(
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""Load a model for a given configuration and tokenizer."""
"""
Load a model for a given configuration and tokenizer.
"""
model_loader = ModelLoader(
cfg,
tokenizer,
@@ -1360,16 +1359,12 @@ def load_model(
reference_model=reference_model,
**kwargs,
)
return model_loader.load_model()
def load_adapter(
model: PreTrainedModel,
cfg: DictDefault,
adapter: str | None,
inference: bool = False,
) -> tuple[PreTrainedModel, PeftConfig | None]:
def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
@@ -1382,9 +1377,8 @@ def load_adapter(
raise NotImplementedError(f"{adapter} peft adapter not available")
def load_llama_adapter(
model: PreTrainedModel, cfg: DictDefault
) -> tuple[PreTrainedModel, PeftConfig | None]:
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import AdaptionPromptConfig, get_peft_model
peft_config = AdaptionPromptConfig(
@@ -1408,7 +1402,7 @@ def load_llama_adapter(
return model, peft_config
def find_all_linear_names(model: PreTrainedModel):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Nikhil Vyas
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,495 @@
# pylint: skip-file
# Copied from https://github.com/nikhilvyas/SOAP
from itertools import chain
import torch
import torch.optim as optim
# Parts of the code are modifications of Pytorch's AdamW optimizer
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py
class SOAP(optim.Optimizer):
"""
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.003):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
Adam's betas parameters (b1, b2).
shampoo_beta (`float`, *optional*, defaults to -1):
If >= 0, use this beta for the preconditioner (L and R in paper, state["GG"] below) moving average instead of betas[1].
eps (`float`, *optional*, defaults to 1e-08):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
precondition_frequency (`int`, *optional*, defaults to 10):
How often to update the preconditioner.
max_precond_dim (`int`, *optional*, defaults to 10000):
Maximum dimension of the preconditioner.
Set to 10000, so that we exclude most common vocab sizes while including layers.
merge_dims (`bool`, *optional*, defaults to `False`):
Whether or not to merge dimensions of the preconditioner.
precondition_1d (`bool`, *optional*, defaults to `False`):
Whether or not to precondition 1D gradients.
normalize_grads (`bool`, *optional*, defaults to `False`):
Whether or not to normalize gradients per layer.
Helps at large precondition_frequency (~100 in our experiments),
but hurts performance at small precondition_frequency (~10 in our experiments).
data_format (`str`, *optional*, defaults to `channels_first`):
Data format of the input for convolutional layers.
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias correction in Adam.
"""
def __init__(
self,
params,
lr: float = 3e-3,
betas=(0.95, 0.95),
shampoo_beta: float = -1,
eps: float = 1e-8,
weight_decay: float = 0.01,
precondition_frequency: int = 10,
max_precond_dim: int = 10000, #
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
precondition_1d: bool = False,
normalize_grads: bool = False,
data_format: str = "channels_first",
correct_bias: bool = True,
):
defaults = {
"lr": lr,
"betas": betas,
"shampoo_beta": shampoo_beta,
"eps": eps,
"weight_decay": weight_decay,
"precondition_frequency": precondition_frequency,
"max_precond_dim": max_precond_dim,
"merge_dims": merge_dims,
"precondition_1d": precondition_1d,
"normalize_grads": normalize_grads,
"correct_bias": correct_bias,
}
super().__init__(params, defaults)
self._data_format = data_format
def merge_dims(self, grad, max_precond_dim):
"""
Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
"""
assert self._data_format in ["channels_first", "channels_last"]
if self._data_format == "channels_last" and grad.dim() == 4:
grad = grad.permute(0, 3, 1, 2)
shape = grad.shape
new_shape = []
curr_shape = 1
for sh in shape:
temp_shape = curr_shape * sh
if temp_shape > max_precond_dim:
if curr_shape > 1:
new_shape.append(curr_shape)
curr_shape = sh
else:
new_shape.append(sh)
curr_shape = 1
else:
curr_shape = temp_shape
if curr_shape > 1 or len(new_shape) == 0:
new_shape.append(curr_shape)
new_grad = grad.reshape(new_shape)
return new_grad
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
if closure is None:
loss = None
else:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if "step" not in state:
state["step"] = 0
# State initialization
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)
if "Q" not in state:
self.init_preconditioner(
grad,
state,
precondition_frequency=group["precondition_frequency"],
precondition_1d=group["precondition_1d"],
shampoo_beta=(
group["shampoo_beta"]
if group["shampoo_beta"] >= 0
else group["betas"][1]
),
max_precond_dim=group["max_precond_dim"],
merge_dims=group["merge_dims"],
)
self.update_preconditioner(
grad,
state,
max_precond_dim=group["max_precond_dim"],
merge_dims=group["merge_dims"],
precondition_1d=group["precondition_1d"],
)
continue # first step is skipped so that we never use the current gradients in the projection.
# Projecting gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state["GG"]
grad_projected = self.project(
grad,
state,
merge_dims=group["merge_dims"],
max_precond_dim=group["max_precond_dim"],
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).add_(
grad_projected.square(), alpha=(1.0 - beta2)
)
denom = exp_avg_sq.sqrt().add_(group["eps"])
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state["GG"]
# exp_avg_projected = self.project(
# exp_avg,
# state,
# merge_dims=group["merge_dims"],
# max_precond_dim=group["max_precond_dim"],
# )
exp_avg_projected = exp_avg
step_size = group["lr"]
if group["correct_bias"]:
bias_correction1 = 1.0 - beta1 ** (state["step"])
bias_correction2 = 1.0 - beta2 ** (state["step"])
step_size = step_size * (bias_correction2**0.5) / bias_correction1
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
# to the original space
norm_grad = self.project_back(
exp_avg_projected / denom,
state,
merge_dims=group["merge_dims"],
max_precond_dim=group["max_precond_dim"],
)
if group["normalize_grads"]:
norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5)
p.add_(norm_grad, alpha=-step_size)
# From AdamW code: Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
# Update is done after the gradient step to avoid using current gradients in the projection.
self.update_preconditioner(
grad,
state,
max_precond_dim=group["max_precond_dim"],
merge_dims=group["merge_dims"],
precondition_1d=group["precondition_1d"],
)
return loss
def init_preconditioner(
self,
grad,
state,
precondition_frequency=10,
shampoo_beta=0.95,
max_precond_dim=10000,
precondition_1d=False,
merge_dims=False,
):
"""
Initializes the preconditioner matrices (L and R in the paper).
"""
state["GG"] = (
[]
) # Will hold all the preconditioner matrices (L and R in the paper).
if grad.dim() == 1:
if not precondition_1d or grad.shape[0] > max_precond_dim:
state["GG"].append([])
else:
state["GG"].append(
torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)
)
else:
if merge_dims:
grad = self.merge_dims(grad, max_precond_dim)
for sh in grad.shape:
if sh > max_precond_dim:
state["GG"].append([])
else:
state["GG"].append(torch.zeros(sh, sh, device=grad.device))
state["Q"] = None # Will hold all the eigenbases of the preconditioner.
state["precondition_frequency"] = precondition_frequency
state["shampoo_beta"] = shampoo_beta
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""
Projects the gradient to the eigenbases of the preconditioner.
"""
original_shape = grad.shape
if merge_dims:
if grad.dim() == 4 and self._data_format == "channels_last":
permuted_shape = grad.permute(0, 3, 1, 2).shape
grad = self.merge_dims(grad, max_precond_dim)
for mat in state["Q"]:
if len(mat) > 0:
grad = torch.tensordot(
grad,
mat,
dims=[[0], [0]],
)
else:
permute_order = list(range(1, len(grad.shape))) + [0]
grad = grad.permute(permute_order)
if merge_dims:
if self._data_format == "channels_last" and len(original_shape) == 4:
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
grad = grad.reshape(original_shape)
return grad
def update_preconditioner(
self,
grad,
state,
max_precond_dim=10000,
merge_dims=False,
precondition_1d=False,
):
"""
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
"""
if state["Q"] is not None:
state["exp_avg"] = self.project_back(
state["exp_avg"],
state,
merge_dims=merge_dims,
max_precond_dim=max_precond_dim,
)
if grad.dim() == 1:
if precondition_1d and grad.shape[0] <= max_precond_dim:
state["GG"][0].lerp_(
grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]
)
else:
if merge_dims:
new_grad = self.merge_dims(grad, max_precond_dim)
for idx, sh in enumerate(new_grad.shape):
if sh <= max_precond_dim:
outer_product = torch.tensordot(
new_grad,
new_grad,
dims=[
[
*chain(
range(idx), range(idx + 1, len(new_grad.shape))
)
]
]
* 2,
)
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
else:
for idx, sh in enumerate(grad.shape):
if sh <= max_precond_dim:
outer_product = torch.tensordot(
grad,
grad,
# Contracts across all dimensions except for k.
dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]]
* 2,
)
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
if state["Q"] is None:
state["Q"] = self.get_orthogonal_matrix(state["GG"])
if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0:
state["Q"] = self.get_orthogonal_matrix_QR(
state, max_precond_dim, merge_dims
)
# state["Q"] = self.get_fast_QR(state, max_precond_dim, merge_dims)
if state["step"] > 0:
state["exp_avg"] = self.project(
state["exp_avg"],
state,
merge_dims=merge_dims,
max_precond_dim=max_precond_dim,
)
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""
Projects the gradient back to the original space.
"""
original_shape = grad.shape
if merge_dims:
if self._data_format == "channels_last" and grad.dim() == 4:
permuted_shape = grad.permute(0, 3, 1, 2).shape
grad = self.merge_dims(grad, max_precond_dim)
for mat in state["Q"]:
if len(mat) > 0:
grad = torch.tensordot(
grad,
mat,
dims=[[0], [1]],
)
else:
permute_order = list(range(1, len(grad.shape))) + [0]
grad = grad.permute(permute_order)
if merge_dims:
if self._data_format == "channels_last" and len(original_shape) == 4:
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
grad = grad.reshape(original_shape)
return grad
def get_orthogonal_matrix(self, mat):
"""
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
"""
matrix = []
for m in mat:
if len(m) == 0:
matrix.append([])
continue
if m.data.dtype != torch.float:
float_data = False
original_type = m.data.dtype
original_device = m.data.device
matrix.append(m.data.float())
else:
float_data = True
matrix.append(m.data)
final = []
for m in matrix:
if len(m) == 0:
final.append([])
continue
try:
_, Q = torch.linalg.eigh(
m + 1e-30 * torch.eye(m.shape[0], device=m.device)
)
except: # pylint: disable=bare-except # noqa: E722
_, Q = torch.linalg.eigh(
m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)
)
Q = Q.to(m.dtype)
Q = torch.flip(Q, [1])
if not float_data:
Q = Q.to(original_device).type(original_type)
final.append(Q)
return final
def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
"""
Computes the eigenbases of the preconditioner using one round of power iteration
followed by torch.linalg.qr decomposition.
"""
precond_list = state["GG"]
orth_list = state["Q"]
matrix = []
orth_matrix = []
for m, o in zip(precond_list, orth_list):
if len(m) == 0:
matrix.append([])
orth_matrix.append([])
continue
if m.data.dtype != torch.float:
float_data = False
original_type = m.data.dtype
original_device = m.data.device
matrix.append(m.data.float())
orth_matrix.append(o.data.float())
else:
float_data = True
matrix.append(m.data.float())
orth_matrix.append(o.data.float())
orig_shape = state["exp_avg_sq"].shape
if self._data_format == "channels_last" and len(orig_shape) == 4:
permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape
if merge_dims:
exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim)
else:
exp_avg_sq = state["exp_avg_sq"]
final = []
for ind, (m, o) in enumerate(zip(matrix, orth_matrix)):
if len(m) == 0:
final.append([])
continue
est_eig = torch.diag(o.T @ m @ o)
sort_idx = torch.argsort(est_eig, descending=True)
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
o = o[:, sort_idx]
power_iter = m @ o
Q, _ = torch.linalg.qr(power_iter)
if not float_data:
Q = Q.to(original_device).type(original_type)
final.append(Q)
if merge_dims:
if self._data_format == "channels_last" and len(orig_shape) == 4:
exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
exp_avg_sq = exp_avg_sq.reshape(orig_shape)
state["exp_avg_sq"] = exp_avg_sq
return final

View File

@@ -8,13 +8,11 @@ from typing import Any, Iterable, List, Union
import numba
import numpy as np
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from torch.utils.data import BatchSampler, Sampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
@numba.njit
@@ -105,55 +103,6 @@ def allocate(
return result, s, len(result) * c * n
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
"""
Sequential allocator that preserves example order
Parameters:
- lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
Returns:
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
result = []
total_used = 0
# First, do sequential packing into bins
all_bins = []
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
for idx, size in enumerate(lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
remaining_capacity -= size
total_used += size
else:
# Example doesn't fit, start a new bin
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = c - size
total_used += size
# Add the last bin if not empty
if current_bin:
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
result.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * c
class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack"""
@@ -166,7 +115,6 @@ class MultipackBatchSampler(BatchSampler):
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
sequential: bool = False,
**kwargs,
):
super().__init__(sampler, batch_size, drop_last)
@@ -174,7 +122,6 @@ class MultipackBatchSampler(BatchSampler):
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
assert isinstance(self.lengths, np.ndarray)
@@ -189,11 +136,6 @@ class MultipackBatchSampler(BatchSampler):
# the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warn(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)
def set_epoch(self, epoch: int):
self.epoch = epoch
@@ -203,21 +145,13 @@ class MultipackBatchSampler(BatchSampler):
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
if self.sequential:
batches, total_used, total_slots = allocate_sequentially(
lengths=lengths,
rank=0,
c=self.batch_max_len,
n=1,
)
else:
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
batches = [
[

View File

@@ -46,7 +46,6 @@ from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig
LOG = logging.getLogger(__name__)
@@ -87,9 +86,6 @@ class AxolotlInputConfig(
trl: TRLConfig | None = Field(
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
)
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
)
reward_model: bool | None = None
process_reward_model: bool | None = None
num_labels: int | None = None
@@ -192,7 +188,6 @@ class AxolotlInputConfig(
sample_packing: bool | None = None
sample_packing_group_size: int | None = 100_000
sample_packing_bin_size: int | None = 200
sample_packing_sequentially: bool | None = None
eval_sample_packing: bool | None = None
pad_to_sequence_len: bool | None = None
curriculum_sampling: bool | None = None
@@ -253,7 +248,6 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -1114,7 +1108,7 @@ class AxolotlInputConfig(
@field_validator("sequence_parallel_degree", mode="before")
@classmethod
def check_sequence_parallel_degree(cls, value, info):
def check_sequence_parallel_config(cls, value, info):
if not value:
value = 1
@@ -1135,17 +1129,6 @@ class AxolotlInputConfig(
return value
@model_validator(mode="before")
@classmethod
def check_muon_deepspeed_fsdp(cls, data):
if data.get("optimizer") == "muon" and (
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
):
raise ValueError(
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
@@ -1281,12 +1264,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data["beta"] != data["trl"]["beta"]:
raise ValueError("beta and trl.beta must match or one must be removed")
return data
@model_validator(mode="after")
def check_min_torch_version(self):
if self.env_capabilities and self.env_capabilities.torch_version:
torch_version = self.env_capabilities.torch_version
if version.parse(torch_version) < version.parse("2.5.1"):
LOG.warning(
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
)

View File

@@ -52,3 +52,4 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
soap = "soap" # pylint: disable=invalid-name

View File

@@ -20,30 +20,27 @@ class TRLConfig(BaseModel):
)
# GRPO specific args
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
use_vllm: bool = Field(
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
use_vllm: bool | None = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
)
vllm_server_host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host of the vLLM server to connect to"},
vllm_device: str | None = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
vllm_server_port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to connect to"},
vllm_gpu_memory_utilization: float | None = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
vllm_server_timeout: int | None = Field(
vllm_dtype: str | None = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: int | None = Field(
default=None,
json_schema_extra={
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
"after the timeout, a `ConnectionError` is raised."
},
)
vllm_guided_decoding_regex: str | None = Field(
default=None,
json_schema_extra={
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
"description": "Maximum length of the model context for VLLM"
},
)
@@ -88,48 +85,3 @@ class TRLConfig(BaseModel):
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)
scale_rewards: bool = Field(
default=True,
json_schema_extra={
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
},
)
temperature: float | None = Field(
default=None,
json_schema_extra={"description": "Sampling temperature for the GRPO policy."},
)
top_p: float | None = Field(
default=None,
json_schema_extra={
"description": "Top-p sampling probability for the generation policy."
},
)
top_k: int | None = Field(
default=None,
json_schema_extra={"description": "Top-k sampling for the generation policy."},
)
min_p: float | None = Field(
default=None,
json_schema_extra={
"description": "Minimum probability for the generation policy."
},
)
repetition_penalty: float | None = Field(
default=None,
json_schema_extra={
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
},
)
num_iterations: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
},
)
epsilon: float | None = Field(
default=None,
json_schema_extra={
"description": "Epsilon value for clipping in the GRPO algorithm."
},
)

View File

@@ -1,38 +0,0 @@
"""
Pydantic models for VLLM configuration, used primarily for RL training with TRL + grpo
"""
from pydantic import BaseModel, Field
class VllmConfig(BaseModel):
"""
Configuration for VLLM server
"""
device: str | None = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
tensor_parallel_size: int | None = Field(
default=None,
json_schema_extra={"description": "Tensor parallel size for VLLM"},
)
gpu_memory_utilization: float | None = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
dtype: str | None = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
max_model_len: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
enable_prefix_caching: bool | None = Field(
default=None,
json_schema_extra={"description": "Enable prefix caching for VLLM"},
)

View File

@@ -13,7 +13,7 @@ import torch
import torch.cuda
from accelerate.logging import get_logger
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.model_config_type in ["mamba", "gemma3"]:
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
@@ -456,18 +456,13 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
else:
sampler_batch_size = cfg.micro_batch_size
batch_max_len = cfg.sequence_len
if cfg.curriculum_sampling:
sampler = SequentialSampler(train_dataset)
else:
sampler = RandomSampler(train_dataset)
sampler = MultipackBatchSampler(
sampler=sampler,
sampler=RandomSampler(train_dataset),
lengths=get_dataset_lengths(train_dataset),
batch_size=sampler_batch_size,
batch_max_len=batch_max_len,
group_size=cfg.sample_packing_group_size,
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
)

View File

@@ -8,13 +8,11 @@ import shutil
import sys
import tempfile
import time
from pathlib import Path
import datasets
import pytest
import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download
from tokenizers import AddedToken
from transformers import AutoTokenizer
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
@@ -50,14 +48,6 @@ def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs)
@pytest.fixture(scope="session", autouse=True)
def download_ds_fixture_bundle():
ds_dir = snapshot_download_w_retry(
"axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
)
return Path(ds_dir)
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
@@ -111,50 +101,42 @@ def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
@pytest.fixture(scope="session", autouse=True)
def download_argilla_distilabel_intel_orca_dpo_dataset():
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
)
# @pytest.fixture(scope="session", autouse=True)
# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# # download the dataset
# snapshot_download_w_retry(
# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
# )
@pytest.fixture(scope="session", autouse=True)
def download_fozzie_alpaca_dpo_dataset():
# download the dataset
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
)
snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test",
repo_type="dataset",
revision="ea82cff",
)
# @pytest.fixture(scope="session", autouse=True)
# def download_fozzie_alpaca_dpo_dataset():
# # download the dataset
# snapshot_download_w_retry(
# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
# )
# snapshot_download_w_retry(
# "fozziethebeat/alpaca_messages_2k_dpo_test",
# repo_type="dataset",
# revision="ea82cff",
# )
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
# @pytest.fixture(scope="session")
# @disable_hf_offline
# def dataset_fozzie_alpaca_dpo_dataset(
# download_fozzie_alpaca_dpo_dataset,
# ): # pylint: disable=unused-argument,redefined-outer-name
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
#
#
# @pytest.fixture(scope="session")
# @disable_hf_offline
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
# download_fozzie_alpaca_dpo_dataset,
# ): # pylint: disable=unused-argument,redefined-outer-name
# return load_dataset(
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
# )
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset(
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
)
@pytest.fixture(scope="session", autouse=True)
@@ -281,7 +263,7 @@ def download_mlx_mistral_7b_model_fixture():
)
@pytest.fixture
@pytest.fixture(scope="session", autouse=True)
def download_llama2_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
@@ -291,7 +273,7 @@ def download_llama2_model_fixture():
)
@pytest.fixture
@pytest.fixture(scope="session", autouse=True)
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
@@ -302,57 +284,6 @@ def tokenizer_huggyllama(
return tokenizer
@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama_w_special_tokens(
tokenizer_huggyllama,
): # pylint: disable=redefined-outer-name
tokenizer_huggyllama.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
return tokenizer_huggyllama
@pytest.fixture
@enable_hf_offline
def tokenizer_llama2_7b(
download_llama2_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
return tokenizer
@pytest.fixture
@enable_hf_offline
def tokenizer_mistral_7b_instruct(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
@pytest.fixture
def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
tokenizer_mistral_7b_instruct.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
tokenizer_mistral_7b_instruct.add_tokens(
[
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
]
)
return tokenizer_mistral_7b_instruct
@pytest.fixture
def temp_dir():
# Create a temporary directory
@@ -418,60 +349,6 @@ def cleanup_monkeypatches():
globals().pop(module_global, None)
@pytest.fixture
def dataset_winglian_tiny_shakespeare(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
return datasets.load_from_disk(ds_path)
@pytest.fixture
def dataset_tatsu_lab_alpaca(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_mhenrichsen_alpaca_2k_test(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = (
download_ds_fixture_bundle
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
)
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = (
download_ds_fixture_bundle
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
)
return datasets.load_from_disk(ds_path)["train"]
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,

View File

@@ -1,294 +0,0 @@
"""
GRPO test suite
"""
import os
import random
import subprocess # nosec B404
import sys
import time
from pathlib import Path
import pytest
import requests
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
def start_vllm(
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
) -> int:
"""
helper function to start the VLLM server in the background, mostly for testing purposes
"""
cmd = [sys.executable, "-m", "trl.scripts.vllm_serve", "--model", model]
if tensor_parallel_size := kwargs.get("tensor_parallel_size"):
cmd.extend(["--tensor-parallel-size", str(tensor_parallel_size)])
if host := kwargs.get("host"):
cmd.extend(["--host", host])
if port := kwargs.get("port"):
cmd.extend(["--port", str(port)])
if gpu_memory_utilization := kwargs.get("gpu_memory_utilization"):
cmd.extend(["--gpu-memory-utilization", str(gpu_memory_utilization)])
if dtype := kwargs.get("dtype"):
cmd.extend(["--dtype", dtype])
if max_model_len := kwargs.get("max_model_len"):
cmd.extend(["--max-model-len", str(max_model_len)])
if kwargs.get("enable_prefix_caching"):
cmd.extend(["--enable-prefix-caching", "True"])
# print out the command to be executed
print(" ".join(cmd))
# start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with
cmd,
env=env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
) # nosec B603
# print out the process id so the user can easily kill it later
print(f"VLLM server process started (PID: {process.pid})")
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
started = False
if wait and host and port:
for _ in range(int(wait)):
try:
response = requests.get(f"http://{host}:{port}", timeout=1)
if int(response.status_code) in [200, 404]:
started = True
break
except requests.exceptions.RequestException:
pass
# also check if the process.pid is still running
if not process.poll() is None:
break
time.sleep(1)
if wait and not started:
print(
f"VLLM server process did not start within {wait} seconds. Please check your server logs."
)
process.kill()
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process id
return process.pid
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
"""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_dora(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"peft_use_dora": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process_id = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
)
finally:
os.kill(vllm_process_id, 9)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_fft(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process_id = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
)
finally:
os.kill(vllm_process_id, 9)

View File

@@ -52,9 +52,9 @@ class TestMultiGPUEval:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -121,9 +121,9 @@ class TestMultiGPUEval:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",

View File

@@ -1,100 +0,0 @@
"""
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@pytest.fixture(scope="session", autouse=True)
def download_model():
# download the model
snapshot_download("axolotl-mirrors/gemma-3-4b-pt", repo_type="model")
class TestMultiGPUGemma3:
"""
Test case for Gemma3 models using LoRA
"""
def test_lora_ddp_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-mirrors/gemma-3-4b-pt",
"sequence_len": 2048,
"ddp_find_unused_parameters": True,
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"chat_template": "gemma3",
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {
"use_reentrant": False,
},
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
)

View File

@@ -0,0 +1,175 @@
"""
GRPO test suite
"""
import random
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
"""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_dora(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
"vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"peft_use_dora": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_fft(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
"vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)

View File

@@ -58,7 +58,6 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -122,7 +121,6 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -195,7 +193,6 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
@@ -273,7 +270,6 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
@@ -334,7 +330,6 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -404,8 +399,7 @@ class TestMultiGPULlama:
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -484,8 +478,7 @@ class TestMultiGPULlama:
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -785,10 +778,9 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",

View File

@@ -46,7 +46,7 @@ class TestMultiGPUQwen2:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 5,
"warmup_steps": 20,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,

View File

@@ -50,7 +50,7 @@ class TestMultiGPURay:
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",

View File

@@ -110,7 +110,7 @@ class TestRingAttention:
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
register_ring_attn(sequence_parallel_degree=4)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2

View File

@@ -201,3 +201,46 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_soap(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "soap",
"adam_beta1": 0.9,
"adam_beta2": 0.95,
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -324,7 +324,7 @@ class TestDatasetPreparation:
@enable_hf_offline
def test_load_hub_with_revision_with_dpo(
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
):
"""Verify that processing dpo data from the hub works with a specific revision"""
@@ -339,10 +339,12 @@ class TestDatasetPreparation:
)
# pylint: disable=duplicate-code
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
@@ -352,9 +354,7 @@ class TestDatasetPreparation:
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub_with_revision(
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer
):
def test_load_local_hub_with_revision(self, tokenizer):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
@@ -386,23 +386,13 @@ class TestDatasetPreparation:
}
)
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
dataset, _ = load_tokenized_prepared_datasets(
tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
@enable_hf_offline
def test_loading_local_dataset_folder(self, tokenizer):

View File

@@ -238,22 +238,21 @@ class TestDeduplicateRLDataset:
@enable_hf_offline
def test_load_with_deduplication(
self,
cfg,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
tokenizer_huggyllama,
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
"""Verify that loading with deduplication removes duplicates."""
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
@@ -264,20 +263,19 @@ class TestDeduplicateRLDataset:
@enable_hf_offline
def test_load_without_deduplication(
self,
cfg,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
tokenizer_huggyllama,
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama

View File

@@ -1,7 +1,7 @@
"""Module for testing streaming dataset sequence packing"""
import pytest
from datasets import concatenate_datasets
from datasets import concatenate_datasets, load_dataset
from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer
@@ -27,6 +27,7 @@ class TestBatchedSamplerPacking:
Test class for packing streaming dataset sequences
"""
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@pytest.mark.parametrize(
"batch_size, num_workers",
[
@@ -37,20 +38,14 @@ class TestBatchedSamplerPacking:
],
)
@pytest.mark.parametrize("max_seq_length", [4096, 512])
@pytest.mark.parametrize("sequential", [True, False])
@enable_hf_offline
def test_packing(
self,
dataset_winglian_tiny_shakespeare,
batch_size,
num_workers,
tokenizer,
max_seq_length,
sequential,
):
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = dataset_winglian_tiny_shakespeare["train"]
dataset = load_dataset(
"winglian/tiny-shakespeare",
split="train",
)
cfg = DictDefault(
{
@@ -60,7 +55,7 @@ class TestBatchedSamplerPacking:
)
ds_cfg = DictDefault(
{
"field": "text",
"field": "Text",
}
)
completion_strategy = load(tokenizer, cfg, ds_cfg)
@@ -80,7 +75,6 @@ class TestBatchedSamplerPacking:
batch_max_len=max_seq_length,
group_size=100000,
bin_size=200,
sequential=sequential,
)
loader = DataLoader(

View File

@@ -2,8 +2,13 @@
import json
import logging
import unittest
from pathlib import Path
import pytest
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
@@ -56,13 +61,24 @@ test_data = {
}
class TestPromptTokenizationStrategies:
class TestPromptTokenizationStrategies(unittest.TestCase):
"""
Test class for prompt tokenization strategies.
"""
@enable_hf_offline
def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens):
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts
"""
@@ -70,7 +86,7 @@ class TestPromptTokenizationStrategies:
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
tokenizer_huggyllama_w_special_tokens,
self.tokenizer,
False,
2048,
)
@@ -83,8 +99,7 @@ class TestPromptTokenizationStrategies:
assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100
@enable_hf_offline
def test_alpaca(self, tokenizer_huggyllama_w_special_tokens):
def test_alpaca(self):
"""
tests the interface between the user and assistant parts
"""
@@ -92,7 +107,7 @@ class TestPromptTokenizationStrategies:
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
tokenizer_huggyllama_w_special_tokens,
self.tokenizer,
False,
2048,
)
@@ -103,17 +118,28 @@ class TestPromptTokenizationStrategies:
assert example["labels"][world_idx - 1] == -100
class TestInstructionWSystemPromptTokenizingStrategy:
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
@enable_hf_offline
def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens):
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_system_alpaca(self):
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
strat = InstructionWSystemPromptTokenizingStrategy(
prompter,
tokenizer_huggyllama_w_special_tokens,
self.tokenizer,
False,
2048,
)
@@ -134,13 +160,18 @@ class TestInstructionWSystemPromptTokenizingStrategy:
assert example["input_ids"][8] == 11889 # USER
class Llama2ChatTokenizationTest:
class Llama2ChatTokenizationTest(unittest.TestCase):
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
@enable_hf_offline
def test_llama2_chat_integration(self, tokenizer_llama2_7b):
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
# woraround because official Meta repos are not open
def test_llama2_chat_integration(self):
with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin:
@@ -155,18 +186,16 @@ class Llama2ChatTokenizationTest:
prompter = Llama2ChatPrompter()
strat = LLama2ChatTokenizingStrategy(
prompter,
tokenizer_llama2_7b,
self.tokenizer,
False,
4096,
)
example = strat.tokenize_prompt(conversation)
for fields in ["input_ids", "attention_mask", "labels"]:
# pytest assert equals
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])
assert len(example[fields]) == len(tokenized_conversation[fields])
assert example[fields] == tokenized_conversation[fields]
def compare_with_transformers_integration(self, tokenizer_llama2_7b):
def compare_with_transformers_integration(self):
# this needs transformers >= v4.31.0
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
from transformers.pipelines.conversational import Conversation
@@ -205,27 +234,49 @@ If a question does not make any sense, or is not factually coherent, explain why
generated_responses=answers,
)
# pylint: disable=W0212
hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]
self.assertEqual(
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
)
class OrpoTokenizationTest:
class OrpoTokenizationTest(unittest.TestCase):
"""test case for the ORPO tokenization"""
@enable_hf_offline
def test_orpo_integration(
self,
tokenizer_mistral_7b_instruct_chatml,
dataset_argilla_ultrafeedback_binarized_preferences_cleaned,
):
ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0])
def setUp(self) -> None:
# pylint: disable=duplicate-code
tokenizer = LlamaTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
tokenizer.add_tokens(
[
AddedToken(
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
),
]
)
self.tokenizer = tokenizer
self.dataset = load_dataset(
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
).select([0])
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
def test_orpo_integration(self):
strat = load(
tokenizer_mistral_7b_instruct_chatml,
self.tokenizer,
DictDefault({"train_on_inputs": False}),
DictDefault({"chat_template": "chatml"}),
)
res = strat.tokenize_prompt(ds[0])
res = strat.tokenize_prompt(self.dataset[0])
assert "rejected_input_ids" in res
assert "rejected_labels" in res
assert "input_ids" in res
@@ -244,3 +295,7 @@ class OrpoTokenizationTest:
assert res["prompt_attention_mask"][0] == 1
assert res["prompt_attention_mask"][-1] == 0
if __name__ == "__main__":
unittest.main()

View File

@@ -321,48 +321,3 @@ class TestValidationCheckDatasetConfig(BaseValidation):
)
validate_config(cfg)
class TestOptimizerValidation(BaseValidation):
"""
Test muon optimizer validation
"""
def test_muon_deepspeed(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"optimizer": "muon",
"deepspeed": "deepspeed_configs/zero3.json",
}
)
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
validate_config(cfg)
def test_muon_fsdp(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"optimizer": "muon",
"fsdp": ["full_shard"],
"fsdp_config": {
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
}
)
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
validate_config(cfg)