Compare commits
3 Commits
sp-rl
...
feat/soap-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a7f048c6b | ||
|
|
76d26366ad | ||
|
|
64fe284765 |
6
.github/workflows/base.yml
vendored
6
.github/workflows/base.yml
vendored
@@ -40,12 +40,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
|
||||
4
.github/workflows/main.yml
vendored
4
.github/workflows/main.yml
vendored
@@ -25,12 +25,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -87,12 +87,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
3
.github/workflows/multi-gpu-e2e.yml
vendored
3
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -42,7 +42,8 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras: vllm
|
||||
# awaiting vllm#12721
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
|
||||
23
.github/workflows/tests-nightly.yml
vendored
23
.github/workflows/tests-nightly.yml
vendored
@@ -33,15 +33,6 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -55,7 +46,7 @@ jobs:
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
- name: Update requirements.txt
|
||||
run: |
|
||||
@@ -67,7 +58,8 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
@@ -81,15 +73,10 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest tests/patched/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
|
||||
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -96,10 +96,6 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
@@ -260,7 +256,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras: vllm
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -40,7 +40,6 @@ quartodoc:
|
||||
- cli.preprocess
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.vllm_serve
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- title: Trainers
|
||||
@@ -244,7 +243,6 @@ website:
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -2,5 +2,4 @@
|
||||
set -e
|
||||
|
||||
# only run one test at a time so as not to OOM the GPU
|
||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
pytest -v -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
||||
|
||||
40
docs/cli.qmd
40
docs/cli.qmd
@@ -170,7 +170,7 @@ axolotl merge-sharded-fsdp-weights config.yml
|
||||
|
||||
### evaluate
|
||||
|
||||
Evaluates a model's performance (loss etc) on the train and eval datasets.
|
||||
Evaluates a model's performance using metrics specified in the config.
|
||||
|
||||
```bash
|
||||
# Basic evaluation
|
||||
@@ -197,8 +197,6 @@ lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
```
|
||||
|
||||
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
|
||||
|
||||
## Legacy CLI Usage
|
||||
|
||||
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
|
||||
@@ -237,7 +235,7 @@ Create a cloud config YAML with your Modal settings:
|
||||
```yaml
|
||||
# cloud_config.yml
|
||||
provider: modal
|
||||
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
|
||||
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
|
||||
gpu_count: 1 # Number of GPUs to use
|
||||
timeout: 86400 # Maximum runtime in seconds (24 hours)
|
||||
branch: main # Git branch to use (optional)
|
||||
@@ -250,7 +248,7 @@ volumes: # Persistent storage volumes
|
||||
- name: axolotl-artifacts
|
||||
mount: /workspace/artifacts
|
||||
|
||||
secrets: # Secrets to inject
|
||||
env: # Environment variables
|
||||
- WANDB_API_KEY
|
||||
- HF_TOKEN
|
||||
```
|
||||
@@ -276,27 +274,15 @@ axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
### Cloud Configuration Options
|
||||
|
||||
```yaml
|
||||
provider: # compute provider, currently only `modal` is supported
|
||||
gpu: # GPU type to use
|
||||
gpu_count: # Number of GPUs (default: 1)
|
||||
memory: # RAM in GB (default: 128)
|
||||
timeout: # Maximum runtime in seconds
|
||||
provider: # compute provider, currently only `modal` is supported
|
||||
gpu: # GPU type to use
|
||||
gpu_count: # Number of GPUs (default: 1)
|
||||
memory: # RAM in GB (default: 128)
|
||||
timeout: # Maximum runtime in seconds
|
||||
timeout_preprocess: # Preprocessing timeout
|
||||
branch: # Git branch to use
|
||||
docker_tag: # Custom Docker image tag
|
||||
volumes: # List of persistent storage volumes
|
||||
|
||||
# Environment variables to pass. Can be specified in two ways:
|
||||
# 1. As a string: Will load the value from the host computer's environment variables
|
||||
# 2. As a key-value pair: Will use the specified value directly
|
||||
# Example:
|
||||
# env:
|
||||
# - CUSTOM_VAR # Loads from host's $CUSTOM_VAR
|
||||
# - {CUSTOM_VAR: "value"} # Uses "value" directly
|
||||
env:
|
||||
|
||||
# Secrets to inject. Same input format as `env` but for sensitive data.
|
||||
secrets:
|
||||
# - HF_TOKEN
|
||||
# - WANDB_API_KEY
|
||||
branch: # Git branch to use
|
||||
docker_tag: # Custom Docker image tag
|
||||
volumes: # List of persistent storage volumes
|
||||
env: # Environment variables to pass
|
||||
secrets: # Secrets to inject
|
||||
```
|
||||
|
||||
@@ -238,10 +238,10 @@ simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||
# grpo
|
||||
trl:
|
||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
|
||||
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
|
||||
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
|
||||
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
|
||||
vllm_device: # Optional[str]. Device to use for VLLM.
|
||||
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
||||
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
||||
vllm_dtype: # Optional[str]. Data type for VLLM.
|
||||
|
||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||
@@ -320,13 +320,9 @@ total_num_tokens:
|
||||
sample_packing_group_size: 100000
|
||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||
sample_packing_bin_size: 200
|
||||
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
|
||||
|
||||
# whether to concatenate samples during pretraining
|
||||
pretraining_sample_concatenation:
|
||||
|
||||
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
|
||||
|
||||
# Use batch flattening for speedups when not using sample_packing
|
||||
batch_flattening:
|
||||
|
||||
@@ -358,27 +354,7 @@ lora_target_modules:
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
lora_target_linear: # If true, will target all linear modules
|
||||
|
||||
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
|
||||
peft_layers_to_transform:
|
||||
|
||||
# Optional[bool]. Whether to use DoRA.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
|
||||
peft_use_dora:
|
||||
|
||||
# Optional[bool]. Whether to use RSLoRA.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
|
||||
peft_use_rslora:
|
||||
|
||||
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
|
||||
peft_layer_replication:
|
||||
|
||||
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
|
||||
# How to initialize LoRA weights. Default to True which is MS original implementation.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
|
||||
peft_init_lora_weights:
|
||||
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
@@ -611,31 +587,26 @@ max_grad_norm:
|
||||
# currently only supported on Llama and Mistral
|
||||
neftune_noise_alpha:
|
||||
|
||||
# Optional[bool]. Whether to bettertransformers
|
||||
# Whether to bettertransformers
|
||||
flash_optimum:
|
||||
|
||||
# Note: Only one of the following attention patches can be used at a time.
|
||||
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
|
||||
|
||||
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
|
||||
# Optional[bool]. Whether to use scaled-dot-product attention
|
||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# Whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||
s2_attention:
|
||||
|
||||
# Optional[bool]. Whether to use low_cpu_mem_usage
|
||||
low_cpu_mem_usage:
|
||||
# Optional[str]. Resume from a specific checkpoint dir
|
||||
# Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Be careful with this being turned on between different models.
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
@@ -686,11 +657,7 @@ ddp_broadcast_buffers:
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
|
||||
flash_attention: true # SP requires flash attention
|
||||
micro_batch_size: 1 # SP requires this is set to 1
|
||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||
heads_k_stride: 1
|
||||
sequence_parallel_degree:
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
12
docs/faq.qmd
12
docs/faq.qmd
@@ -35,22 +35,12 @@ description: Frequently asked questions
|
||||
|
||||
**Q: How to call Axolotl via custom python scripts?**
|
||||
|
||||
> A: Since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||
|
||||
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
|
||||
|
||||
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
|
||||
|
||||
**Q: ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as pad_token**
|
||||
|
||||
> A: This is because the tokenizer does not have a padding token. Please add a padding token to the tokenizer via:
|
||||
|
||||
> ```yaml
|
||||
> special_tokens:
|
||||
> # str. If you're not sure, set to same as `eos_token`.
|
||||
> pad_token: "..."
|
||||
> ```
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
@@ -18,7 +18,6 @@ Axolotl supports several methods for multi-GPU training:
|
||||
|
||||
- DeepSpeed (recommended)
|
||||
- FSDP (Fully Sharded Data Parallel)
|
||||
- Sequence parallelism
|
||||
- FSDP + QLoRA
|
||||
|
||||
## DeepSpeed {#sec-deepspeed}
|
||||
@@ -67,28 +66,6 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
## Sequence parallelism {#sec-sequence-parallelism}
|
||||
|
||||
We support sequence parallelism (SP) via the
|
||||
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
|
||||
allows one to split up sequences across GPUs, which is useful in the event that a
|
||||
single sequence causes OOM errors during model training.
|
||||
|
||||
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
|
||||
or from source with `pip install .[ring-flash-attn]`.
|
||||
|
||||
Your Axolotl YAML config should contain the following lines:
|
||||
|
||||
```{.yaml}
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
|
||||
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
|
||||
heads_k_stride: 1
|
||||
```
|
||||
|
||||
See our [dedicated guide](sequence_parallelism.qmd) for more details.
|
||||
|
||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||
|
||||
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||
|
||||
@@ -502,48 +502,9 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||
:::
|
||||
|
||||
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
|
||||
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
|
||||
using 4 GPUs - 2 for training, and 2 for vLLM:
|
||||
|
||||
::: {.callout-important}
|
||||
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
tensor_parallel_size: 2
|
||||
gpu_memory_utilization: 0.85
|
||||
dtype: auto
|
||||
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
|
||||
|
||||
rl: grpo
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
vllm_server_timeout: 300
|
||||
```
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml
|
||||
```
|
||||
|
||||
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
|
||||
```
|
||||
|
||||
#### Reward functions
|
||||
|
||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||
|
||||
For example, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
|
||||
```python
|
||||
# rewards.py
|
||||
@@ -569,6 +530,8 @@ trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 256
|
||||
use_vllm: True
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.15
|
||||
num_generations: 4
|
||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||
reward_weights: [1.0]
|
||||
|
||||
@@ -23,11 +23,8 @@ Use sequence parallelism when:
|
||||
To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # SP requires flash attention
|
||||
micro_batch_size: 1 # SP requires this is set to 1
|
||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||
heads_k_stride: 1
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
@@ -61,22 +58,16 @@ To use sequence parallelism, you need:
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
# Example config with sequence parallelism
|
||||
base_model: meta-llama/Llama-3-8B-Instruct
|
||||
sequence_len: 8192
|
||||
|
||||
...
|
||||
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # SP requires flash attention
|
||||
micro_batch_size: 1 # SP requires this is set to 1
|
||||
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||
heads_k_stride: 1
|
||||
|
||||
sequence_parallel_degree: 2 # Split each sequence into 4 parts
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
...
|
||||
```
|
||||
|
||||
This will train the Llama 3 8B model with 8192 context length, with each sequence split
|
||||
into 4 subsequences of length 2048 across 4 GPUs.
|
||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||
into 2 subsequences of length 4096 across 2 GPUs.
|
||||
|
||||
## Sample Packing with Sequence Parallelism
|
||||
|
||||
@@ -88,14 +79,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
First, note that sequence parallelism supports only the case where `micro_batch_size: 1`.
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches are processed per step
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
@@ -5,9 +5,6 @@ tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
@@ -57,8 +54,6 @@ fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
|
||||
@@ -7,9 +7,6 @@ skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
@@ -51,8 +48,6 @@ fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
@@ -82,6 +82,3 @@ deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
test_value: true
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
sample_packing_sequentially: true
|
||||
curriculum_sampling: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.4
|
||||
bitsandbytes==0.45.3
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
@@ -12,12 +12,12 @@ liger-kernel==0.5.5
|
||||
packaging==23.2
|
||||
|
||||
peft==0.15.0
|
||||
transformers==4.50.3
|
||||
transformers==4.50.0
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.5.2
|
||||
datasets==3.5.0
|
||||
deepspeed==0.16.4
|
||||
trl==0.16.0
|
||||
trl==0.15.1
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
|
||||
87
setup.py
87
setup.py
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def parse_requirements(extras_require_map):
|
||||
def parse_requirements():
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
@@ -67,7 +67,6 @@ def parse_requirements(extras_require_map):
|
||||
if (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post2")
|
||||
extras_require_map["vllm"] = ["vllm==0.8.1"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
@@ -87,7 +86,7 @@ def parse_requirements(extras_require_map):
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
return _install_requires, _dependency_links, extras_require_map
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
def get_package_version():
|
||||
@@ -104,50 +103,7 @@ def get_package_version():
|
||||
return version_
|
||||
|
||||
|
||||
extras_require = {
|
||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.7.4.post1",
|
||||
"ring-flash-attn>=0.1.4",
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
],
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
],
|
||||
"mlflow": [
|
||||
"mlflow",
|
||||
],
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
"apollo": [
|
||||
"apollo-torch",
|
||||
],
|
||||
"optimizers": [
|
||||
"galore_torch",
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
}
|
||||
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
extras_require
|
||||
)
|
||||
install_requires, dependency_links = parse_requirements()
|
||||
|
||||
setup(
|
||||
version=get_package_version(),
|
||||
@@ -160,5 +116,40 @@ setup(
|
||||
"axolotl=axolotl.cli.main:main",
|
||||
],
|
||||
},
|
||||
extras_require=extras_require_build,
|
||||
extras_require={
|
||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
],
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
],
|
||||
"mlflow": [
|
||||
"mlflow",
|
||||
],
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
"apollo": [
|
||||
"apollo-torch",
|
||||
],
|
||||
"optimizers": [
|
||||
"galore_torch",
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -35,55 +35,6 @@ class TrainerCliArgs:
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VllmServeCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
||||
|
||||
tensor_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
host: str = field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
)
|
||||
port: int = field(
|
||||
default=8000,
|
||||
metadata={"help": "Port to run the server on."},
|
||||
)
|
||||
gpu_memory_utilization: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||||
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
||||
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
||||
"out-of-memory (OOM) errors during initialization."
|
||||
},
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||
},
|
||||
)
|
||||
max_model_len: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
||||
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
||||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||
},
|
||||
)
|
||||
enable_prefix_caching: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
||||
"hardware support this feature."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluateCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
||||
|
||||
@@ -14,12 +14,7 @@ import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.args import (
|
||||
EvaluateCliArgs,
|
||||
PreprocessCliArgs,
|
||||
TrainerCliArgs,
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
@@ -28,7 +23,6 @@ from axolotl.cli.utils import (
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from axolotl.cli.vllm_serve import do_vllm_serve
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
@@ -322,14 +316,6 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
fetch_from_github(f"{directory}/", dest)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(VllmServeCliArgs)
|
||||
@filter_none_kwargs
|
||||
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||
do_vllm_serve(config, cli_args)
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"""
|
||||
CLI to start the vllm server for online RL
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
from trl.scripts.vllm_serve import main as vllm_serve_main
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
|
||||
|
||||
def do_vllm_serve(
|
||||
config: Union[Path, str],
|
||||
cli_args: dict,
|
||||
):
|
||||
"""
|
||||
Starts the VLLM server for serving LLM models used for online RL
|
||||
|
||||
Args
|
||||
:param cfg: Parsed doct of the YAML config
|
||||
:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs
|
||||
|
||||
Returns:
|
||||
process_id: the process id of the started VLLM server
|
||||
"""
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
host = cli_args.get("host") or cfg.vllm.host
|
||||
port = cli_args.get("port") or cfg.vllm.port
|
||||
gpu_memory_utilization = (
|
||||
cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization
|
||||
)
|
||||
dtype = cli_args.get("dtype") or cfg.vllm.dtype
|
||||
max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len
|
||||
enable_prefix_caching = (
|
||||
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
|
||||
)
|
||||
|
||||
vllm_script_args = ScriptArguments(
|
||||
model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
host=host,
|
||||
port=port,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
vllm_serve_main(vllm_script_args)
|
||||
@@ -69,6 +69,7 @@ from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SaveModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
@@ -248,6 +249,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
callbacks.append(SaveModelCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -524,15 +526,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
and self.cfg.eval_steps
|
||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||
) or False
|
||||
|
||||
# handle ddp
|
||||
ddp_find_unused_parameters = None
|
||||
if self.cfg.ddp:
|
||||
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
|
||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||
ddp_find_unused_parameters
|
||||
False if self.cfg.ddp else None
|
||||
)
|
||||
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
report_to = []
|
||||
@@ -667,6 +663,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "soap":
|
||||
from axolotl.utils.optimizers.soap import SOAP
|
||||
|
||||
optimizer_cls = SOAP
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
@@ -941,6 +942,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
callbacks.append(SaveModelCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -1043,10 +1045,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rpo_alpha is not None:
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl == "simpo":
|
||||
@@ -1165,7 +1163,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
dpo_trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
@@ -1183,3 +1180,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer.add_callback(callback)
|
||||
|
||||
return dpo_trainer
|
||||
|
||||
|
||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
HF Factory class for PPO Trainer
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
# build PPOConfig
|
||||
pass
|
||||
|
||||
@@ -3,16 +3,16 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
|
||||
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
|
||||
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
|
||||
from axolotl.core.trainers.relora import ReLoRATrainer
|
||||
from axolotl.core.trainers.trl import (
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
TRLPPOTrainer,
|
||||
)
|
||||
|
||||
@@ -12,8 +12,8 @@ from typing import Any, Literal
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
@@ -26,8 +26,11 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.mixins import (
|
||||
OptimizerMixin,
|
||||
SchedulerMixin,
|
||||
SequenceParallelMixin,
|
||||
)
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
@@ -37,7 +40,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
@@ -63,7 +66,9 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
self.sequence_parallel_handler = SequenceParallelHandler(self.args)
|
||||
# Initialize sequence parallelism if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._setup_sequence_parallel()
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
@@ -107,7 +112,6 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
@@ -124,7 +128,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
|
||||
# Determine the base sampler first
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
|
||||
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
||||
elif self.args.curriculum_sampling:
|
||||
base_sampler = SequentialSampler(self.train_dataset)
|
||||
elif use_sample_packing:
|
||||
@@ -160,7 +164,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
|
||||
# Determine the base sampler
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
|
||||
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
||||
elif use_multipack:
|
||||
base_sampler = SequentialSampler(eval_dataset)
|
||||
else:
|
||||
@@ -232,10 +236,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
dataloader = self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
return dataloader
|
||||
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
@@ -344,57 +345,7 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
num_items_in_batch: The number of items in the batch.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal training step
|
||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: list[str] | None = None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Perform a prediction step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform prediction step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
prediction_loss_only: Whether to return only the loss.
|
||||
ignore_keys: Keys to ignore in the inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of (loss, logits, labels).
|
||||
"""
|
||||
# Set up sequence parallelism for this prediction step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal prediction step
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
@@ -638,3 +589,27 @@ class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal training step
|
||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
return loss
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""DPO Specific Strategy for training"""
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class DPOStrategy:
|
||||
"""Strategy for DPO training"""
|
||||
"""
|
||||
Strategy for DPO training
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Axolotl specific DPO args"""
|
||||
"""
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -9,4 +11,6 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""DPO config for DPO training"""
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""DPO trainer for axolotl"""
|
||||
"""
|
||||
DPO trainer for axolotl
|
||||
"""
|
||||
|
||||
import gc
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
@@ -10,8 +13,7 @@ from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.mixins import SchedulerMixin
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
@@ -21,18 +23,18 @@ if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||
"""Extend the base DPOTrainer for axolotl helpers"""
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
|
||||
|
||||
def create_optimizer(self):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -86,7 +88,7 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> dict:
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
@@ -115,9 +117,10 @@ class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any | None],
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
|
||||
|
||||
return super().training_step(model, inputs, num_items_in_batch)
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
@@ -40,15 +40,18 @@ class GRPOStrategy:
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
|
||||
if trl.vllm_server_timeout:
|
||||
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||
if trl.vllm_guided_decoding_regex:
|
||||
grpo_args_kwargs["vllm_guided_decoding_regex"] = (
|
||||
trl.vllm_guided_decoding_regex
|
||||
grpo_args_kwargs["vllm_device"] = (
|
||||
trl.vllm_device if trl.vllm_device else "auto"
|
||||
)
|
||||
|
||||
if trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
trl.vllm_gpu_memory_utilization
|
||||
)
|
||||
|
||||
if trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||
|
||||
if trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||
|
||||
@@ -67,25 +70,6 @@ class GRPOStrategy:
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
if trl.scale_rewards is not None:
|
||||
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
|
||||
|
||||
if trl.temperature is not None:
|
||||
grpo_args_kwargs["temperature"] = trl.temperature
|
||||
if trl.top_p is not None:
|
||||
grpo_args_kwargs["top_p"] = trl.top_p
|
||||
if trl.top_k is not None:
|
||||
grpo_args_kwargs["top_k"] = trl.top_k
|
||||
if trl.min_p is not None:
|
||||
grpo_args_kwargs["min_p"] = trl.min_p
|
||||
if trl.repetition_penalty is not None:
|
||||
grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty
|
||||
|
||||
if trl.num_iterations is not None:
|
||||
grpo_args_kwargs["num_iterations"] = trl.num_iterations
|
||||
if trl.epsilon is not None:
|
||||
grpo_args_kwargs["epsilon"] = trl.epsilon
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,65 +1,109 @@
|
||||
"""Axolotl GRPO trainer"""
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
|
||||
from contextlib import nullcontext
|
||||
from accelerate.utils import is_peft_model
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from transformers import PreTrainedModel
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||
from trl import GRPOTrainer
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
# mypy: ignore-errors
|
||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
@profiling_decorator
|
||||
def _move_model_to_vllm(self):
|
||||
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
gather_if_zero3 = (
|
||||
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# pylint: disable=access-member-before-definition
|
||||
# Enable gradient checkpointing if requested
|
||||
if kwargs["args"].gradient_checkpointing:
|
||||
# Ensure use_cache is disabled
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(self.model) and hasattr(
|
||||
self.model.base_model, "gradient_checkpointing_enable"
|
||||
):
|
||||
self.model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# pylint: disable=unused-argument,redefined-builtin
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
)
|
||||
|
||||
if is_peft_model(self.model):
|
||||
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
||||
# adapters in a sharded manner is not supported.
|
||||
with gather_if_zero3(list(self.model.parameters())):
|
||||
self.model.merge_adapter()
|
||||
if use_reentrant:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
# Update vLLM weights while parameters are gathered
|
||||
for name, param in self.model.named_parameters():
|
||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
||||
name = (
|
||||
name.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", "")
|
||||
)
|
||||
if self.model.prefix in name:
|
||||
continue
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
if "original_module" in name:
|
||||
continue
|
||||
name = name.replace("modules_to_save.default.", "")
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
model.get_input_embeddings().register_forward_hook(
|
||||
make_inputs_require_grad
|
||||
)
|
||||
|
||||
# Unmerge adapters while parameters are still gathered
|
||||
self.model.unmerge_adapter()
|
||||
# Parameters will automatically be repartitioned when exiting the context
|
||||
else:
|
||||
# For non-PEFT models, simply gather and update each parameter individually.
|
||||
for name, param in self.model.named_parameters():
|
||||
with gather_if_zero3([param]):
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
return model
|
||||
# pylint: enable=unused-argument,redefined-builtin
|
||||
|
||||
# Reset cache on main process
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
def _move_model_to_vllm(self):
|
||||
with unwrap_model_for_generation(
|
||||
self.model,
|
||||
self.accelerator,
|
||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||
) as unwrapped_model:
|
||||
if is_compiled_module(unwrapped_model):
|
||||
unwrapped_model = (
|
||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||
)
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", ""): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
# Remove values with adapter prefix (example: "_lora")
|
||||
state_dict = {
|
||||
k: v
|
||||
for k, v in state_dict.items()
|
||||
if unwrapped_model.prefix not in k
|
||||
}
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
state_dict = {
|
||||
k.replace("modules_to_save.default.", ""): v
|
||||
for k, v in state_dict.items()
|
||||
if "original_module" not in k
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Init for trainer handlers"""
|
||||
|
||||
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler
|
||||
@@ -1,123 +0,0 @@
|
||||
"""Handler class for sequence parallel trainer logic"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DistributedSampler
|
||||
|
||||
|
||||
class SequenceParallelHandler:
|
||||
"""
|
||||
Handler class that encapsulates sequence parallelism functionality.
|
||||
This replaces the SequenceParallelMixin with a composition-based approach.
|
||||
"""
|
||||
|
||||
def __init__(self, args=None):
|
||||
"""
|
||||
Initialize the sequence parallel handler.
|
||||
|
||||
Args:
|
||||
args: The arguments object containing sequence parallelism settings.
|
||||
"""
|
||||
self.args = args
|
||||
self.ring_attn_group = None
|
||||
|
||||
# Set up sequence parallelism if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._setup_sequence_parallel()
|
||||
|
||||
def _setup_sequence_parallel(self):
|
||||
"""Set up sequence parallelism environment."""
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
self.update_ring_flash_attn_params = update_ring_flash_attn_params
|
||||
self.ring_attn_group = get_ring_attn_group()
|
||||
|
||||
def create_sequence_parallel_sampler(
|
||||
self,
|
||||
dataset,
|
||||
shuffle=True,
|
||||
is_eval=False,
|
||||
):
|
||||
"""
|
||||
Helper method to create sampler for sequence parallelism (SP).
|
||||
|
||||
Args:
|
||||
dataset: Dataset to sample from.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||
|
||||
Returns:
|
||||
Distributed sampler.
|
||||
"""
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
return DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed if shuffle else None,
|
||||
shuffle=shuffle,
|
||||
drop_last=not is_eval,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self, dataset):
|
||||
"""
|
||||
Get a training sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
dataset: The training dataset.
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self.create_sequence_parallel_sampler(
|
||||
dataset,
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset):
|
||||
"""
|
||||
Get an evaluation sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
eval_dataset: The evaluation dataset.
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self.create_sequence_parallel_sampler(
|
||||
eval_dataset, shuffle=False, is_eval=True
|
||||
)
|
||||
|
||||
def _update_ring_flash_attn_params(self, inputs):
|
||||
"""
|
||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||
the substituted ring_flash_attn.
|
||||
|
||||
Args:
|
||||
inputs: Current batch of inputs.
|
||||
"""
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||
@@ -3,12 +3,6 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
|
||||
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class TrainerMixins(
|
||||
OptimizerMixin, RngLoaderMixin, SchedulerMixin
|
||||
):
|
||||
"""Stub class combining all mixins for Axolotl trainers."""
|
||||
from .optimizer import OptimizerMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
from .sequence_parallel import SequenceParallelMixin
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
"""
|
||||
Temporary fix/override for bug in resume from checkpoint
|
||||
|
||||
See https://github.com/huggingface/transformers/pull/37162
|
||||
|
||||
TODO: Remove when upstream added PR to release
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Trainer, is_torch_npu_available
|
||||
from transformers.trainer import safe_globals
|
||||
from transformers.trainer_pt_utils import set_rng_state_for_device
|
||||
from transformers.training_args import ParallelMode
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RngLoaderMixin(Trainer):
|
||||
"""Mixin for method override to load RNG states from a checkpoint"""
|
||||
|
||||
def _load_rng_state(self, checkpoint):
|
||||
# Load RNG states from `checkpoint`
|
||||
if checkpoint is None:
|
||||
return
|
||||
|
||||
if self.args.world_size > 1:
|
||||
process_index = self.args.process_index
|
||||
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
|
||||
if not os.path.isfile(rng_file):
|
||||
LOG.info(
|
||||
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
|
||||
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
|
||||
)
|
||||
return
|
||||
else:
|
||||
rng_file = os.path.join(checkpoint, "rng_state.pth")
|
||||
if not os.path.isfile(rng_file):
|
||||
LOG.info(
|
||||
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
|
||||
"fashion, reproducibility is not guaranteed."
|
||||
)
|
||||
return
|
||||
|
||||
# Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
|
||||
# which requires allowlisted classes when loading with weights_only=True.
|
||||
with safe_globals():
|
||||
checkpoint_rng_state = torch.load(rng_file) # nosec B614
|
||||
random.setstate(checkpoint_rng_state["python"])
|
||||
np.random.set_state(checkpoint_rng_state["numpy"])
|
||||
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
||||
|
||||
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
if torch.cuda.is_available():
|
||||
set_rng_state_for_device(
|
||||
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
|
||||
)
|
||||
if is_torch_npu_available():
|
||||
set_rng_state_for_device(
|
||||
"NPU", torch.npu, checkpoint_rng_state, is_distributed
|
||||
)
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||
# TODO(Dan): remove
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -71,12 +70,12 @@ class SequenceParallelMixin:
|
||||
drop_last=not is_eval,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self, dataset) -> Sampler | None:
|
||||
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
||||
"""
|
||||
Get a training sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
dataset: The training dataset.
|
||||
dataset: The training dataset
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
@@ -86,7 +85,7 @@ class SequenceParallelMixin:
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||
"""
|
||||
Get an evaluation sampler configured for sequence parallelism.
|
||||
|
||||
|
||||
@@ -13,10 +13,10 @@ from trl import (
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||
|
||||
tag_names = ["axolotl", "ppo"]
|
||||
@@ -74,8 +74,10 @@ class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
||||
"""Extend the base ORPOTrainer for axolotl helpers"""
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
@@ -152,14 +154,18 @@ class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
|
||||
"""Extend the base KTOTrainer for axolotl helpers"""
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
||||
"""Extend the base CPOTrainer for axolotl helpers"""
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
@@ -238,13 +244,17 @@ class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
|
||||
"""Extend the base RewardTrainer for axolotl helpers"""
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
|
||||
"""Extend the base trl.PRMTrainer for axolotl helpers"""
|
||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
|
||||
@@ -12,7 +12,9 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""Mixin class for the Axolotl training args."""
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
@@ -32,12 +34,6 @@ class AxolotlTrainingMixins:
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_sequentially: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
|
||||
@@ -15,7 +15,6 @@ from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
@@ -160,6 +159,4 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
cleanup_distributed()
|
||||
|
||||
return all_metrics
|
||||
|
||||
@@ -6,22 +6,11 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from accelerate.logging import get_logger
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
try:
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
except ImportError:
|
||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||
pass
|
||||
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -43,133 +32,19 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
Setter for ring attention group on this rank.
|
||||
|
||||
Args:
|
||||
ring_attn_group: Process group for ring attention.
|
||||
Process group for ring attention.
|
||||
"""
|
||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int):
|
||||
"""
|
||||
Patch flash attention a second time to handle batched data. This is a hack to
|
||||
accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is
|
||||
specified in the Axolotl config.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
"""
|
||||
# Store the original flash attention function
|
||||
original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||
|
||||
def sequential_batch_flash_attention(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
dropout: float = 0.0,
|
||||
scaling: float | None = None,
|
||||
sliding_window: int | None = None,
|
||||
softcap: float | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
# Check if we have a batch dimension > 1
|
||||
batch_size = query.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return original_flash_attention(
|
||||
module,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
dropout,
|
||||
scaling,
|
||||
sliding_window,
|
||||
softcap,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Process each item in the batch separately
|
||||
outputs = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Extract single batch item
|
||||
q_item = query[i:i+1]
|
||||
k_item = key[i:i+1]
|
||||
v_item = value[i:i+1]
|
||||
|
||||
# Handle attention mask - it might be None or have different shapes
|
||||
mask_item = None
|
||||
if attention_mask is not None:
|
||||
# The mask could have different formats depending on implementation
|
||||
if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size:
|
||||
mask_item = attention_mask[i:i+1]
|
||||
else:
|
||||
# For broadcast masks that don't have a batch dimension
|
||||
mask_item = attention_mask
|
||||
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = q_item.shape[0]
|
||||
seq_len = q_item.shape[2]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * sequence_parallel_degree
|
||||
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
|
||||
# Call the original function for a single batch item
|
||||
output, _ = original_flash_attention(
|
||||
module,
|
||||
q_item,
|
||||
k_item,
|
||||
v_item,
|
||||
mask_item,
|
||||
dropout,
|
||||
scaling,
|
||||
sliding_window,
|
||||
softcap,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
# Concatenate results along batch dimension
|
||||
concatenated_output = torch.cat(outputs, dim=0)
|
||||
return concatenated_output, None
|
||||
|
||||
# Replace the original function with our sequential version
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||
def register_ring_attn(sequence_parallel_degree: int):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed
|
||||
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
||||
"""
|
||||
if get_ring_attn_group() is not None:
|
||||
LOG.info("Ring attention already registered, exiting early...")
|
||||
return
|
||||
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
@@ -209,12 +84,6 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if heads_k_stride is None:
|
||||
heads_k_stride = 1
|
||||
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||
)
|
||||
patch_flash_attention_for_sequential_batch(sequence_parallel_degree)
|
||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
||||
|
||||
@@ -22,7 +22,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"phi3",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"gemma3",
|
||||
"gemma3_text",
|
||||
"cohere",
|
||||
"cohere2",
|
||||
|
||||
@@ -27,7 +27,6 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
@@ -158,8 +157,6 @@ def setup_signal_handler(
|
||||
_model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
cleanup_distributed()
|
||||
sys.exit(0)
|
||||
|
||||
_model_weakref = weakref.ref(model)
|
||||
@@ -481,7 +478,7 @@ def train(
|
||||
Returns:
|
||||
Tuple of (model, tokenizer) after training
|
||||
"""
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer etc.
|
||||
(
|
||||
trainer,
|
||||
model,
|
||||
@@ -490,26 +487,34 @@ def train(
|
||||
processor,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
# Determine if we need to resume from a checkpoint
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
|
||||
# Configuration for saving
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
handle_untrained_tokens_fix(
|
||||
cfg, model, tokenizer, train_dataset, safe_serialization
|
||||
)
|
||||
|
||||
# Additional setup
|
||||
# Save initial configs
|
||||
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
||||
|
||||
# Set up signal handler for graceful termination
|
||||
setup_signal_handler(cfg, model, safe_serialization)
|
||||
|
||||
# Set up badges and config info for model card
|
||||
setup_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# Save the trained model and cleanup
|
||||
# Save the trained model
|
||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||
|
||||
# Create model card
|
||||
create_model_card(cfg, trainer)
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -816,6 +816,27 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
|
||||
class SaveModelCallback(TrainerCallback):
|
||||
"""Callback to save model on train end"""
|
||||
|
||||
def on_step_end( # pylint: disable=unused-argument
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
# Save
|
||||
if state.global_step >= state.max_steps:
|
||||
control.should_save = True
|
||||
|
||||
def on_train_end( # pylint: disable=unused-argument
|
||||
self, args, state, control, **kwargs
|
||||
):
|
||||
control.should_save = True
|
||||
return control
|
||||
|
||||
|
||||
class GCCallback(TrainerCallback):
|
||||
"""Callback to garbage collect torch cache"""
|
||||
|
||||
|
||||
@@ -112,7 +112,6 @@ class DataCollatorForSeq2Seq:
|
||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
has_attn_mask = "attention_mask" in features[0].keys()
|
||||
labels = None
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
@@ -165,8 +164,6 @@ class DataCollatorForSeq2Seq:
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
if not has_attn_mask:
|
||||
del features["attention_mask"]
|
||||
|
||||
# prepare decoder_input_ids
|
||||
if (
|
||||
|
||||
@@ -238,8 +238,7 @@ def load_dataset_w_config(
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif config_dataset.data_files:
|
||||
fp: str | list[str] | None = None
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
|
||||
@@ -71,8 +71,8 @@ def barrier():
|
||||
|
||||
def is_main_process():
|
||||
"""
|
||||
Check if the current process is the main process. If not in distributed mode,
|
||||
always return `True`.
|
||||
Check if the current process is the main process.
|
||||
If not in distributed mode, always return True.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return True
|
||||
@@ -87,18 +87,6 @@ def get_world_size():
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
def cleanup_distributed():
|
||||
"""
|
||||
Destroy process group if torch distributed is initialized. Called in training early
|
||||
termination or when training successfully completes.
|
||||
"""
|
||||
# Ensure that all operations are completed before destroying the process group
|
||||
torch.cuda.synchronize()
|
||||
# Destroy the process group
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_only():
|
||||
"""
|
||||
|
||||
@@ -609,10 +609,7 @@ class ModelLoader:
|
||||
# Initialize ring attn for sequence parallelism. This must be done after
|
||||
# model init but before the first forward pass, since it modifies flash
|
||||
# attn to use ring comm for SP training across multiple GPUs.
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
||||
heads_k_stride=self.cfg.heads_k_stride,
|
||||
)
|
||||
register_ring_attn(self.cfg.sequence_parallel_degree)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
@@ -1351,7 +1348,9 @@ def load_model(
|
||||
reference_model: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""Load a model for a given configuration and tokenizer."""
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
model_loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
@@ -1360,16 +1359,12 @@ def load_model(
|
||||
reference_model=reference_model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model_loader.load_model()
|
||||
|
||||
|
||||
def load_adapter(
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
adapter: str | None,
|
||||
inference: bool = False,
|
||||
) -> tuple[PreTrainedModel, PeftConfig | None]:
|
||||
def load_adapter(model, cfg, adapter, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
@@ -1382,9 +1377,8 @@ def load_adapter(
|
||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||
|
||||
|
||||
def load_llama_adapter(
|
||||
model: PreTrainedModel, cfg: DictDefault
|
||||
) -> tuple[PreTrainedModel, PeftConfig | None]:
|
||||
def load_llama_adapter(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
from peft import AdaptionPromptConfig, get_peft_model
|
||||
|
||||
peft_config = AdaptionPromptConfig(
|
||||
@@ -1408,7 +1402,7 @@ def load_llama_adapter(
|
||||
return model, peft_config
|
||||
|
||||
|
||||
def find_all_linear_names(model: PreTrainedModel):
|
||||
def find_all_linear_names(model):
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
|
||||
21
src/axolotl/utils/optimizers/soap/LICENSE
Normal file
21
src/axolotl/utils/optimizers/soap/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Nikhil Vyas
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
495
src/axolotl/utils/optimizers/soap/__init__.py
Normal file
495
src/axolotl/utils/optimizers/soap/__init__.py
Normal file
@@ -0,0 +1,495 @@
|
||||
# pylint: skip-file
|
||||
# Copied from https://github.com/nikhilvyas/SOAP
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
# Parts of the code are modifications of Pytorch's AdamW optimizer
|
||||
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py
|
||||
|
||||
|
||||
class SOAP(optim.Optimizer):
|
||||
"""
|
||||
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
|
||||
|
||||
Parameters:
|
||||
params (`Iterable[nn.parameter.Parameter]`):
|
||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||
lr (`float`, *optional*, defaults to 0.003):
|
||||
The learning rate to use.
|
||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
|
||||
Adam's betas parameters (b1, b2).
|
||||
shampoo_beta (`float`, *optional*, defaults to -1):
|
||||
If >= 0, use this beta for the preconditioner (L and R in paper, state["GG"] below) moving average instead of betas[1].
|
||||
eps (`float`, *optional*, defaults to 1e-08):
|
||||
Adam's epsilon for numerical stability.
|
||||
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
|
||||
precondition_frequency (`int`, *optional*, defaults to 10):
|
||||
How often to update the preconditioner.
|
||||
max_precond_dim (`int`, *optional*, defaults to 10000):
|
||||
Maximum dimension of the preconditioner.
|
||||
Set to 10000, so that we exclude most common vocab sizes while including layers.
|
||||
merge_dims (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to merge dimensions of the preconditioner.
|
||||
precondition_1d (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to precondition 1D gradients.
|
||||
normalize_grads (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to normalize gradients per layer.
|
||||
Helps at large precondition_frequency (~100 in our experiments),
|
||||
but hurts performance at small precondition_frequency (~10 in our experiments).
|
||||
data_format (`str`, *optional*, defaults to `channels_first`):
|
||||
Data format of the input for convolutional layers.
|
||||
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
|
||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use bias correction in Adam.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr: float = 3e-3,
|
||||
betas=(0.95, 0.95),
|
||||
shampoo_beta: float = -1,
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0.01,
|
||||
precondition_frequency: int = 10,
|
||||
max_precond_dim: int = 10000, #
|
||||
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
|
||||
precondition_1d: bool = False,
|
||||
normalize_grads: bool = False,
|
||||
data_format: str = "channels_first",
|
||||
correct_bias: bool = True,
|
||||
):
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"betas": betas,
|
||||
"shampoo_beta": shampoo_beta,
|
||||
"eps": eps,
|
||||
"weight_decay": weight_decay,
|
||||
"precondition_frequency": precondition_frequency,
|
||||
"max_precond_dim": max_precond_dim,
|
||||
"merge_dims": merge_dims,
|
||||
"precondition_1d": precondition_1d,
|
||||
"normalize_grads": normalize_grads,
|
||||
"correct_bias": correct_bias,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
self._data_format = data_format
|
||||
|
||||
def merge_dims(self, grad, max_precond_dim):
|
||||
"""
|
||||
Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
|
||||
"""
|
||||
assert self._data_format in ["channels_first", "channels_last"]
|
||||
if self._data_format == "channels_last" and grad.dim() == 4:
|
||||
grad = grad.permute(0, 3, 1, 2)
|
||||
shape = grad.shape
|
||||
new_shape = []
|
||||
|
||||
curr_shape = 1
|
||||
for sh in shape:
|
||||
temp_shape = curr_shape * sh
|
||||
if temp_shape > max_precond_dim:
|
||||
if curr_shape > 1:
|
||||
new_shape.append(curr_shape)
|
||||
curr_shape = sh
|
||||
else:
|
||||
new_shape.append(sh)
|
||||
curr_shape = 1
|
||||
else:
|
||||
curr_shape = temp_shape
|
||||
|
||||
if curr_shape > 1 or len(new_shape) == 0:
|
||||
new_shape.append(curr_shape)
|
||||
|
||||
new_grad = grad.reshape(new_shape)
|
||||
return new_grad
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
if closure is None:
|
||||
loss = None
|
||||
else:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
# State initialization
|
||||
if "exp_avg" not in state:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
if "Q" not in state:
|
||||
self.init_preconditioner(
|
||||
grad,
|
||||
state,
|
||||
precondition_frequency=group["precondition_frequency"],
|
||||
precondition_1d=group["precondition_1d"],
|
||||
shampoo_beta=(
|
||||
group["shampoo_beta"]
|
||||
if group["shampoo_beta"] >= 0
|
||||
else group["betas"][1]
|
||||
),
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
merge_dims=group["merge_dims"],
|
||||
)
|
||||
self.update_preconditioner(
|
||||
grad,
|
||||
state,
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
merge_dims=group["merge_dims"],
|
||||
precondition_1d=group["precondition_1d"],
|
||||
)
|
||||
continue # first step is skipped so that we never use the current gradients in the projection.
|
||||
|
||||
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
||||
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
||||
grad_projected = self.project(
|
||||
grad,
|
||||
state,
|
||||
merge_dims=group["merge_dims"],
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
|
||||
exp_avg_sq.mul_(beta2).add_(
|
||||
grad_projected.square(), alpha=(1.0 - beta2)
|
||||
)
|
||||
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
||||
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
||||
# exp_avg_projected = self.project(
|
||||
# exp_avg,
|
||||
# state,
|
||||
# merge_dims=group["merge_dims"],
|
||||
# max_precond_dim=group["max_precond_dim"],
|
||||
# )
|
||||
exp_avg_projected = exp_avg
|
||||
|
||||
step_size = group["lr"]
|
||||
if group["correct_bias"]:
|
||||
bias_correction1 = 1.0 - beta1 ** (state["step"])
|
||||
bias_correction2 = 1.0 - beta2 ** (state["step"])
|
||||
step_size = step_size * (bias_correction2**0.5) / bias_correction1
|
||||
|
||||
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
||||
# to the original space
|
||||
norm_grad = self.project_back(
|
||||
exp_avg_projected / denom,
|
||||
state,
|
||||
merge_dims=group["merge_dims"],
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
)
|
||||
|
||||
if group["normalize_grads"]:
|
||||
norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5)
|
||||
|
||||
p.add_(norm_grad, alpha=-step_size)
|
||||
|
||||
# From AdamW code: Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
|
||||
# Update is done after the gradient step to avoid using current gradients in the projection.
|
||||
self.update_preconditioner(
|
||||
grad,
|
||||
state,
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
merge_dims=group["merge_dims"],
|
||||
precondition_1d=group["precondition_1d"],
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def init_preconditioner(
|
||||
self,
|
||||
grad,
|
||||
state,
|
||||
precondition_frequency=10,
|
||||
shampoo_beta=0.95,
|
||||
max_precond_dim=10000,
|
||||
precondition_1d=False,
|
||||
merge_dims=False,
|
||||
):
|
||||
"""
|
||||
Initializes the preconditioner matrices (L and R in the paper).
|
||||
"""
|
||||
state["GG"] = (
|
||||
[]
|
||||
) # Will hold all the preconditioner matrices (L and R in the paper).
|
||||
if grad.dim() == 1:
|
||||
if not precondition_1d or grad.shape[0] > max_precond_dim:
|
||||
state["GG"].append([])
|
||||
else:
|
||||
state["GG"].append(
|
||||
torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)
|
||||
)
|
||||
else:
|
||||
if merge_dims:
|
||||
grad = self.merge_dims(grad, max_precond_dim)
|
||||
|
||||
for sh in grad.shape:
|
||||
if sh > max_precond_dim:
|
||||
state["GG"].append([])
|
||||
else:
|
||||
state["GG"].append(torch.zeros(sh, sh, device=grad.device))
|
||||
|
||||
state["Q"] = None # Will hold all the eigenbases of the preconditioner.
|
||||
state["precondition_frequency"] = precondition_frequency
|
||||
state["shampoo_beta"] = shampoo_beta
|
||||
|
||||
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
||||
"""
|
||||
Projects the gradient to the eigenbases of the preconditioner.
|
||||
"""
|
||||
original_shape = grad.shape
|
||||
if merge_dims:
|
||||
if grad.dim() == 4 and self._data_format == "channels_last":
|
||||
permuted_shape = grad.permute(0, 3, 1, 2).shape
|
||||
grad = self.merge_dims(grad, max_precond_dim)
|
||||
|
||||
for mat in state["Q"]:
|
||||
if len(mat) > 0:
|
||||
grad = torch.tensordot(
|
||||
grad,
|
||||
mat,
|
||||
dims=[[0], [0]],
|
||||
)
|
||||
else:
|
||||
permute_order = list(range(1, len(grad.shape))) + [0]
|
||||
grad = grad.permute(permute_order)
|
||||
|
||||
if merge_dims:
|
||||
if self._data_format == "channels_last" and len(original_shape) == 4:
|
||||
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||
else:
|
||||
grad = grad.reshape(original_shape)
|
||||
return grad
|
||||
|
||||
def update_preconditioner(
|
||||
self,
|
||||
grad,
|
||||
state,
|
||||
max_precond_dim=10000,
|
||||
merge_dims=False,
|
||||
precondition_1d=False,
|
||||
):
|
||||
"""
|
||||
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
||||
"""
|
||||
if state["Q"] is not None:
|
||||
state["exp_avg"] = self.project_back(
|
||||
state["exp_avg"],
|
||||
state,
|
||||
merge_dims=merge_dims,
|
||||
max_precond_dim=max_precond_dim,
|
||||
)
|
||||
if grad.dim() == 1:
|
||||
if precondition_1d and grad.shape[0] <= max_precond_dim:
|
||||
state["GG"][0].lerp_(
|
||||
grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]
|
||||
)
|
||||
else:
|
||||
if merge_dims:
|
||||
new_grad = self.merge_dims(grad, max_precond_dim)
|
||||
for idx, sh in enumerate(new_grad.shape):
|
||||
if sh <= max_precond_dim:
|
||||
outer_product = torch.tensordot(
|
||||
new_grad,
|
||||
new_grad,
|
||||
dims=[
|
||||
[
|
||||
*chain(
|
||||
range(idx), range(idx + 1, len(new_grad.shape))
|
||||
)
|
||||
]
|
||||
]
|
||||
* 2,
|
||||
)
|
||||
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
|
||||
else:
|
||||
for idx, sh in enumerate(grad.shape):
|
||||
if sh <= max_precond_dim:
|
||||
outer_product = torch.tensordot(
|
||||
grad,
|
||||
grad,
|
||||
# Contracts across all dimensions except for k.
|
||||
dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]]
|
||||
* 2,
|
||||
)
|
||||
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
|
||||
|
||||
if state["Q"] is None:
|
||||
state["Q"] = self.get_orthogonal_matrix(state["GG"])
|
||||
if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0:
|
||||
state["Q"] = self.get_orthogonal_matrix_QR(
|
||||
state, max_precond_dim, merge_dims
|
||||
)
|
||||
# state["Q"] = self.get_fast_QR(state, max_precond_dim, merge_dims)
|
||||
|
||||
if state["step"] > 0:
|
||||
state["exp_avg"] = self.project(
|
||||
state["exp_avg"],
|
||||
state,
|
||||
merge_dims=merge_dims,
|
||||
max_precond_dim=max_precond_dim,
|
||||
)
|
||||
|
||||
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
||||
"""
|
||||
Projects the gradient back to the original space.
|
||||
"""
|
||||
original_shape = grad.shape
|
||||
if merge_dims:
|
||||
if self._data_format == "channels_last" and grad.dim() == 4:
|
||||
permuted_shape = grad.permute(0, 3, 1, 2).shape
|
||||
grad = self.merge_dims(grad, max_precond_dim)
|
||||
for mat in state["Q"]:
|
||||
if len(mat) > 0:
|
||||
grad = torch.tensordot(
|
||||
grad,
|
||||
mat,
|
||||
dims=[[0], [1]],
|
||||
)
|
||||
else:
|
||||
permute_order = list(range(1, len(grad.shape))) + [0]
|
||||
grad = grad.permute(permute_order)
|
||||
|
||||
if merge_dims:
|
||||
if self._data_format == "channels_last" and len(original_shape) == 4:
|
||||
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||
else:
|
||||
grad = grad.reshape(original_shape)
|
||||
return grad
|
||||
|
||||
def get_orthogonal_matrix(self, mat):
|
||||
"""
|
||||
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
||||
"""
|
||||
matrix = []
|
||||
for m in mat:
|
||||
if len(m) == 0:
|
||||
matrix.append([])
|
||||
continue
|
||||
if m.data.dtype != torch.float:
|
||||
float_data = False
|
||||
original_type = m.data.dtype
|
||||
original_device = m.data.device
|
||||
matrix.append(m.data.float())
|
||||
else:
|
||||
float_data = True
|
||||
matrix.append(m.data)
|
||||
|
||||
final = []
|
||||
for m in matrix:
|
||||
if len(m) == 0:
|
||||
final.append([])
|
||||
continue
|
||||
try:
|
||||
_, Q = torch.linalg.eigh(
|
||||
m + 1e-30 * torch.eye(m.shape[0], device=m.device)
|
||||
)
|
||||
except: # pylint: disable=bare-except # noqa: E722
|
||||
_, Q = torch.linalg.eigh(
|
||||
m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)
|
||||
)
|
||||
Q = Q.to(m.dtype)
|
||||
Q = torch.flip(Q, [1])
|
||||
|
||||
if not float_data:
|
||||
Q = Q.to(original_device).type(original_type)
|
||||
final.append(Q)
|
||||
return final
|
||||
|
||||
def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
|
||||
"""
|
||||
Computes the eigenbases of the preconditioner using one round of power iteration
|
||||
followed by torch.linalg.qr decomposition.
|
||||
"""
|
||||
precond_list = state["GG"]
|
||||
orth_list = state["Q"]
|
||||
|
||||
matrix = []
|
||||
orth_matrix = []
|
||||
for m, o in zip(precond_list, orth_list):
|
||||
if len(m) == 0:
|
||||
matrix.append([])
|
||||
orth_matrix.append([])
|
||||
continue
|
||||
if m.data.dtype != torch.float:
|
||||
float_data = False
|
||||
original_type = m.data.dtype
|
||||
original_device = m.data.device
|
||||
matrix.append(m.data.float())
|
||||
orth_matrix.append(o.data.float())
|
||||
else:
|
||||
float_data = True
|
||||
matrix.append(m.data.float())
|
||||
orth_matrix.append(o.data.float())
|
||||
|
||||
orig_shape = state["exp_avg_sq"].shape
|
||||
if self._data_format == "channels_last" and len(orig_shape) == 4:
|
||||
permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape
|
||||
if merge_dims:
|
||||
exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim)
|
||||
else:
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
|
||||
final = []
|
||||
for ind, (m, o) in enumerate(zip(matrix, orth_matrix)):
|
||||
if len(m) == 0:
|
||||
final.append([])
|
||||
continue
|
||||
est_eig = torch.diag(o.T @ m @ o)
|
||||
sort_idx = torch.argsort(est_eig, descending=True)
|
||||
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
||||
o = o[:, sort_idx]
|
||||
power_iter = m @ o
|
||||
Q, _ = torch.linalg.qr(power_iter)
|
||||
|
||||
if not float_data:
|
||||
Q = Q.to(original_device).type(original_type)
|
||||
final.append(Q)
|
||||
|
||||
if merge_dims:
|
||||
if self._data_format == "channels_last" and len(orig_shape) == 4:
|
||||
exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||
else:
|
||||
exp_avg_sq = exp_avg_sq.reshape(orig_shape)
|
||||
|
||||
state["exp_avg_sq"] = exp_avg_sq
|
||||
return final
|
||||
@@ -8,13 +8,11 @@ from typing import Any, Iterable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
LOG.setLevel(logging.INFO)
|
||||
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
||||
|
||||
|
||||
@numba.njit
|
||||
@@ -105,55 +103,6 @@ def allocate(
|
||||
return result, s, len(result) * c * n
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
|
||||
"""
|
||||
Sequential allocator that preserves example order
|
||||
|
||||
Parameters:
|
||||
- lengths: The lengths of all examples
|
||||
- rank: The current rank (for distributed training)
|
||||
- c: The capacity of each bin (maximum sequence length)
|
||||
- n: Number of ranks
|
||||
|
||||
Returns:
|
||||
- result: List of batches for the current rank
|
||||
- total_used: Number of actual example tokens
|
||||
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
|
||||
"""
|
||||
result = []
|
||||
total_used = 0
|
||||
|
||||
# First, do sequential packing into bins
|
||||
all_bins = []
|
||||
current_bin = [0 for i in range(0)] # numba hint
|
||||
remaining_capacity = c
|
||||
|
||||
for idx, size in enumerate(lengths):
|
||||
if size <= remaining_capacity:
|
||||
# Example fits in current bin
|
||||
current_bin.append(idx)
|
||||
remaining_capacity -= size
|
||||
total_used += size
|
||||
else:
|
||||
# Example doesn't fit, start a new bin
|
||||
if current_bin: # Add non-empty bin to all_bins
|
||||
all_bins.append(current_bin)
|
||||
current_bin = [idx]
|
||||
remaining_capacity = c - size
|
||||
total_used += size
|
||||
|
||||
# Add the last bin if not empty
|
||||
if current_bin:
|
||||
all_bins.append(current_bin)
|
||||
|
||||
# Assign bins to ranks - each rank gets every n-th bin
|
||||
for bin_idx in range(rank, len(all_bins), n):
|
||||
result.append(all_bins[bin_idx])
|
||||
|
||||
return result, total_used, len(all_bins) * c
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""Batch sampler class for multipack"""
|
||||
|
||||
@@ -166,7 +115,6 @@ class MultipackBatchSampler(BatchSampler):
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
drop_last: bool = False,
|
||||
num_count_samples: int = 16,
|
||||
sequential: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
@@ -174,7 +122,6 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.sequential = sequential
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
@@ -189,11 +136,6 @@ class MultipackBatchSampler(BatchSampler):
|
||||
# the minimum packed dataset length across all ranks determined by a gather/broadcast
|
||||
self.len_across_ranks = None
|
||||
|
||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||
LOG.warn(
|
||||
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
||||
)
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
@@ -203,21 +145,13 @@ class MultipackBatchSampler(BatchSampler):
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
if self.sequential:
|
||||
batches, total_used, total_slots = allocate_sequentially(
|
||||
lengths=lengths,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
else:
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [
|
||||
[
|
||||
|
||||
@@ -46,7 +46,6 @@ from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -87,9 +86,6 @@ class AxolotlInputConfig(
|
||||
trl: TRLConfig | None = Field(
|
||||
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
reward_model: bool | None = None
|
||||
process_reward_model: bool | None = None
|
||||
num_labels: int | None = None
|
||||
@@ -192,7 +188,6 @@ class AxolotlInputConfig(
|
||||
sample_packing: bool | None = None
|
||||
sample_packing_group_size: int | None = 100_000
|
||||
sample_packing_bin_size: int | None = 200
|
||||
sample_packing_sequentially: bool | None = None
|
||||
eval_sample_packing: bool | None = None
|
||||
pad_to_sequence_len: bool | None = None
|
||||
curriculum_sampling: bool | None = None
|
||||
@@ -253,7 +248,6 @@ class AxolotlInputConfig(
|
||||
val_set_size: float | None = Field(default=0.0)
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
heads_k_stride: int | None = None
|
||||
|
||||
special_tokens: SpecialTokensConfig | None = None
|
||||
tokens: list[str] | None = None
|
||||
@@ -1114,7 +1108,7 @@ class AxolotlInputConfig(
|
||||
|
||||
@field_validator("sequence_parallel_degree", mode="before")
|
||||
@classmethod
|
||||
def check_sequence_parallel_degree(cls, value, info):
|
||||
def check_sequence_parallel_config(cls, value, info):
|
||||
if not value:
|
||||
value = 1
|
||||
|
||||
@@ -1135,17 +1129,6 @@ class AxolotlInputConfig(
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
if data.get("optimizer") == "muon" and (
|
||||
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
|
||||
):
|
||||
raise ValueError(
|
||||
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
@@ -1281,12 +1264,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data["beta"] != data["trl"]["beta"]:
|
||||
raise ValueError("beta and trl.beta must match or one must be removed")
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_min_torch_version(self):
|
||||
if self.env_capabilities and self.env_capabilities.torch_version:
|
||||
torch_version = self.env_capabilities.torch_version
|
||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||
LOG.warning(
|
||||
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
|
||||
)
|
||||
|
||||
@@ -52,3 +52,4 @@ class CustomSupportedOptimizers(str, Enum):
|
||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||
muon = "muon" # pylint: disable=invalid-name
|
||||
soap = "soap" # pylint: disable=invalid-name
|
||||
|
||||
@@ -20,30 +20,27 @@ class TRLConfig(BaseModel):
|
||||
)
|
||||
|
||||
# GRPO specific args
|
||||
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
|
||||
use_vllm: bool = Field(
|
||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||
use_vllm: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||
)
|
||||
vllm_server_host: str | None = Field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
json_schema_extra={"description": "Host of the vLLM server to connect to"},
|
||||
vllm_device: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Device to use for VLLM"},
|
||||
)
|
||||
vllm_server_port: int | None = Field(
|
||||
default=8000,
|
||||
json_schema_extra={"description": "Port of the vLLM server to connect to"},
|
||||
vllm_gpu_memory_utilization: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
)
|
||||
vllm_server_timeout: int | None = Field(
|
||||
vllm_dtype: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Data type for VLLM"},
|
||||
)
|
||||
vllm_max_model_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
|
||||
"after the timeout, a `ConnectionError` is raised."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
|
||||
"description": "Maximum length of the model context for VLLM"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -88,48 +85,3 @@ class TRLConfig(BaseModel):
|
||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
scale_rewards: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
|
||||
},
|
||||
)
|
||||
|
||||
temperature: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Sampling temperature for the GRPO policy."},
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Top-p sampling probability for the generation policy."
|
||||
},
|
||||
)
|
||||
top_k: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Top-k sampling for the generation policy."},
|
||||
)
|
||||
min_p: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Minimum probability for the generation policy."
|
||||
},
|
||||
)
|
||||
repetition_penalty: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
|
||||
},
|
||||
)
|
||||
num_iterations: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
|
||||
},
|
||||
)
|
||||
epsilon: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Epsilon value for clipping in the GRPO algorithm."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
"""
|
||||
Pydantic models for VLLM configuration, used primarily for RL training with TRL + grpo
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VllmConfig(BaseModel):
|
||||
"""
|
||||
Configuration for VLLM server
|
||||
"""
|
||||
|
||||
device: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Device to use for VLLM"},
|
||||
)
|
||||
tensor_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Tensor parallel size for VLLM"},
|
||||
)
|
||||
gpu_memory_utilization: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
)
|
||||
dtype: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Data type for VLLM"},
|
||||
)
|
||||
max_model_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the model context for VLLM"
|
||||
},
|
||||
)
|
||||
enable_prefix_caching: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Enable prefix caching for VLLM"},
|
||||
)
|
||||
@@ -13,7 +13,7 @@ import torch
|
||||
import torch.cuda
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
@@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if cfg.model_config_type in ["mamba", "gemma3"]:
|
||||
if cfg.model_config_type == "mamba":
|
||||
LOG.info("dropping attention_mask column")
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
@@ -456,18 +456,13 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
else:
|
||||
sampler_batch_size = cfg.micro_batch_size
|
||||
batch_max_len = cfg.sequence_len
|
||||
if cfg.curriculum_sampling:
|
||||
sampler = SequentialSampler(train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=sampler,
|
||||
sampler=RandomSampler(train_dataset),
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
batch_size=sampler_batch_size,
|
||||
batch_max_len=batch_max_len,
|
||||
group_size=cfg.sample_packing_group_size,
|
||||
bin_size=cfg.sample_packing_bin_size,
|
||||
sequential=cfg.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,13 +8,11 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
|
||||
@@ -50,14 +48,6 @@ def snapshot_download_w_retry(*args, **kwargs):
|
||||
return snapshot_download(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_ds_fixture_bundle():
|
||||
ds_dir = snapshot_download_w_retry(
|
||||
"axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
|
||||
)
|
||||
return Path(ds_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_smollm2_135m_model():
|
||||
# download the model
|
||||
@@ -111,50 +101,42 @@ def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
# @pytest.fixture(scope="session", autouse=True)
|
||||
# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||
# # download the dataset
|
||||
# snapshot_download_w_retry(
|
||||
# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||
# )
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_fozzie_alpaca_dpo_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
||||
)
|
||||
snapshot_download_w_retry(
|
||||
"fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
repo_type="dataset",
|
||||
revision="ea82cff",
|
||||
)
|
||||
|
||||
|
||||
# @pytest.fixture(scope="session", autouse=True)
|
||||
# def download_fozzie_alpaca_dpo_dataset():
|
||||
# # download the dataset
|
||||
# snapshot_download_w_retry(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
||||
# )
|
||||
# snapshot_download_w_retry(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
# repo_type="dataset",
|
||||
# revision="ea82cff",
|
||||
# )
|
||||
@pytest.fixture(scope="session")
|
||||
@disable_hf_offline
|
||||
def dataset_fozzie_alpaca_dpo_dataset(
|
||||
download_fozzie_alpaca_dpo_dataset,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
|
||||
|
||||
# @pytest.fixture(scope="session")
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(scope="session")
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# return load_dataset(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
# )
|
||||
@pytest.fixture(scope="session")
|
||||
@disable_hf_offline
|
||||
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
download_fozzie_alpaca_dpo_dataset,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return load_dataset(
|
||||
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -281,7 +263,7 @@ def download_mlx_mistral_7b_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama2_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
@@ -291,7 +273,7 @@ def download_llama2_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama(
|
||||
download_huggyllama_model_fixture,
|
||||
@@ -302,57 +284,6 @@ def tokenizer_huggyllama(
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama_w_special_tokens(
|
||||
tokenizer_huggyllama,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
tokenizer_huggyllama.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
return tokenizer_huggyllama
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_llama2_7b(
|
||||
download_llama2_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_mistral_7b_instruct(
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
||||
tokenizer_mistral_7b_instruct.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer_mistral_7b_instruct.add_tokens(
|
||||
[
|
||||
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
return tokenizer_mistral_7b_instruct
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
# Create a temporary directory
|
||||
@@ -418,60 +349,6 @@ def cleanup_monkeypatches():
|
||||
globals().pop(module_global, None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_winglian_tiny_shakespeare(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
|
||||
return datasets.load_from_disk(ds_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_tatsu_lab_alpaca(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_mhenrichsen_alpaca_2k_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
|
||||
)
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
|
||||
)
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
# # pylint: disable=redefined-outer-name,unused-argument
|
||||
# def test_load_fixtures(
|
||||
# download_smollm2_135m_model,
|
||||
|
||||
@@ -1,294 +0,0 @@
|
||||
"""
|
||||
GRPO test suite
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import subprocess # nosec B404
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import require_vllm
|
||||
|
||||
|
||||
def start_vllm(
|
||||
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
helper function to start the VLLM server in the background, mostly for testing purposes
|
||||
"""
|
||||
cmd = [sys.executable, "-m", "trl.scripts.vllm_serve", "--model", model]
|
||||
|
||||
if tensor_parallel_size := kwargs.get("tensor_parallel_size"):
|
||||
cmd.extend(["--tensor-parallel-size", str(tensor_parallel_size)])
|
||||
if host := kwargs.get("host"):
|
||||
cmd.extend(["--host", host])
|
||||
if port := kwargs.get("port"):
|
||||
cmd.extend(["--port", str(port)])
|
||||
if gpu_memory_utilization := kwargs.get("gpu_memory_utilization"):
|
||||
cmd.extend(["--gpu-memory-utilization", str(gpu_memory_utilization)])
|
||||
if dtype := kwargs.get("dtype"):
|
||||
cmd.extend(["--dtype", dtype])
|
||||
if max_model_len := kwargs.get("max_model_len"):
|
||||
cmd.extend(["--max-model-len", str(max_model_len)])
|
||||
if kwargs.get("enable_prefix_caching"):
|
||||
cmd.extend(["--enable-prefix-caching", "True"])
|
||||
|
||||
# print out the command to be executed
|
||||
print(" ".join(cmd))
|
||||
|
||||
# start `trl vllm-serve` command in the background and capture the process id
|
||||
process = subprocess.Popen( # pylint: disable=consider-using-with
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
|
||||
) # nosec B603
|
||||
|
||||
# print out the process id so the user can easily kill it later
|
||||
print(f"VLLM server process started (PID: {process.pid})")
|
||||
|
||||
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
|
||||
started = False
|
||||
if wait and host and port:
|
||||
for _ in range(int(wait)):
|
||||
try:
|
||||
response = requests.get(f"http://{host}:{port}", timeout=1)
|
||||
if int(response.status_code) in [200, 404]:
|
||||
started = True
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
# also check if the process.pid is still running
|
||||
if not process.poll() is None:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if wait and not started:
|
||||
print(
|
||||
f"VLLM server process did not start within {wait} seconds. Please check your server logs."
|
||||
)
|
||||
process.kill()
|
||||
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
|
||||
|
||||
# return the process id
|
||||
return process.pid
|
||||
|
||||
|
||||
class TestGRPO:
|
||||
"""
|
||||
Test case for GRPO training using multilpe GPUs
|
||||
"""
|
||||
|
||||
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
|
||||
fout.write(
|
||||
"""import random
|
||||
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||
return [random.uniform(0, 1) for _ in completions]
|
||||
|
||||
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]},],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
"""
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
)
|
||||
@require_vllm
|
||||
def test_llama_dora(self, temp_dir, num_gpus):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"vllm": {
|
||||
"max_model_len": 800,
|
||||
"enable_prefix_caching": True,
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"peft_use_dora": True,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 3,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
current_env = os.environ.copy()
|
||||
env = {
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
}
|
||||
vllm_process_id = start_vllm(
|
||||
cfg.base_model,
|
||||
env=env,
|
||||
quiet=True,
|
||||
wait=120,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=cfg.vllm.max_model_len,
|
||||
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
)
|
||||
|
||||
try:
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(num_gpus),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
],
|
||||
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
|
||||
)
|
||||
finally:
|
||||
os.kill(vllm_process_id, 9)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
)
|
||||
@require_vllm
|
||||
def test_llama_fft(self, temp_dir, num_gpus):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"vllm": {
|
||||
"max_model_len": 800,
|
||||
"enable_prefix_caching": True,
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 3,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
current_env = os.environ.copy()
|
||||
env = {
|
||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
}
|
||||
vllm_process_id = start_vllm(
|
||||
cfg.base_model,
|
||||
env=env,
|
||||
quiet=True,
|
||||
wait=120,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=cfg.vllm.max_model_len,
|
||||
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
)
|
||||
|
||||
try:
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(num_gpus),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
],
|
||||
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
|
||||
)
|
||||
finally:
|
||||
os.kill(vllm_process_id, 9)
|
||||
@@ -52,9 +52,9 @@ class TestMultiGPUEval:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -121,9 +121,9 @@ class TestMultiGPUEval:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
"""
|
||||
E2E tests for multigpu lora tinyllama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_model():
|
||||
# download the model
|
||||
snapshot_download("axolotl-mirrors/gemma-3-4b-pt", repo_type="model")
|
||||
|
||||
|
||||
class TestMultiGPUGemma3:
|
||||
"""
|
||||
Test case for Gemma3 models using LoRA
|
||||
"""
|
||||
|
||||
def test_lora_ddp_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-mirrors/gemma-3-4b-pt",
|
||||
"sequence_len": 2048,
|
||||
"ddp_find_unused_parameters": True,
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.0,
|
||||
"chat_template": "gemma3",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {
|
||||
"use_reentrant": False,
|
||||
},
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
|
||||
)
|
||||
175
tests/e2e/multigpu/test_grpo.py
Normal file
175
tests/e2e/multigpu/test_grpo.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
GRPO test suite
|
||||
"""
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import require_vllm
|
||||
|
||||
|
||||
class TestGRPO:
|
||||
"""
|
||||
Test case for GRPO training using multilpe GPUs
|
||||
"""
|
||||
|
||||
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
|
||||
fout.write(
|
||||
"""import random
|
||||
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||
return [random.uniform(0, 1) for _ in completions]
|
||||
|
||||
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]},],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
"""
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
)
|
||||
@require_vllm
|
||||
def test_llama_dora(self, temp_dir, num_gpus):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
|
||||
"vllm_gpu_memory_utilization": 0.15,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"peft_use_dora": True,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 5,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(num_gpus),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
)
|
||||
@require_vllm
|
||||
def test_llama_fft(self, temp_dir, num_gpus):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
|
||||
"vllm_gpu_memory_utilization": 0.15,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 5,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(num_gpus),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
@@ -58,7 +58,6 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -122,7 +121,6 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -195,7 +193,6 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
@@ -273,7 +270,6 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
@@ -334,7 +330,6 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
@@ -404,8 +399,7 @@ class TestMultiGPULlama:
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
@@ -484,8 +478,7 @@ class TestMultiGPULlama:
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
@@ -785,10 +778,9 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
|
||||
@@ -46,7 +46,7 @@ class TestMultiGPUQwen2:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 5,
|
||||
"warmup_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestMultiGPURay:
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
|
||||
@@ -110,7 +110,7 @@ class TestRingAttention:
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
|
||||
register_ring_attn(sequence_parallel_degree=4)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
|
||||
@@ -201,3 +201,46 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_soap(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "soap",
|
||||
"adam_beta1": 0.9,
|
||||
"adam_beta2": 0.95,
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -324,7 +324,7 @@ class TestDatasetPreparation:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_hub_with_revision_with_dpo(
|
||||
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||
):
|
||||
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||
|
||||
@@ -339,10 +339,12 @@ class TestDatasetPreparation:
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
@@ -352,9 +354,7 @@ class TestDatasetPreparation:
|
||||
|
||||
@enable_hf_offline
|
||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||
def test_load_local_hub_with_revision(
|
||||
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer
|
||||
):
|
||||
def test_load_local_hub_with_revision(self, tokenizer):
|
||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
@@ -386,23 +386,13 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_loading_local_dataset_folder(self, tokenizer):
|
||||
|
||||
@@ -238,22 +238,21 @@ class TestDeduplicateRLDataset:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_with_deduplication(
|
||||
self,
|
||||
cfg,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
tokenizer_huggyllama,
|
||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||
):
|
||||
"""Verify that loading with deduplication removes duplicates."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
@@ -264,20 +263,19 @@ class TestDeduplicateRLDataset:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_without_deduplication(
|
||||
self,
|
||||
cfg,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
tokenizer_huggyllama,
|
||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
|
||||
import pytest
|
||||
from datasets import concatenate_datasets
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -27,6 +27,7 @@ class TestBatchedSamplerPacking:
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size, num_workers",
|
||||
[
|
||||
@@ -37,20 +38,14 @@ class TestBatchedSamplerPacking:
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
||||
@pytest.mark.parametrize("sequential", [True, False])
|
||||
@enable_hf_offline
|
||||
def test_packing(
|
||||
self,
|
||||
dataset_winglian_tiny_shakespeare,
|
||||
batch_size,
|
||||
num_workers,
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
sequential,
|
||||
):
|
||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
dataset = dataset_winglian_tiny_shakespeare["train"]
|
||||
dataset = load_dataset(
|
||||
"winglian/tiny-shakespeare",
|
||||
split="train",
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -60,7 +55,7 @@ class TestBatchedSamplerPacking:
|
||||
)
|
||||
ds_cfg = DictDefault(
|
||||
{
|
||||
"field": "text",
|
||||
"field": "Text",
|
||||
}
|
||||
)
|
||||
completion_strategy = load(tokenizer, cfg, ds_cfg)
|
||||
@@ -80,7 +75,6 @@ class TestBatchedSamplerPacking:
|
||||
batch_max_len=max_seq_length,
|
||||
group_size=100000,
|
||||
bin_size=200,
|
||||
sequential=sequential,
|
||||
)
|
||||
|
||||
loader = DataLoader(
|
||||
|
||||
@@ -2,8 +2,13 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
InstructionWSystemPromptTokenizingStrategy,
|
||||
@@ -56,13 +61,24 @@ test_data = {
|
||||
}
|
||||
|
||||
|
||||
class TestPromptTokenizationStrategies:
|
||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
"""
|
||||
Test class for prompt tokenization strategies.
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens):
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
@@ -70,7 +86,7 @@ class TestPromptTokenizationStrategies:
|
||||
# pylint: disable=duplicate-code
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
tokenizer_huggyllama_w_special_tokens,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
@@ -83,8 +99,7 @@ class TestPromptTokenizationStrategies:
|
||||
assert example["labels"][world_idx] == 3186
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
@enable_hf_offline
|
||||
def test_alpaca(self, tokenizer_huggyllama_w_special_tokens):
|
||||
def test_alpaca(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
@@ -92,7 +107,7 @@ class TestPromptTokenizationStrategies:
|
||||
prompter = AlpacaPrompter()
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
tokenizer_huggyllama_w_special_tokens,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
@@ -103,17 +118,28 @@ class TestPromptTokenizationStrategies:
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
|
||||
class TestInstructionWSystemPromptTokenizingStrategy:
|
||||
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
||||
"""
|
||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens):
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
def test_system_alpaca(self):
|
||||
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
|
||||
strat = InstructionWSystemPromptTokenizingStrategy(
|
||||
prompter,
|
||||
tokenizer_huggyllama_w_special_tokens,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
@@ -134,13 +160,18 @@ class TestInstructionWSystemPromptTokenizingStrategy:
|
||||
assert example["input_ids"][8] == 11889 # USER
|
||||
|
||||
|
||||
class Llama2ChatTokenizationTest:
|
||||
class Llama2ChatTokenizationTest(unittest.TestCase):
|
||||
"""
|
||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def test_llama2_chat_integration(self, tokenizer_llama2_7b):
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
# woraround because official Meta repos are not open
|
||||
|
||||
def test_llama2_chat_integration(self):
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
@@ -155,18 +186,16 @@ class Llama2ChatTokenizationTest:
|
||||
prompter = Llama2ChatPrompter()
|
||||
strat = LLama2ChatTokenizingStrategy(
|
||||
prompter,
|
||||
tokenizer_llama2_7b,
|
||||
self.tokenizer,
|
||||
False,
|
||||
4096,
|
||||
)
|
||||
example = strat.tokenize_prompt(conversation)
|
||||
for fields in ["input_ids", "attention_mask", "labels"]:
|
||||
# pytest assert equals
|
||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||
|
||||
assert len(example[fields]) == len(tokenized_conversation[fields])
|
||||
assert example[fields] == tokenized_conversation[fields]
|
||||
|
||||
def compare_with_transformers_integration(self, tokenizer_llama2_7b):
|
||||
def compare_with_transformers_integration(self):
|
||||
# this needs transformers >= v4.31.0
|
||||
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
@@ -205,27 +234,49 @@ If a question does not make any sense, or is not factually coherent, explain why
|
||||
generated_responses=answers,
|
||||
)
|
||||
# pylint: disable=W0212
|
||||
hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
|
||||
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
|
||||
|
||||
assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]
|
||||
self.assertEqual(
|
||||
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
|
||||
)
|
||||
|
||||
|
||||
class OrpoTokenizationTest:
|
||||
class OrpoTokenizationTest(unittest.TestCase):
|
||||
"""test case for the ORPO tokenization"""
|
||||
|
||||
@enable_hf_offline
|
||||
def test_orpo_integration(
|
||||
self,
|
||||
tokenizer_mistral_7b_instruct_chatml,
|
||||
dataset_argilla_ultrafeedback_binarized_preferences_cleaned,
|
||||
):
|
||||
ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0])
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(
|
||||
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
|
||||
),
|
||||
]
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset = load_dataset(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
||||
).select([0])
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
def test_orpo_integration(self):
|
||||
strat = load(
|
||||
tokenizer_mistral_7b_instruct_chatml,
|
||||
self.tokenizer,
|
||||
DictDefault({"train_on_inputs": False}),
|
||||
DictDefault({"chat_template": "chatml"}),
|
||||
)
|
||||
res = strat.tokenize_prompt(ds[0])
|
||||
res = strat.tokenize_prompt(self.dataset[0])
|
||||
assert "rejected_input_ids" in res
|
||||
assert "rejected_labels" in res
|
||||
assert "input_ids" in res
|
||||
@@ -244,3 +295,7 @@ class OrpoTokenizationTest:
|
||||
|
||||
assert res["prompt_attention_mask"][0] == 1
|
||||
assert res["prompt_attention_mask"][-1] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -321,48 +321,3 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class TestOptimizerValidation(BaseValidation):
|
||||
"""
|
||||
Test muon optimizer validation
|
||||
"""
|
||||
|
||||
def test_muon_deepspeed(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"optimizer": "muon",
|
||||
"deepspeed": "deepspeed_configs/zero3.json",
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_muon_fsdp(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"optimizer": "muon",
|
||||
"fsdp": ["full_shard"],
|
||||
"fsdp_config": {
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user