Compare commits
11 Commits
fix/xforme
...
muon-valid
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c578c8f256 | ||
|
|
f4ae8816bb | ||
|
|
9b95e06cbb | ||
|
|
e0aba74dd0 | ||
|
|
328d598114 | ||
|
|
4d36ecc724 | ||
|
|
7acf93b59f | ||
|
|
b6fc46ada8 | ||
|
|
b35992262e | ||
|
|
ef6eb77cc8 | ||
|
|
5410195e0b |
6
.github/workflows/base.yml
vendored
6
.github/workflows/base.yml
vendored
@@ -40,6 +40,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
|
||||
4
.github/workflows/main.yml
vendored
4
.github/workflows/main.yml
vendored
@@ -25,12 +25,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -87,12 +87,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
3
.github/workflows/multi-gpu-e2e.yml
vendored
3
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -42,8 +42,7 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
# awaiting vllm#12721
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
|
||||
23
.github/workflows/tests-nightly.yml
vendored
23
.github/workflows/tests-nightly.yml
vendored
@@ -33,6 +33,15 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -46,7 +55,7 @@ jobs:
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Update requirements.txt
|
||||
run: |
|
||||
@@ -58,8 +67,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
@@ -73,10 +81,15 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest tests/patched/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
|
||||
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -96,6 +96,10 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
@@ -256,7 +260,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -40,6 +40,7 @@ quartodoc:
|
||||
- cli.preprocess
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.vllm_serve
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- title: Trainers
|
||||
@@ -243,6 +244,7 @@ website:
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
set -e
|
||||
|
||||
# only run one test at a time so as not to OOM the GPU
|
||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
pytest -v -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
|
||||
40
docs/cli.qmd
40
docs/cli.qmd
@@ -170,7 +170,7 @@ axolotl merge-sharded-fsdp-weights config.yml
|
||||
|
||||
### evaluate
|
||||
|
||||
Evaluates a model's performance using metrics specified in the config.
|
||||
Evaluates a model's performance (loss etc) on the train and eval datasets.
|
||||
|
||||
```bash
|
||||
# Basic evaluation
|
||||
@@ -197,6 +197,8 @@ lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
```
|
||||
|
||||
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
|
||||
|
||||
## Legacy CLI Usage
|
||||
|
||||
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
|
||||
@@ -235,7 +237,7 @@ Create a cloud config YAML with your Modal settings:
|
||||
```yaml
|
||||
# cloud_config.yml
|
||||
provider: modal
|
||||
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
|
||||
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
|
||||
gpu_count: 1 # Number of GPUs to use
|
||||
timeout: 86400 # Maximum runtime in seconds (24 hours)
|
||||
branch: main # Git branch to use (optional)
|
||||
@@ -248,7 +250,7 @@ volumes: # Persistent storage volumes
|
||||
- name: axolotl-artifacts
|
||||
mount: /workspace/artifacts
|
||||
|
||||
env: # Environment variables
|
||||
secrets: # Secrets to inject
|
||||
- WANDB_API_KEY
|
||||
- HF_TOKEN
|
||||
```
|
||||
@@ -274,15 +276,27 @@ axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
### Cloud Configuration Options
|
||||
|
||||
```yaml
|
||||
provider: # compute provider, currently only `modal` is supported
|
||||
gpu: # GPU type to use
|
||||
gpu_count: # Number of GPUs (default: 1)
|
||||
memory: # RAM in GB (default: 128)
|
||||
timeout: # Maximum runtime in seconds
|
||||
provider: # compute provider, currently only `modal` is supported
|
||||
gpu: # GPU type to use
|
||||
gpu_count: # Number of GPUs (default: 1)
|
||||
memory: # RAM in GB (default: 128)
|
||||
timeout: # Maximum runtime in seconds
|
||||
timeout_preprocess: # Preprocessing timeout
|
||||
branch: # Git branch to use
|
||||
docker_tag: # Custom Docker image tag
|
||||
volumes: # List of persistent storage volumes
|
||||
env: # Environment variables to pass
|
||||
secrets: # Secrets to inject
|
||||
branch: # Git branch to use
|
||||
docker_tag: # Custom Docker image tag
|
||||
volumes: # List of persistent storage volumes
|
||||
|
||||
# Environment variables to pass. Can be specified in two ways:
|
||||
# 1. As a string: Will load the value from the host computer's environment variables
|
||||
# 2. As a key-value pair: Will use the specified value directly
|
||||
# Example:
|
||||
# env:
|
||||
# - CUSTOM_VAR # Loads from host's $CUSTOM_VAR
|
||||
# - {CUSTOM_VAR: "value"} # Uses "value" directly
|
||||
env:
|
||||
|
||||
# Secrets to inject. Same input format as `env` but for sensitive data.
|
||||
secrets:
|
||||
# - HF_TOKEN
|
||||
# - WANDB_API_KEY
|
||||
```
|
||||
|
||||
@@ -238,10 +238,10 @@ simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||
# grpo
|
||||
trl:
|
||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||
vllm_device: # Optional[str]. Device to use for VLLM.
|
||||
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
||||
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
||||
vllm_dtype: # Optional[str]. Data type for VLLM.
|
||||
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
|
||||
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
|
||||
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
|
||||
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
|
||||
|
||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||
@@ -320,9 +320,13 @@ total_num_tokens:
|
||||
sample_packing_group_size: 100000
|
||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||
sample_packing_bin_size: 200
|
||||
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
|
||||
|
||||
# whether to concatenate samples during pretraining
|
||||
pretraining_sample_concatenation:
|
||||
|
||||
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
|
||||
|
||||
# Use batch flattening for speedups when not using sample_packing
|
||||
batch_flattening:
|
||||
|
||||
@@ -354,7 +358,27 @@ lora_target_modules:
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
lora_target_linear: # If true, will target all linear modules
|
||||
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
|
||||
|
||||
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
|
||||
peft_layers_to_transform:
|
||||
|
||||
# Optional[bool]. Whether to use DoRA.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
|
||||
peft_use_dora:
|
||||
|
||||
# Optional[bool]. Whether to use RSLoRA.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
|
||||
peft_use_rslora:
|
||||
|
||||
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
|
||||
peft_layer_replication:
|
||||
|
||||
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
|
||||
# How to initialize LoRA weights. Default to True which is MS original implementation.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
|
||||
peft_init_lora_weights:
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
@@ -587,26 +611,31 @@ max_grad_norm:
|
||||
# currently only supported on Llama and Mistral
|
||||
neftune_noise_alpha:
|
||||
|
||||
# Whether to bettertransformers
|
||||
# Optional[bool]. Whether to bettertransformers
|
||||
flash_optimum:
|
||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
|
||||
# Note: Only one of the following attention patches can be used at a time.
|
||||
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
|
||||
|
||||
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# Whether to use scaled-dot-product attention
|
||||
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
|
||||
# Optional[bool]. Whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||
s2_attention:
|
||||
|
||||
# Optional[bool]. Whether to use low_cpu_mem_usage
|
||||
low_cpu_mem_usage:
|
||||
# Resume from a specific checkpoint dir
|
||||
# Optional[str]. Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Be careful with this being turned on between different models.
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
@@ -658,6 +687,9 @@ ddp_broadcast_buffers:
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
heads_k_stride: 1
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
12
docs/faq.qmd
12
docs/faq.qmd
@@ -35,12 +35,22 @@ description: Frequently asked questions
|
||||
|
||||
**Q: How to call Axolotl via custom python scripts?**
|
||||
|
||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||
> A: Since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||
|
||||
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
|
||||
|
||||
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
|
||||
|
||||
**Q: ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as pad_token**
|
||||
|
||||
> A: This is because the tokenizer does not have a padding token. Please add a padding token to the tokenizer via:
|
||||
|
||||
> ```yaml
|
||||
> special_tokens:
|
||||
> # str. If you're not sure, set to same as `eos_token`.
|
||||
> pad_token: "..."
|
||||
> ```
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
@@ -18,6 +18,7 @@ Axolotl supports several methods for multi-GPU training:
|
||||
|
||||
- DeepSpeed (recommended)
|
||||
- FSDP (Fully Sharded Data Parallel)
|
||||
- Sequence parallelism
|
||||
- FSDP + QLoRA
|
||||
|
||||
## DeepSpeed {#sec-deepspeed}
|
||||
@@ -66,6 +67,28 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
## Sequence parallelism {#sec-sequence-parallelism}
|
||||
|
||||
We support sequence parallelism (SP) via the
|
||||
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
|
||||
allows one to split up sequences across GPUs, which is useful in the event that a
|
||||
single sequence causes OOM errors during model training.
|
||||
|
||||
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
|
||||
or from source with `pip install .[ring-flash-attn]`.
|
||||
|
||||
Your Axolotl YAML config should contain the following lines:
|
||||
|
||||
```{.yaml}
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
|
||||
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
|
||||
heads_k_stride: 1
|
||||
```
|
||||
|
||||
See our [dedicated guide](sequence_parallelism.qmd) for more details.
|
||||
|
||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||
|
||||
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||
|
||||
@@ -502,9 +502,48 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||
:::
|
||||
|
||||
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
|
||||
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
|
||||
using 4 GPUs - 2 for training, and 2 for vLLM:
|
||||
|
||||
::: {.callout-important}
|
||||
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
tensor_parallel_size: 2
|
||||
gpu_memory_utilization: 0.85
|
||||
dtype: auto
|
||||
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
|
||||
|
||||
rl: grpo
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
vllm_server_timeout: 300
|
||||
```
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml
|
||||
```
|
||||
|
||||
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
|
||||
```
|
||||
|
||||
#### Reward functions
|
||||
|
||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||
|
||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
For example, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
|
||||
```python
|
||||
# rewards.py
|
||||
@@ -530,8 +569,6 @@ trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 256
|
||||
use_vllm: True
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.15
|
||||
num_generations: 4
|
||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||
reward_weights: [1.0]
|
||||
|
||||
@@ -25,6 +25,8 @@ To enable sequence parallelism, add the following to your configuration file:
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
@@ -58,11 +60,16 @@ To use sequence parallelism, you need:
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
# Example config with sequence parallelism
|
||||
base_model: meta-llama/Llama-3-8B-Instruct
|
||||
sequence_len: 8192
|
||||
sequence_parallel_degree: 2 # Split each sequence into 4 parts
|
||||
|
||||
...
|
||||
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
@@ -5,6 +5,9 @@ tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
@@ -54,6 +57,8 @@ fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
|
||||
@@ -7,6 +7,9 @@ skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
@@ -48,6 +51,8 @@ fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
80
examples/llama-3/lora-1b-sample-packing-sequentially.yml
Normal file
80
examples/llama-3/lora-1b-sample-packing-sequentially.yml
Normal file
@@ -0,0 +1,80 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
test_value: true
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
sample_packing_sequentially: true
|
||||
curriculum_sampling: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.3
|
||||
bitsandbytes==0.45.4
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
@@ -12,12 +12,12 @@ liger-kernel==0.5.5
|
||||
packaging==23.2
|
||||
|
||||
peft==0.15.0
|
||||
transformers==4.50.0
|
||||
transformers==4.50.3
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.5.2
|
||||
datasets==3.5.0
|
||||
deepspeed==0.16.4
|
||||
trl==0.15.1
|
||||
trl==0.16.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
|
||||
83
setup.py
83
setup.py
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def parse_requirements():
|
||||
def parse_requirements(extras_require_map):
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
@@ -67,6 +67,7 @@ def parse_requirements():
|
||||
if (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post2")
|
||||
extras_require_map["vllm"] = ["vllm==0.8.1"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
@@ -86,7 +87,7 @@ def parse_requirements():
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
return _install_requires, _dependency_links
|
||||
return _install_requires, _dependency_links, extras_require_map
|
||||
|
||||
|
||||
def get_package_version():
|
||||
@@ -103,7 +104,46 @@ def get_package_version():
|
||||
return version_
|
||||
|
||||
|
||||
install_requires, dependency_links = parse_requirements()
|
||||
extras_require = {
|
||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
],
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
],
|
||||
"mlflow": [
|
||||
"mlflow",
|
||||
],
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
"apollo": [
|
||||
"apollo-torch",
|
||||
],
|
||||
"optimizers": [
|
||||
"galore_torch",
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
}
|
||||
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
extras_require
|
||||
)
|
||||
|
||||
setup(
|
||||
version=get_package_version(),
|
||||
@@ -116,40 +156,5 @@ setup(
|
||||
"axolotl=axolotl.cli.main:main",
|
||||
],
|
||||
},
|
||||
extras_require={
|
||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
],
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
],
|
||||
"mlflow": [
|
||||
"mlflow",
|
||||
],
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
"apollo": [
|
||||
"apollo-torch",
|
||||
],
|
||||
"optimizers": [
|
||||
"galore_torch",
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
},
|
||||
extras_require=extras_require_build,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,55 @@ class TrainerCliArgs:
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VllmServeCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
||||
|
||||
tensor_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
host: str = field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
)
|
||||
port: int = field(
|
||||
default=8000,
|
||||
metadata={"help": "Port to run the server on."},
|
||||
)
|
||||
gpu_memory_utilization: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||||
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
||||
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
||||
"out-of-memory (OOM) errors during initialization."
|
||||
},
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||
},
|
||||
)
|
||||
max_model_len: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
||||
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
||||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||
},
|
||||
)
|
||||
enable_prefix_caching: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
||||
"hardware support this feature."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluateCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
||||
|
||||
@@ -14,7 +14,12 @@ import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.cli.args import (
|
||||
EvaluateCliArgs,
|
||||
PreprocessCliArgs,
|
||||
TrainerCliArgs,
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
@@ -23,6 +28,7 @@ from axolotl.cli.utils import (
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from axolotl.cli.vllm_serve import do_vllm_serve
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
@@ -316,6 +322,14 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
fetch_from_github(f"{directory}/", dest)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(VllmServeCliArgs)
|
||||
@filter_none_kwargs
|
||||
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||
do_vllm_serve(config, cli_args)
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
|
||||
55
src/axolotl/cli/vllm_serve.py
Normal file
55
src/axolotl/cli/vllm_serve.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
CLI to start the vllm server for online RL
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
from trl.scripts.vllm_serve import main as vllm_serve_main
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
|
||||
|
||||
def do_vllm_serve(
|
||||
config: Union[Path, str],
|
||||
cli_args: dict,
|
||||
):
|
||||
"""
|
||||
Starts the VLLM server for serving LLM models used for online RL
|
||||
|
||||
Args
|
||||
:param cfg: Parsed doct of the YAML config
|
||||
:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs
|
||||
|
||||
Returns:
|
||||
process_id: the process id of the started VLLM server
|
||||
"""
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
host = cli_args.get("host") or cfg.vllm.host
|
||||
port = cli_args.get("port") or cfg.vllm.port
|
||||
gpu_memory_utilization = (
|
||||
cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization
|
||||
)
|
||||
dtype = cli_args.get("dtype") or cfg.vllm.dtype
|
||||
max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len
|
||||
enable_prefix_caching = (
|
||||
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
|
||||
)
|
||||
|
||||
vllm_script_args = ScriptArguments(
|
||||
model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
host=host,
|
||||
port=port,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
vllm_serve_main(vllm_script_args)
|
||||
@@ -69,7 +69,6 @@ from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SaveModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
@@ -249,7 +248,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
callbacks.append(SaveModelCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -526,9 +524,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
and self.cfg.eval_steps
|
||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||
) or False
|
||||
|
||||
# handle ddp
|
||||
ddp_find_unused_parameters = None
|
||||
if self.cfg.ddp:
|
||||
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
|
||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||
False if self.cfg.ddp else None
|
||||
ddp_find_unused_parameters
|
||||
)
|
||||
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
report_to = []
|
||||
@@ -937,7 +941,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
callbacks.append(SaveModelCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
SequenceParallelMixin,
|
||||
)
|
||||
@@ -40,7 +41,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
|
||||
class AxolotlTrainer(
|
||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
@@ -112,6 +115,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.mixins import SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
@@ -40,18 +40,15 @@ class GRPOStrategy:
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_device"] = (
|
||||
trl.vllm_device if trl.vllm_device else "auto"
|
||||
)
|
||||
|
||||
if trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
trl.vllm_gpu_memory_utilization
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
|
||||
if trl.vllm_server_timeout:
|
||||
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||
if trl.vllm_guided_decoding_regex:
|
||||
grpo_args_kwargs["vllm_guided_decoding_regex"] = (
|
||||
trl.vllm_guided_decoding_regex
|
||||
)
|
||||
|
||||
if trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||
|
||||
if trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||
|
||||
@@ -70,6 +67,25 @@ class GRPOStrategy:
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
if trl.scale_rewards is not None:
|
||||
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
|
||||
|
||||
if trl.temperature is not None:
|
||||
grpo_args_kwargs["temperature"] = trl.temperature
|
||||
if trl.top_p is not None:
|
||||
grpo_args_kwargs["top_p"] = trl.top_p
|
||||
if trl.top_k is not None:
|
||||
grpo_args_kwargs["top_k"] = trl.top_k
|
||||
if trl.min_p is not None:
|
||||
grpo_args_kwargs["min_p"] = trl.min_p
|
||||
if trl.repetition_penalty is not None:
|
||||
grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty
|
||||
|
||||
if trl.num_iterations is not None:
|
||||
grpo_args_kwargs["num_iterations"] = trl.num_iterations
|
||||
if trl.epsilon is not None:
|
||||
grpo_args_kwargs["epsilon"] = trl.epsilon
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -2,108 +2,68 @@
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
|
||||
from accelerate.utils import is_peft_model
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from transformers import PreTrainedModel
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.models import unwrap_model_for_generation
|
||||
from contextlib import nullcontext
|
||||
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||
from trl import GRPOTrainer
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# pylint: disable=access-member-before-definition
|
||||
# Enable gradient checkpointing if requested
|
||||
if kwargs["args"].gradient_checkpointing:
|
||||
# Ensure use_cache is disabled
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(self.model) and hasattr(
|
||||
self.model.base_model, "gradient_checkpointing_enable"
|
||||
):
|
||||
self.model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# pylint: disable=unused-argument,redefined-builtin
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
@profiling_decorator
|
||||
def _move_model_to_vllm(self):
|
||||
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
gather_if_zero3 = (
|
||||
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
if is_peft_model(self.model):
|
||||
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
||||
# adapters in a sharded manner is not supported.
|
||||
with gather_if_zero3(list(self.model.parameters())):
|
||||
self.model.merge_adapter()
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
# Update vLLM weights while parameters are gathered
|
||||
for name, param in self.model.named_parameters():
|
||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
||||
name = (
|
||||
name.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", "")
|
||||
)
|
||||
if self.model.prefix in name:
|
||||
continue
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
if "original_module" in name:
|
||||
continue
|
||||
name = name.replace("modules_to_save.default.", "")
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(
|
||||
make_inputs_require_grad
|
||||
)
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
|
||||
return model
|
||||
# pylint: enable=unused-argument,redefined-builtin
|
||||
# Unmerge adapters while parameters are still gathered
|
||||
self.model.unmerge_adapter()
|
||||
# Parameters will automatically be repartitioned when exiting the context
|
||||
else:
|
||||
# For non-PEFT models, simply gather and update each parameter individually.
|
||||
for name, param in self.model.named_parameters():
|
||||
with gather_if_zero3([param]):
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
|
||||
def _move_model_to_vllm(self):
|
||||
with unwrap_model_for_generation(
|
||||
self.model,
|
||||
self.accelerator,
|
||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||
) as unwrapped_model:
|
||||
if is_compiled_module(unwrapped_model):
|
||||
unwrapped_model = (
|
||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||
)
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", ""): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
# Remove values with adapter prefix (example: "_lora")
|
||||
state_dict = {
|
||||
k: v
|
||||
for k, v in state_dict.items()
|
||||
if unwrapped_model.prefix not in k
|
||||
}
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
state_dict = {
|
||||
k.replace("modules_to_save.default.", ""): v
|
||||
for k, v in state_dict.items()
|
||||
if "original_module" not in k
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
# Reset cache on main process
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
|
||||
@@ -4,5 +4,6 @@
|
||||
# flake8: noqa
|
||||
|
||||
from .optimizer import OptimizerMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
from .sequence_parallel import SequenceParallelMixin
|
||||
|
||||
67
src/axolotl/core/trainers/mixins/rng_state_loader.py
Normal file
67
src/axolotl/core/trainers/mixins/rng_state_loader.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Temporary fix/override for bug in resume from checkpoint
|
||||
|
||||
See https://github.com/huggingface/transformers/pull/37162
|
||||
|
||||
TODO: Remove when upstream added PR to release
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Trainer, is_torch_npu_available
|
||||
from transformers.trainer import safe_globals
|
||||
from transformers.trainer_pt_utils import set_rng_state_for_device
|
||||
from transformers.training_args import ParallelMode
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RngLoaderMixin(Trainer):
|
||||
"""
|
||||
mixin for method override to load RNG states from a checkpoint
|
||||
"""
|
||||
|
||||
def _load_rng_state(self, checkpoint):
|
||||
# Load RNG states from `checkpoint`
|
||||
if checkpoint is None:
|
||||
return
|
||||
|
||||
if self.args.world_size > 1:
|
||||
process_index = self.args.process_index
|
||||
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
|
||||
if not os.path.isfile(rng_file):
|
||||
LOG.info(
|
||||
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
|
||||
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
|
||||
)
|
||||
return
|
||||
else:
|
||||
rng_file = os.path.join(checkpoint, "rng_state.pth")
|
||||
if not os.path.isfile(rng_file):
|
||||
LOG.info(
|
||||
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
|
||||
"fashion, reproducibility is not guaranteed."
|
||||
)
|
||||
return
|
||||
|
||||
# Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
|
||||
# which requires allowlisted classes when loading with weights_only=True.
|
||||
with safe_globals():
|
||||
checkpoint_rng_state = torch.load(rng_file) # nosec B614
|
||||
random.setstate(checkpoint_rng_state["python"])
|
||||
np.random.set_state(checkpoint_rng_state["numpy"])
|
||||
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
||||
|
||||
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
if torch.cuda.is_available():
|
||||
set_rng_state_for_device(
|
||||
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
|
||||
)
|
||||
if is_torch_npu_available():
|
||||
set_rng_state_for_device(
|
||||
"NPU", torch.npu, checkpoint_rng_state, is_distributed
|
||||
)
|
||||
@@ -13,6 +13,7 @@ from trl import (
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
@@ -74,7 +75,7 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -154,7 +155,7 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -162,7 +163,7 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -244,7 +245,7 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -252,7 +253,7 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
@@ -34,6 +34,12 @@ class AxolotlTrainingMixins:
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_sequentially: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
|
||||
@@ -15,6 +15,7 @@ from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
@@ -159,4 +160,6 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
cleanup_distributed()
|
||||
|
||||
return all_metrics
|
||||
|
||||
@@ -38,13 +38,19 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_degree: int):
|
||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed
|
||||
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
||||
"""
|
||||
if get_ring_attn_group() is not None:
|
||||
LOG.info("Ring attention already registered, exiting early...")
|
||||
return
|
||||
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
@@ -84,6 +90,11 @@ def register_ring_attn(sequence_parallel_degree: int):
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if heads_k_stride is None:
|
||||
heads_k_stride = 1
|
||||
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||
)
|
||||
|
||||
@@ -1,113 +1,153 @@
|
||||
"""
|
||||
Hijack the LlamaAttention forward method to use xformers if available.
|
||||
|
||||
Updated for transformers v4.50.0.
|
||||
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.llama.modeling_llama import repeat_kv
|
||||
import torch.nn.functional as F
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
try:
|
||||
import xformers.ops
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
def xformers_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""
|
||||
Implements xformers memory-efficient attention for LlamaAttention with support for GQA.
|
||||
|
||||
Args:
|
||||
module: The LlamaAttention module
|
||||
query: Query states of shape [batch, num_heads, seq_len, head_dim]
|
||||
key: Key states of shape [batch, num_kv_heads, seq_len, head_dim]
|
||||
value: Value states of shape [batch, num_kv_heads, seq_len, head_dim]
|
||||
attention_mask: Attention mask
|
||||
scaling: Scaling factor for attention scores
|
||||
dropout: Dropout probability
|
||||
|
||||
Returns:
|
||||
attn_output: Output of xformers memory-efficient attention
|
||||
attn_weights: None
|
||||
"""
|
||||
# First, handle grouped-query attention (GQA)
|
||||
# We need to repeat key and value states to match the number of query heads
|
||||
num_key_value_groups = getattr(module, "num_key_value_groups", 1)
|
||||
key = repeat_kv(key, num_key_value_groups)
|
||||
value = repeat_kv(value, num_key_value_groups)
|
||||
|
||||
# xformers expects inputs in shape [batch, seq_len, num_heads, head_dim]
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Determine if we need a causal mask
|
||||
is_causal = getattr(module, "is_causal", True)
|
||||
|
||||
# Set up the attention bias for xformers
|
||||
if is_causal:
|
||||
# Use xformers built-in causal mask
|
||||
attn_bias = xformers.ops.LowerTriangularMask()
|
||||
elif attention_mask is not None:
|
||||
# For non-causal attention with a mask, we'd need to convert the mask
|
||||
# This is a simplification - you might need to adapt based on your mask format
|
||||
attn_bias = attention_mask
|
||||
else:
|
||||
# No mask needed
|
||||
attn_bias = None
|
||||
|
||||
# Apply xformers memory-efficient attention
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
p=dropout if module.training else 0.0,
|
||||
scale=scaling,
|
||||
)
|
||||
|
||||
# Reshape back to [batch, seq_len, hidden_size]
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
return attn_output, None # Return None for attn_weights to match interface
|
||||
logging.error("xformers not found! Please install it before trying to use it.")
|
||||
|
||||
|
||||
def hijack_llama_attention():
|
||||
"""
|
||||
Patch the LlamaAttention forward method to use xformers if available.
|
||||
"""
|
||||
if not XFORMERS_AVAILABLE:
|
||||
raise ValueError(
|
||||
"xformers not available. Please install it following axolotl's requirements."
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||
|
||||
|
||||
def xformers_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
import transformers.models.llama.modeling_llama as llama_modeling
|
||||
#
|
||||
# xformers-attn start
|
||||
#
|
||||
|
||||
# Add xformers to the available attention implementations
|
||||
llama_modeling.ALL_ATTENTION_FUNCTIONS["xformers"] = xformers_attention_forward
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# Create a wrapper for the original LlamaAttention forward method
|
||||
original_forward = llama_modeling.LlamaAttention.forward
|
||||
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states, key_states, value_states, attn_bias=None
|
||||
)
|
||||
else:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
# attn_bias=attention_mask,
|
||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||
)
|
||||
|
||||
def patched_forward(self, *args, **kwargs):
|
||||
# Set the attention implementation to xformers
|
||||
# pylint: disable=protected-access
|
||||
self.config._attn_implementation = "xformers"
|
||||
return original_forward(self, *args, **kwargs)
|
||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
# Apply the patch
|
||||
llama_modeling.LlamaAttention.forward = patched_forward
|
||||
#
|
||||
# xformers-attn end
|
||||
#
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.pretraining_tp)
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
@@ -22,6 +22,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"phi3",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"gemma3",
|
||||
"gemma3_text",
|
||||
"cohere",
|
||||
"cohere2",
|
||||
|
||||
@@ -27,6 +27,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
@@ -157,6 +158,8 @@ def setup_signal_handler(
|
||||
_model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
cleanup_distributed()
|
||||
sys.exit(0)
|
||||
|
||||
_model_weakref = weakref.ref(model)
|
||||
@@ -478,7 +481,7 @@ def train(
|
||||
Returns:
|
||||
Tuple of (model, tokenizer) after training
|
||||
"""
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer etc.
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||
(
|
||||
trainer,
|
||||
model,
|
||||
@@ -487,34 +490,26 @@ def train(
|
||||
processor,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
# Determine if we need to resume from a checkpoint
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
|
||||
# Configuration for saving
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
handle_untrained_tokens_fix(
|
||||
cfg, model, tokenizer, train_dataset, safe_serialization
|
||||
)
|
||||
|
||||
# Save initial configs
|
||||
# Additional setup
|
||||
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
||||
|
||||
# Set up signal handler for graceful termination
|
||||
setup_signal_handler(cfg, model, safe_serialization)
|
||||
|
||||
# Set up badges and config info for model card
|
||||
setup_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# Save the trained model
|
||||
# Save the trained model and cleanup
|
||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||
|
||||
# Create model card
|
||||
create_model_card(cfg, trainer)
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -816,27 +816,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
|
||||
class SaveModelCallback(TrainerCallback):
|
||||
"""Callback to save model on train end"""
|
||||
|
||||
def on_step_end( # pylint: disable=unused-argument
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
# Save
|
||||
if state.global_step >= state.max_steps:
|
||||
control.should_save = True
|
||||
|
||||
def on_train_end( # pylint: disable=unused-argument
|
||||
self, args, state, control, **kwargs
|
||||
):
|
||||
control.should_save = True
|
||||
return control
|
||||
|
||||
|
||||
class GCCallback(TrainerCallback):
|
||||
"""Callback to garbage collect torch cache"""
|
||||
|
||||
|
||||
@@ -112,6 +112,7 @@ class DataCollatorForSeq2Seq:
|
||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
has_attn_mask = "attention_mask" in features[0].keys()
|
||||
labels = None
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
@@ -164,6 +165,8 @@ class DataCollatorForSeq2Seq:
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
if not has_attn_mask:
|
||||
del features["attention_mask"]
|
||||
|
||||
# prepare decoder_input_ids
|
||||
if (
|
||||
|
||||
@@ -238,7 +238,8 @@ def load_dataset_w_config(
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
else:
|
||||
elif config_dataset.data_files:
|
||||
fp: str | list[str] | None = None
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
|
||||
@@ -71,8 +71,8 @@ def barrier():
|
||||
|
||||
def is_main_process():
|
||||
"""
|
||||
Check if the current process is the main process.
|
||||
If not in distributed mode, always return True.
|
||||
Check if the current process is the main process. If not in distributed mode,
|
||||
always return `True`.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return True
|
||||
@@ -87,6 +87,18 @@ def get_world_size():
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
def cleanup_distributed():
|
||||
"""
|
||||
Destroy process group if torch distributed is initialized. Called in training early
|
||||
termination or when training successfully completes.
|
||||
"""
|
||||
# Ensure that all operations are completed before destroying the process group
|
||||
torch.cuda.synchronize()
|
||||
# Destroy the process group
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_only():
|
||||
"""
|
||||
|
||||
@@ -609,7 +609,10 @@ class ModelLoader:
|
||||
# Initialize ring attn for sequence parallelism. This must be done after
|
||||
# model init but before the first forward pass, since it modifies flash
|
||||
# attn to use ring comm for SP training across multiple GPUs.
|
||||
register_ring_attn(self.cfg.sequence_parallel_degree)
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
||||
heads_k_stride=self.cfg.heads_k_stride,
|
||||
)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
|
||||
@@ -8,11 +8,13 @@ from typing import Any, Iterable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
|
||||
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@numba.njit
|
||||
@@ -103,6 +105,55 @@ def allocate(
|
||||
return result, s, len(result) * c * n
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
|
||||
"""
|
||||
Sequential allocator that preserves example order
|
||||
|
||||
Parameters:
|
||||
- lengths: The lengths of all examples
|
||||
- rank: The current rank (for distributed training)
|
||||
- c: The capacity of each bin (maximum sequence length)
|
||||
- n: Number of ranks
|
||||
|
||||
Returns:
|
||||
- result: List of batches for the current rank
|
||||
- total_used: Number of actual example tokens
|
||||
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
|
||||
"""
|
||||
result = []
|
||||
total_used = 0
|
||||
|
||||
# First, do sequential packing into bins
|
||||
all_bins = []
|
||||
current_bin = [0 for i in range(0)] # numba hint
|
||||
remaining_capacity = c
|
||||
|
||||
for idx, size in enumerate(lengths):
|
||||
if size <= remaining_capacity:
|
||||
# Example fits in current bin
|
||||
current_bin.append(idx)
|
||||
remaining_capacity -= size
|
||||
total_used += size
|
||||
else:
|
||||
# Example doesn't fit, start a new bin
|
||||
if current_bin: # Add non-empty bin to all_bins
|
||||
all_bins.append(current_bin)
|
||||
current_bin = [idx]
|
||||
remaining_capacity = c - size
|
||||
total_used += size
|
||||
|
||||
# Add the last bin if not empty
|
||||
if current_bin:
|
||||
all_bins.append(current_bin)
|
||||
|
||||
# Assign bins to ranks - each rank gets every n-th bin
|
||||
for bin_idx in range(rank, len(all_bins), n):
|
||||
result.append(all_bins[bin_idx])
|
||||
|
||||
return result, total_used, len(all_bins) * c
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""Batch sampler class for multipack"""
|
||||
|
||||
@@ -115,6 +166,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
drop_last: bool = False,
|
||||
num_count_samples: int = 16,
|
||||
sequential: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
@@ -122,6 +174,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.sequential = sequential
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
@@ -136,6 +189,11 @@ class MultipackBatchSampler(BatchSampler):
|
||||
# the minimum packed dataset length across all ranks determined by a gather/broadcast
|
||||
self.len_across_ranks = None
|
||||
|
||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||
LOG.warn(
|
||||
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
||||
)
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
@@ -145,13 +203,21 @@ class MultipackBatchSampler(BatchSampler):
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
if self.sequential:
|
||||
batches, total_used, total_slots = allocate_sequentially(
|
||||
lengths=lengths,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
else:
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [
|
||||
[
|
||||
|
||||
@@ -46,6 +46,7 @@ from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,6 +87,9 @@ class AxolotlInputConfig(
|
||||
trl: TRLConfig | None = Field(
|
||||
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
reward_model: bool | None = None
|
||||
process_reward_model: bool | None = None
|
||||
num_labels: int | None = None
|
||||
@@ -188,6 +192,7 @@ class AxolotlInputConfig(
|
||||
sample_packing: bool | None = None
|
||||
sample_packing_group_size: int | None = 100_000
|
||||
sample_packing_bin_size: int | None = 200
|
||||
sample_packing_sequentially: bool | None = None
|
||||
eval_sample_packing: bool | None = None
|
||||
pad_to_sequence_len: bool | None = None
|
||||
curriculum_sampling: bool | None = None
|
||||
@@ -248,6 +253,7 @@ class AxolotlInputConfig(
|
||||
val_set_size: float | None = Field(default=0.0)
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
heads_k_stride: int | None = None
|
||||
|
||||
special_tokens: SpecialTokensConfig | None = None
|
||||
tokens: list[str] | None = None
|
||||
@@ -1108,7 +1114,7 @@ class AxolotlInputConfig(
|
||||
|
||||
@field_validator("sequence_parallel_degree", mode="before")
|
||||
@classmethod
|
||||
def check_sequence_parallel_config(cls, value, info):
|
||||
def check_sequence_parallel_degree(cls, value, info):
|
||||
if not value:
|
||||
value = 1
|
||||
|
||||
@@ -1129,6 +1135,17 @@ class AxolotlInputConfig(
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
if data.get("optimizer") == "muon" and (
|
||||
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
|
||||
):
|
||||
raise ValueError(
|
||||
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
@@ -1264,3 +1281,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data["beta"] != data["trl"]["beta"]:
|
||||
raise ValueError("beta and trl.beta must match or one must be removed")
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_min_torch_version(self):
|
||||
if self.env_capabilities and self.env_capabilities.torch_version:
|
||||
torch_version = self.env_capabilities.torch_version
|
||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||
LOG.warning(
|
||||
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
|
||||
)
|
||||
|
||||
@@ -20,27 +20,30 @@ class TRLConfig(BaseModel):
|
||||
)
|
||||
|
||||
# GRPO specific args
|
||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||
use_vllm: bool | None = Field(
|
||||
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
|
||||
use_vllm: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||
)
|
||||
vllm_device: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Device to use for VLLM"},
|
||||
vllm_server_host: str | None = Field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
json_schema_extra={"description": "Host of the vLLM server to connect to"},
|
||||
)
|
||||
vllm_gpu_memory_utilization: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
vllm_server_port: int | None = Field(
|
||||
default=8000,
|
||||
json_schema_extra={"description": "Port of the vLLM server to connect to"},
|
||||
)
|
||||
vllm_dtype: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Data type for VLLM"},
|
||||
)
|
||||
vllm_max_model_len: int | None = Field(
|
||||
vllm_server_timeout: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the model context for VLLM"
|
||||
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
|
||||
"after the timeout, a `ConnectionError` is raised."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -85,3 +88,48 @@ class TRLConfig(BaseModel):
|
||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
scale_rewards: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
|
||||
},
|
||||
)
|
||||
|
||||
temperature: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Sampling temperature for the GRPO policy."},
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Top-p sampling probability for the generation policy."
|
||||
},
|
||||
)
|
||||
top_k: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Top-k sampling for the generation policy."},
|
||||
)
|
||||
min_p: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Minimum probability for the generation policy."
|
||||
},
|
||||
)
|
||||
repetition_penalty: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
|
||||
},
|
||||
)
|
||||
num_iterations: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
|
||||
},
|
||||
)
|
||||
epsilon: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Epsilon value for clipping in the GRPO algorithm."
|
||||
},
|
||||
)
|
||||
|
||||
38
src/axolotl/utils/schemas/vllm.py
Normal file
38
src/axolotl/utils/schemas/vllm.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Pydantic models for VLLM configuration, used primarily for RL training with TRL + grpo
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VllmConfig(BaseModel):
|
||||
"""
|
||||
Configuration for VLLM server
|
||||
"""
|
||||
|
||||
device: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Device to use for VLLM"},
|
||||
)
|
||||
tensor_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Tensor parallel size for VLLM"},
|
||||
)
|
||||
gpu_memory_utilization: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
)
|
||||
dtype: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Data type for VLLM"},
|
||||
)
|
||||
max_model_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the model context for VLLM"
|
||||
},
|
||||
)
|
||||
enable_prefix_caching: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Enable prefix caching for VLLM"},
|
||||
)
|
||||
@@ -13,7 +13,7 @@ import torch
|
||||
import torch.cuda
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
@@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if cfg.model_config_type == "mamba":
|
||||
if cfg.model_config_type in ["mamba", "gemma3"]:
|
||||
LOG.info("dropping attention_mask column")
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
@@ -456,13 +456,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
else:
|
||||
sampler_batch_size = cfg.micro_batch_size
|
||||
batch_max_len = cfg.sequence_len
|
||||
if cfg.curriculum_sampling:
|
||||
sampler = SequentialSampler(train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
sampler=sampler,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
batch_size=sampler_batch_size,
|
||||
batch_max_len=batch_max_len,
|
||||
group_size=cfg.sample_packing_group_size,
|
||||
bin_size=cfg.sample_packing_bin_size,
|
||||
sequential=cfg.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,11 +8,13 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
|
||||
@@ -48,6 +50,14 @@ def snapshot_download_w_retry(*args, **kwargs):
|
||||
return snapshot_download(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_ds_fixture_bundle():
|
||||
ds_dir = snapshot_download_w_retry(
|
||||
"axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
|
||||
)
|
||||
return Path(ds_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_smollm2_135m_model():
|
||||
# download the model
|
||||
@@ -101,42 +111,50 @@ def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||
def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_fozzie_alpaca_dpo_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
||||
)
|
||||
snapshot_download_w_retry(
|
||||
"fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
repo_type="dataset",
|
||||
revision="ea82cff",
|
||||
)
|
||||
# @pytest.fixture(scope="session", autouse=True)
|
||||
# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||
# # download the dataset
|
||||
# snapshot_download_w_retry(
|
||||
# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||
# )
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@disable_hf_offline
|
||||
def dataset_fozzie_alpaca_dpo_dataset(
|
||||
download_fozzie_alpaca_dpo_dataset,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
# @pytest.fixture(scope="session", autouse=True)
|
||||
# def download_fozzie_alpaca_dpo_dataset():
|
||||
# # download the dataset
|
||||
# snapshot_download_w_retry(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
||||
# )
|
||||
# snapshot_download_w_retry(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
# repo_type="dataset",
|
||||
# revision="ea82cff",
|
||||
# )
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@disable_hf_offline
|
||||
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
download_fozzie_alpaca_dpo_dataset,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return load_dataset(
|
||||
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
)
|
||||
# @pytest.fixture(scope="session")
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(scope="session")
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# return load_dataset(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
# )
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -263,7 +281,7 @@ def download_mlx_mistral_7b_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture
|
||||
def download_llama2_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
@@ -273,7 +291,7 @@ def download_llama2_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama(
|
||||
download_huggyllama_model_fixture,
|
||||
@@ -284,6 +302,57 @@ def tokenizer_huggyllama(
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama_w_special_tokens(
|
||||
tokenizer_huggyllama,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
tokenizer_huggyllama.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
return tokenizer_huggyllama
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_llama2_7b(
|
||||
download_llama2_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_mistral_7b_instruct(
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
||||
tokenizer_mistral_7b_instruct.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer_mistral_7b_instruct.add_tokens(
|
||||
[
|
||||
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
return tokenizer_mistral_7b_instruct
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
# Create a temporary directory
|
||||
@@ -349,6 +418,60 @@ def cleanup_monkeypatches():
|
||||
globals().pop(module_global, None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_winglian_tiny_shakespeare(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
|
||||
return datasets.load_from_disk(ds_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_tatsu_lab_alpaca(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_mhenrichsen_alpaca_2k_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
|
||||
)
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
|
||||
)
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
# # pylint: disable=redefined-outer-name,unused-argument
|
||||
# def test_load_fixtures(
|
||||
# download_smollm2_135m_model,
|
||||
|
||||
0
tests/e2e/multigpu/solo/__init__.py
Normal file
0
tests/e2e/multigpu/solo/__init__.py
Normal file
294
tests/e2e/multigpu/solo/test_grpo.py
Normal file
294
tests/e2e/multigpu/solo/test_grpo.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
GRPO test suite
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import subprocess # nosec B404
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import require_vllm
|
||||
|
||||
|
||||
def start_vllm(
|
||||
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
helper function to start the VLLM server in the background, mostly for testing purposes
|
||||
"""
|
||||
cmd = [sys.executable, "-m", "trl.scripts.vllm_serve", "--model", model]
|
||||
|
||||
if tensor_parallel_size := kwargs.get("tensor_parallel_size"):
|
||||
cmd.extend(["--tensor-parallel-size", str(tensor_parallel_size)])
|
||||
if host := kwargs.get("host"):
|
||||
cmd.extend(["--host", host])
|
||||
if port := kwargs.get("port"):
|
||||
cmd.extend(["--port", str(port)])
|
||||
if gpu_memory_utilization := kwargs.get("gpu_memory_utilization"):
|
||||
cmd.extend(["--gpu-memory-utilization", str(gpu_memory_utilization)])
|
||||
if dtype := kwargs.get("dtype"):
|
||||
cmd.extend(["--dtype", dtype])
|
||||
if max_model_len := kwargs.get("max_model_len"):
|
||||
cmd.extend(["--max-model-len", str(max_model_len)])
|
||||
if kwargs.get("enable_prefix_caching"):
|
||||
cmd.extend(["--enable-prefix-caching", "True"])
|
||||
|
||||
# print out the command to be executed
|
||||
print(" ".join(cmd))
|
||||
|
||||
# start `trl vllm-serve` command in the background and capture the process id
|
||||
process = subprocess.Popen( # pylint: disable=consider-using-with
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
|
||||
) # nosec B603
|
||||
|
||||
# print out the process id so the user can easily kill it later
|
||||
print(f"VLLM server process started (PID: {process.pid})")
|
||||
|
||||
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
|
||||
started = False
|
||||
if wait and host and port:
|
||||
for _ in range(int(wait)):
|
||||
try:
|
||||
response = requests.get(f"http://{host}:{port}", timeout=1)
|
||||
if int(response.status_code) in [200, 404]:
|
||||
started = True
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
# also check if the process.pid is still running
|
||||
if not process.poll() is None:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if wait and not started:
|
||||
print(
|
||||
f"VLLM server process did not start within {wait} seconds. Please check your server logs."
|
||||
)
|
||||
process.kill()
|
||||
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
|
||||
|
||||
# return the process id
|
||||
return process.pid
|
||||
|
||||
|
||||
class TestGRPO:
|
||||
"""
|
||||
Test case for GRPO training using multilpe GPUs
|
||||
"""
|
||||
|
||||
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
|
||||
fout.write(
|
||||
"""import random
|
||||
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||
return [random.uniform(0, 1) for _ in completions]
|
||||
|
||||
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]},],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
"""
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
)
|
||||
@require_vllm
|
||||
def test_llama_dora(self, temp_dir, num_gpus):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"vllm": {
|
||||
"max_model_len": 800,
|
||||
"enable_prefix_caching": True,
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"peft_use_dora": True,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 3,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
current_env = os.environ.copy()
|
||||
env = {
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
}
|
||||
vllm_process_id = start_vllm(
|
||||
cfg.base_model,
|
||||
env=env,
|
||||
quiet=True,
|
||||
wait=120,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=cfg.vllm.max_model_len,
|
||||
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
)
|
||||
|
||||
try:
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(num_gpus),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
],
|
||||
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
|
||||
)
|
||||
finally:
|
||||
os.kill(vllm_process_id, 9)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
)
|
||||
@require_vllm
|
||||
def test_llama_fft(self, temp_dir, num_gpus):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"vllm": {
|
||||
"max_model_len": 800,
|
||||
"enable_prefix_caching": True,
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 3,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
current_env = os.environ.copy()
|
||||
env = {
|
||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
}
|
||||
vllm_process_id = start_vllm(
|
||||
cfg.base_model,
|
||||
env=env,
|
||||
quiet=True,
|
||||
wait=120,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=cfg.vllm.max_model_len,
|
||||
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
)
|
||||
|
||||
try:
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(num_gpus),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
],
|
||||
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
|
||||
)
|
||||
finally:
|
||||
os.kill(vllm_process_id, 9)
|
||||
@@ -52,9 +52,9 @@ class TestMultiGPUEval:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -121,9 +121,9 @@ class TestMultiGPUEval:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
|
||||
100
tests/e2e/multigpu/test_gemma3.py
Normal file
100
tests/e2e/multigpu/test_gemma3.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
E2E tests for multigpu lora tinyllama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_model():
|
||||
# download the model
|
||||
snapshot_download("axolotl-mirrors/gemma-3-4b-pt", repo_type="model")
|
||||
|
||||
|
||||
class TestMultiGPUGemma3:
|
||||
"""
|
||||
Test case for Gemma3 models using LoRA
|
||||
"""
|
||||
|
||||
def test_lora_ddp_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-mirrors/gemma-3-4b-pt",
|
||||
"sequence_len": 2048,
|
||||
"ddp_find_unused_parameters": True,
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.0,
|
||||
"chat_template": "gemma3",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {
|
||||
"use_reentrant": False,
|
||||
},
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
|
||||
)
|
||||
@@ -1,175 +0,0 @@
|
||||
"""
|
||||
GRPO test suite
|
||||
"""
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import require_vllm
|
||||
|
||||
|
||||
class TestGRPO:
|
||||
"""
|
||||
Test case for GRPO training using multilpe GPUs
|
||||
"""
|
||||
|
||||
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
|
||||
fout.write(
|
||||
"""import random
|
||||
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||
return [random.uniform(0, 1) for _ in completions]
|
||||
|
||||
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]},],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
"""
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
)
|
||||
@require_vllm
|
||||
def test_llama_dora(self, temp_dir, num_gpus):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
|
||||
"vllm_gpu_memory_utilization": 0.15,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"peft_use_dora": True,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 5,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(num_gpus),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
)
|
||||
@require_vllm
|
||||
def test_llama_fft(self, temp_dir, num_gpus):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
|
||||
"vllm_gpu_memory_utilization": 0.15,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 5,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(num_gpus),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
@@ -58,6 +58,7 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -121,6 +122,7 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -193,6 +195,7 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
@@ -270,6 +273,7 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
@@ -330,6 +334,7 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
@@ -399,7 +404,8 @@ class TestMultiGPULlama:
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
@@ -478,7 +484,8 @@ class TestMultiGPULlama:
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
@@ -778,9 +785,10 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
|
||||
@@ -46,7 +46,7 @@ class TestMultiGPUQwen2:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"max_steps": 2,
|
||||
"warmup_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestMultiGPURay:
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
|
||||
@@ -110,7 +110,7 @@ class TestRingAttention:
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(sequence_parallel_degree=4)
|
||||
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
|
||||
@@ -324,7 +324,7 @@ class TestDatasetPreparation:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_hub_with_revision_with_dpo(
|
||||
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
):
|
||||
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||
|
||||
@@ -339,12 +339,10 @@ class TestDatasetPreparation:
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset:
|
||||
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
@@ -354,7 +352,9 @@ class TestDatasetPreparation:
|
||||
|
||||
@enable_hf_offline
|
||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||
def test_load_local_hub_with_revision(self, tokenizer):
|
||||
def test_load_local_hub_with_revision(
|
||||
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer
|
||||
):
|
||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
@@ -386,13 +386,23 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_loading_local_dataset_folder(self, tokenizer):
|
||||
|
||||
@@ -238,21 +238,22 @@ class TestDeduplicateRLDataset:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_with_deduplication(
|
||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||
self,
|
||||
cfg,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
tokenizer_huggyllama,
|
||||
):
|
||||
"""Verify that loading with deduplication removes duplicates."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
@@ -263,19 +264,20 @@ class TestDeduplicateRLDataset:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_without_deduplication(
|
||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||
self,
|
||||
cfg,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
tokenizer_huggyllama,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
|
||||
import pytest
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from datasets import concatenate_datasets
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -27,7 +27,6 @@ class TestBatchedSamplerPacking:
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size, num_workers",
|
||||
[
|
||||
@@ -38,14 +37,20 @@ class TestBatchedSamplerPacking:
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
||||
@pytest.mark.parametrize("sequential", [True, False])
|
||||
@enable_hf_offline
|
||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||
def test_packing(
|
||||
self,
|
||||
dataset_winglian_tiny_shakespeare,
|
||||
batch_size,
|
||||
num_workers,
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
sequential,
|
||||
):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
dataset = load_dataset(
|
||||
"winglian/tiny-shakespeare",
|
||||
split="train",
|
||||
)
|
||||
dataset = dataset_winglian_tiny_shakespeare["train"]
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -55,7 +60,7 @@ class TestBatchedSamplerPacking:
|
||||
)
|
||||
ds_cfg = DictDefault(
|
||||
{
|
||||
"field": "Text",
|
||||
"field": "text",
|
||||
}
|
||||
)
|
||||
completion_strategy = load(tokenizer, cfg, ds_cfg)
|
||||
@@ -75,6 +80,7 @@ class TestBatchedSamplerPacking:
|
||||
batch_max_len=max_seq_length,
|
||||
group_size=100000,
|
||||
bin_size=200,
|
||||
sequential=sequential,
|
||||
)
|
||||
|
||||
loader = DataLoader(
|
||||
|
||||
@@ -2,13 +2,8 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
InstructionWSystemPromptTokenizingStrategy,
|
||||
@@ -61,24 +56,13 @@ test_data = {
|
||||
}
|
||||
|
||||
|
||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
class TestPromptTokenizationStrategies:
|
||||
"""
|
||||
Test class for prompt tokenization strategies.
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
@@ -86,7 +70,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
tokenizer_huggyllama_w_special_tokens,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
@@ -99,7 +83,8 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
assert example["labels"][world_idx] == 3186
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
def test_alpaca(self):
|
||||
@enable_hf_offline
|
||||
def test_alpaca(self, tokenizer_huggyllama_w_special_tokens):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
@@ -107,7 +92,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
prompter = AlpacaPrompter()
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
tokenizer_huggyllama_w_special_tokens,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
@@ -118,28 +103,17 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
|
||||
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
||||
class TestInstructionWSystemPromptTokenizingStrategy:
|
||||
"""
|
||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
def test_system_alpaca(self):
|
||||
def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens):
|
||||
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
|
||||
strat = InstructionWSystemPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
tokenizer_huggyllama_w_special_tokens,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
@@ -160,18 +134,13 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
||||
assert example["input_ids"][8] == 11889 # USER
|
||||
|
||||
|
||||
class Llama2ChatTokenizationTest(unittest.TestCase):
|
||||
class Llama2ChatTokenizationTest:
|
||||
"""
|
||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
# woraround because official Meta repos are not open
|
||||
|
||||
def test_llama2_chat_integration(self):
|
||||
def test_llama2_chat_integration(self, tokenizer_llama2_7b):
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
@@ -186,16 +155,18 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
|
||||
prompter = Llama2ChatPrompter()
|
||||
strat = LLama2ChatTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
tokenizer_llama2_7b,
|
||||
False,
|
||||
4096,
|
||||
)
|
||||
example = strat.tokenize_prompt(conversation)
|
||||
for fields in ["input_ids", "attention_mask", "labels"]:
|
||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||
# pytest assert equals
|
||||
|
||||
def compare_with_transformers_integration(self):
|
||||
assert len(example[fields]) == len(tokenized_conversation[fields])
|
||||
assert example[fields] == tokenized_conversation[fields]
|
||||
|
||||
def compare_with_transformers_integration(self, tokenizer_llama2_7b):
|
||||
# this needs transformers >= v4.31.0
|
||||
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
@@ -234,49 +205,27 @@ If a question does not make any sense, or is not factually coherent, explain why
|
||||
generated_responses=answers,
|
||||
)
|
||||
# pylint: disable=W0212
|
||||
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
|
||||
hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
|
||||
|
||||
self.assertEqual(
|
||||
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
|
||||
)
|
||||
assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]
|
||||
|
||||
|
||||
class OrpoTokenizationTest(unittest.TestCase):
|
||||
class OrpoTokenizationTest:
|
||||
"""test case for the ORPO tokenization"""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(
|
||||
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
|
||||
),
|
||||
]
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset = load_dataset(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
||||
).select([0])
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
def test_orpo_integration(self):
|
||||
def test_orpo_integration(
|
||||
self,
|
||||
tokenizer_mistral_7b_instruct_chatml,
|
||||
dataset_argilla_ultrafeedback_binarized_preferences_cleaned,
|
||||
):
|
||||
ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0])
|
||||
strat = load(
|
||||
self.tokenizer,
|
||||
tokenizer_mistral_7b_instruct_chatml,
|
||||
DictDefault({"train_on_inputs": False}),
|
||||
DictDefault({"chat_template": "chatml"}),
|
||||
)
|
||||
res = strat.tokenize_prompt(self.dataset[0])
|
||||
res = strat.tokenize_prompt(ds[0])
|
||||
assert "rejected_input_ids" in res
|
||||
assert "rejected_labels" in res
|
||||
assert "input_ids" in res
|
||||
@@ -295,7 +244,3 @@ class OrpoTokenizationTest(unittest.TestCase):
|
||||
|
||||
assert res["prompt_attention_mask"][0] == 1
|
||||
assert res["prompt_attention_mask"][-1] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -321,3 +321,48 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class TestOptimizerValidation(BaseValidation):
|
||||
"""
|
||||
Test muon optimizer validation
|
||||
"""
|
||||
|
||||
def test_muon_deepspeed(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"optimizer": "muon",
|
||||
"deepspeed": "deepspeed_configs/zero3.json",
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_muon_fsdp(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"optimizer": "muon",
|
||||
"fsdp": ["full_shard"],
|
||||
"fsdp_config": {
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user