Compare commits

..

22 Commits

Author SHA1 Message Date
Wing Lian
b708a1cc45 validate config to set defaults 2025-04-26 13:11:25 -04:00
Rahul Tuli
daa9a58f83 Add: line about further optimizations using llmcompressor
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
2025-04-24 14:06:25 -04:00
Rahul Tuli
ae7069e15b Merge branch 'main' into llmcompressor-sft 2025-04-24 12:37:14 -05:00
Rahul Tuli
20d48cd617 Address Review Comments:
* deleted redundant docs/llm_compressor.qmd
* incorporated feedback in integration README.md
* added llmcompressor integration to docs/custom_integrations.qmd

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
2025-04-24 13:36:09 -04:00
Rahul Tuli
e766a730ba Add: .qmd file 2025-04-24 12:45:57 -04:00
Rahul Tuli
7dc797860e Tests, Style, Updates 2025-04-24 12:45:57 -04:00
Rahul Tuli
ff4904c8c4 Rebase and updates! 2025-04-24 12:45:57 -04:00
Rahul Tuli
45b7293793 Add: llm_compressor integration documentation 2025-04-24 12:45:57 -04:00
Rahul Tuli
279c7178bc Move: LLMCompressorPlugin into it's own submodule 2025-04-24 12:45:57 -04:00
Rahul Tuli
e73c3709f9 Update model config 2025-04-24 12:45:57 -04:00
Rahul Tuli
33562189f8 Use: absolute import 2025-04-24 12:45:57 -04:00
Rahul Tuli
c057a2268f Rename: sft.yaml to sparse-finetuning.yaml 2025-04-24 12:45:57 -04:00
Rahul Tuli
9d7a3809b5 Add: llcompressor installable 2025-04-24 12:45:57 -04:00
Rahul Tuli
b7b24d6a64 Address review comments from @markurtz 2025-04-24 12:45:57 -04:00
Rahul Tuli
8b82b8f7a1 Apply suggestions from @markurtz
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
2025-04-24 12:45:57 -04:00
Rahul Tuli
81da58c0a1 Update llmcompressor version to latest 2025-04-24 12:45:57 -04:00
Rahul Tuli
2cd5a234a7 Revert: TODO's 2025-04-24 12:45:57 -04:00
Rahul Tuli
8c1af0747d Use: warning over warn 2025-04-24 12:45:57 -04:00
Rahul Tuli
a06b360d99 pre commit hooks 2025-04-24 12:45:57 -04:00
Rahul Tuli
0f6456a14f Add:llmcompressor instalable 2025-04-24 12:45:57 -04:00
Rahul Tuli
47a333ce49 Update: review comments! 2025-04-24 12:45:57 -04:00
Rahul Tuli
f9d6776c28 Add: SFTPlugin with llmcompressor 2025-04-24 12:45:57 -04:00
42 changed files with 790 additions and 2250 deletions

View File

@@ -8,7 +8,6 @@ on:
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

161
.runpod/.gitignore vendored
View File

@@ -1,161 +0,0 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
pod/scripts/config.yaml

View File

@@ -1,18 +0,0 @@
FROM runpod/pytorch:3.10-2.0.0-117
COPY .runpod/requirements.txt /requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade pip && \
python3 -m pip install --upgrade -r /requirements.txt
# Environment settings
ARG BASE_VOLUME="/runpod-volume"
ENV BASE_VOLUME=$BASE_VOLUME
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
COPY .runpod/src /src
CMD ["python3", "/src/handler.py"]

View File

@@ -1,335 +0,0 @@
<h1>LLM Post Training- Full fine-tune, LoRA, QLoRa etc. Llama/Mistral/Gemma and more</h1>
# Configuration Options
This document outlines all available configuration options for training models. The configuration can be provided as a JSON request.
## Usage
You can use these configuration Options:
1. As a JSON request body:
```json
{
"input": {
"user_id": "user",
"model_id": "model-name",
"run_id": "run-id",
"credentials": {
"wandb_api_key": "", # add your Weights & biases key. TODO: you will be able to set this in Enviornment variables.
"hf_token": "", # add your HF_token. TODO: you will be able to set this in Enviornment variables.
},
"args": {
"base_model": "NousResearch/Llama-3.2-1B",
// ... other options
}
}
}
```
## Configuration Options
### Model Configuration
| Option | Description | Default |
| ------------------- | --------------------------------------------------------------------------------------------- | -------------------- |
| `base_model` | Path to the base model (local or HuggingFace) | Required |
| `base_model_config` | Configuration path for the base model | Same as base_model |
| `revision_of_model` | Specific model revision from HuggingFace hub | Latest |
| `tokenizer_config` | Custom tokenizer configuration path | Optional |
| `model_type` | Type of model to load | AutoModelForCausalLM |
| `tokenizer_type` | Type of tokenizer to use | AutoTokenizer |
| `hub_model_id` | Repository ID where the model will be pushed on Hugging Face Hub (format: username/repo-name) | Optional |
## Model Family Identification
| Option | Default | Description |
| -------------------------- | ------- | ------------------------------ |
| `is_falcon_derived_model` | `false` | Whether model is Falcon-based |
| `is_llama_derived_model` | `false` | Whether model is LLaMA-based |
| `is_qwen_derived_model` | `false` | Whether model is Qwen-based |
| `is_mistral_derived_model` | `false` | Whether model is Mistral-based |
## Model Configuration Overrides
| Option | Default | Description |
| ----------------------------------------------- | ---------- | ---------------------------------- |
| `overrides_of_model_config.rope_scaling.type` | `"linear"` | RoPE scaling type (linear/dynamic) |
| `overrides_of_model_config.rope_scaling.factor` | `1.0` | RoPE scaling factor |
### Model Loading Options
| Option | Description | Default |
| -------------- | ----------------------------- | ------- |
| `load_in_8bit` | Load model in 8-bit precision | false |
| `load_in_4bit` | Load model in 4-bit precision | false |
| `bf16` | Use bfloat16 precision | false |
| `fp16` | Use float16 precision | false |
| `tf32` | Use tensor float 32 precision | false |
## Memory and Device Settings
| Option | Default | Description |
| ------------------ | --------- | ----------------------- |
| `gpu_memory_limit` | `"20GiB"` | GPU memory limit |
| `lora_on_cpu` | `false` | Load LoRA on CPU |
| `device_map` | `"auto"` | Device mapping strategy |
| `max_memory` | `null` | Max memory per device |
## Training Hyperparameters
| Option | Default | Description |
| ----------------------------- | --------- | --------------------------- |
| `gradient_accumulation_steps` | `1` | Gradient accumulation steps |
| `micro_batch_size` | `2` | Batch size per GPU |
| `eval_batch_size` | `null` | Evaluation batch size |
| `num_epochs` | `4` | Number of training epochs |
| `warmup_steps` | `100` | Warmup steps |
| `warmup_ratio` | `0.05` | Warmup ratio |
| `learning_rate` | `0.00003` | Learning rate |
| `lr_quadratic_warmup` | `false` | Quadratic warmup |
| `logging_steps` | `null` | Logging frequency |
| `eval_steps` | `null` | Evaluation frequency |
| `evals_per_epoch` | `null` | Evaluations per epoch |
| `save_strategy` | `"epoch"` | Checkpoint saving strategy |
| `save_steps` | `null` | Saving frequency |
| `saves_per_epoch` | `null` | Saves per epoch |
| `save_total_limit` | `null` | Maximum checkpoints to keep |
| `max_steps` | `null` | Maximum training steps |
### Dataset Configuration
```yaml
datasets:
- path: vicgalle/alpaca-gpt4 # HuggingFace dataset or TODO: You will be able to add the local path.
type: alpaca # Format type (alpaca, gpteacher, oasst, etc.)
ds_type: json # Dataset type
data_files: path/to/data # Source data files
train_on_split: train # Dataset split to use
```
## Chat Template Settings
| Option | Default | Description |
| ------------------------ | -------------------------------- | ---------------------- |
| `chat_template` | `"tokenizer_default"` | Chat template type |
| `chat_template_jinja` | `null` | Custom Jinja template |
| `default_system_message` | `"You are a helpful assistant."` | Default system message |
## Dataset Processing
| Option | Default | Description |
| ----------------------------- | -------------------------- | --------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
## LoRA Configuration
| Option | Default | Description |
| -------------------------- | ---------------------- | ------------------------------ |
| `adapter` | `"lora"` | Adapter type (lora/qlora) |
| `lora_model_dir` | `""` | Directory with pretrained LoRA |
| `lora_r` | `8` | LoRA attention dimension |
| `lora_alpha` | `16` | LoRA alpha parameter |
| `lora_dropout` | `0.05` | LoRA dropout |
| `lora_target_modules` | `["q_proj", "v_proj"]` | Modules to apply LoRA |
| `lora_target_linear` | `false` | Target all linear modules |
| `peft_layers_to_transform` | `[]` | Layers to transform |
| `lora_modules_to_save` | `[]` | Modules to save |
| `lora_fan_in_fan_out` | `false` | Fan in/out structure |
## Optimization Settings
| Option | Default | Description |
| ------------------------- | ------- | -------------------------- |
| `train_on_inputs` | `false` | Train on input prompts |
| `group_by_length` | `false` | Group by sequence length |
| `gradient_checkpointing` | `false` | Use gradient checkpointing |
| `early_stopping_patience` | `3` | Early stopping patience |
## Learning Rate Scheduling
| Option | Default | Description |
| -------------------------- | ---------- | -------------------- |
| `lr_scheduler` | `"cosine"` | Scheduler type |
| `lr_scheduler_kwargs` | `{}` | Scheduler parameters |
| `cosine_min_lr_ratio` | `null` | Minimum LR ratio |
| `cosine_constant_lr_ratio` | `null` | Constant LR ratio |
| `lr_div_factor` | `null` | LR division factor |
## Optimizer Settings
| Option | Default | Description |
| ---------------------- | ------------ | ------------------- |
| `optimizer` | `"adamw_hf"` | Optimizer choice |
| `optim_args` | `{}` | Optimizer arguments |
| `optim_target_modules` | `[]` | Target modules |
| `weight_decay` | `null` | Weight decay |
| `adam_beta1` | `null` | Adam beta1 |
| `adam_beta2` | `null` | Adam beta2 |
| `adam_epsilon` | `null` | Adam epsilon |
| `max_grad_norm` | `null` | Gradient clipping |
## Attention Implementations
| Option | Default | Description |
| -------------------------- | ------- | ----------------------------- |
| `flash_optimum` | `false` | Use better transformers |
| `xformers_attention` | `false` | Use xformers |
| `flash_attention` | `false` | Use flash attention |
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
| `sdp_attention` | `false` | Use scaled dot product |
| `s2_attention` | `false` | Use shifted sparse attention |
## Tokenizer Modifications
| Option | Default | Description |
| ---------------- | ------- | ---------------------------- |
| `special_tokens` | - | Special tokens to add/modify |
| `tokens` | `[]` | Additional tokens |
## Distributed Training
| Option | Default | Description |
| ----------------------- | ------- | --------------------- |
| `fsdp` | `null` | FSDP configuration |
| `fsdp_config` | `null` | FSDP config options |
| `deepspeed` | `null` | Deepspeed config path |
| `ddp_timeout` | `null` | DDP timeout |
| `ddp_bucket_cap_mb` | `null` | DDP bucket capacity |
| `ddp_broadcast_buffers` | `null` | DDP broadcast buffers |
<details>
<summary><h3>Example Configuration Request:</h3></summary>
Here's a complete example for fine-tuning a LLaMA model using LoRA:
```json
{
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "test-run",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "NousResearch/Llama-3.2-1B",
"load_in_8bit": false,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca"
}
],
"dataset_prepared_path": "last_run_prepared",
"val_set_size": 0.1,
"output_dir": "./outputs/lora-out",
"adapter": "lora",
"sequence_len": 2048,
"sample_packing": true,
"eval_sample_packing": true,
"pad_to_sequence_len": true,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_modules": [
"gate_proj",
"down_proj",
"up_proj",
"q_proj",
"v_proj",
"k_proj",
"o_proj"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": false,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"loss_watchdog_threshold": 5,
"loss_watchdog_patience": 3,
"warmup_steps": 10,
"evals_per_epoch": 4,
"saves_per_epoch": 1,
"weight_decay": 0,
"hub_model_id": "runpod/llama-fr-lora",
"wandb_name": "test-run-1",
"wandb_project": "test-run-1",
"wandb_entity": "axo-test",
"special_tokens": {
"pad_token": "<|end_of_text|>"
}
}
}
}
```
</details>
### Advanced Features
#### Wandb Integration
- `wandb_project`: Project name for Weights & Biases
- `wandb_entity`: Team name in W&B
- `wandb_watch`: Monitor model with W&B
- `wandb_name`: Name of the W&B run
- `wandb_run_id`: ID for the W&B run
#### Performance Optimization
- `sample_packing`: Enable efficient sequence packing
- `eval_sample_packing`: Use sequence packing during evaluation
- `torch_compile`: Enable PyTorch 2.0 compilation
- `flash_attention`: Use Flash Attention implementation
- `xformers_attention`: Use xFormers attention implementation
### Available Optimizers
The following optimizers are supported:
- `adamw_hf`: HuggingFace's AdamW implementation
- `adamw_torch`: PyTorch's AdamW
- `adamw_torch_fused`: Fused AdamW implementation
- `adamw_torch_xla`: XLA-optimized AdamW
- `adamw_apex_fused`: NVIDIA Apex fused AdamW
- `adafactor`: Adafactor optimizer
- `adamw_anyprecision`: Anyprecision AdamW
- `adamw_bnb_8bit`: 8-bit AdamW from bitsandbytes
- `lion_8bit`: 8-bit Lion optimizer
- `lion_32bit`: 32-bit Lion optimizer
- `sgd`: Stochastic Gradient Descent
- `adagrad`: Adagrad optimizer
## Notes
- Set `load_in_8bit: true` or `load_in_4bit: true` for memory-efficient training
- Enable `flash_attention: true` for faster training on modern GPUs
- Use `gradient_checkpointing: true` to reduce memory usage
- Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html).
### Errors:
- if you face any issues with the Flash Attention-2, Delete yoor worker and Re-start.

View File

@@ -1,93 +0,0 @@
{
"title": "Axolotl Fine-Tuning",
"description": "Serverless fine-tuning of open-source LLMs with Axolotl. Supports LoRA, QLoRA, DPO, and more using Hugging Face models and datasets.",
"type": "serverless",
"category": "language",
"iconUrl": "https://avatars.githubusercontent.com/u/167502477",
"config": {
"runsOn": "GPU",
"containerDiskInGb": 200,
"gpuCount": 1,
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
],
"presets": [],
"env": [
{
"key": "TOKENIZER",
"input": {
"name": "Tokenizer",
"type": "string",
"description": "Name or path of the Hugging Face tokenizer to use.",
"default": "",
"advanced": true
}
},
{
"key": "MAX_NUM_SEQS",
"input": {
"name": "Max Num Seqs",
"type": "number",
"description": "Maximum number of sequences per iteration.",
"default": 256,
"advanced": true
}
},
{
"key": "DISABLE_LOG_STATS",
"input": {
"name": "Disable Log Stats",
"type": "boolean",
"description": "Disable logging statistics.",
"default": false,
"trueValue": "true",
"falseValue": "false"
}
},
{
"key": "LOAD_FORMAT",
"input": {
"name": "Load Format",
"type": "string",
"description": "The format of the model weights to load.",
"default": "auto",
"options": [
{
"label": "auto",
"value": "auto"
},
{
"label": "pt",
"value": "pt"
},
{
"label": "safetensors",
"value": "safetensors"
},
{
"label": "npcache",
"value": "npcache"
},
{
"label": "dummy",
"value": "dummy"
},
{
"label": "tensorizer",
"value": "tensorizer"
},
{
"label": "bitsandbytes",
"value": "bitsandbytes"
}
],
"advanced": true
}
}
]
}
}

View File

@@ -1,15 +0,0 @@
# Required Python packages get listed here, one per line.
# Reccomended to lock the version number to avoid unexpected changes.
# You can also install packages from a git repository, e.g.:
# git+https://github.com/runpod/runpod-python.git
# To learn more, see https://pip.pypa.io/en/stable/reference/requirements-file-format/
runpod~=1.7.0
huggingface_hub
typing-extensions
pydantic
pydantic-settings
hf-transfer
setuptools
numpy==2.0.0
axolotl[flash-attn,deepspeed]

View File

@@ -1,577 +0,0 @@
# # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# # This can also be a relative path to a model on disk
# base_model: ./llama-7b-hf
# # You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
# base_model_ignore_patterns:
# # If the base_model repo on hf hub doesn't include configuration .json files,
# # You can set that here, or leave this empty to default to base_model
# base_model_config: ./llama-7b-hf
# # You can specify to choose a specific model revision from huggingface hub
# model_revision:
# # Optional tokenizer configuration override in case you want to use a different tokenizer
# # than the one defined in the base model
# tokenizer_config:
# # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
# model_type: AutoModelForCausalLM
# # Corresponding tokenizer for the model AutoTokenizer is a good choice
# tokenizer_type: AutoTokenizer
# # Trust remote code for untrusted source
# trust_remote_code:
# # use_fast option for tokenizer loading from_pretrained, default to True
# tokenizer_use_fast:
# # Whether to use the legacy tokenizer setting, defaults to True
# tokenizer_legacy:
# # Resize the model embeddings when new tokens are added to multiples of 32
# # This is reported to improve training speed on some models
# resize_token_embeddings_to_32x:
# # Used to identify which the model is based on
# is_falcon_derived_model:
# is_llama_derived_model:
# # Please note that if you set this to true, `padding_side` will be set to "left" by default
# is_mistral_derived_model:
# is_qwen_derived_model:
# # optional overrides to the base model configuration
# model_config:
# # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
# rope_scaling:
# type: # linear | dynamic
# factor: # float
# # Whether you are training a 4-bit GPTQ quantized model
# gptq: true
# gptq_groupsize: 128 # group size
# gptq_model_v1: false # v1 or v2
# # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
# load_in_8bit: true
# # Use bitsandbytes 4 bit
# load_in_4bit:
# # Use CUDA bf16
# bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
# # Use CUDA fp16
# fp16: true
# # Use CUDA tf32
# tf32: true # require >=ampere
# # No AMP (automatic mixed precision)
# bfloat16: true # require >=ampere
# float16: true
# # A list of one or more datasets to finetune the model with
# datasets:
# # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
# - path: vicgalle/alpaca-gpt4
# # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
# type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
# ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
# data_files: # Optional[str] path to source data files
# shards: # Optional[int] number of shards to split data into
# name: # Optional[str] name of dataset configuration to load
# train_on_split: train # Optional[str] name of dataset split to load from
# # Optional[str] fastchat conversation type, only used with type: sharegpt
# conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# field_human: # Optional[str]. Human key to use for conversation.
# field_model: # Optional[str]. Assistant key to use for conversation.
# # Custom user prompt
# - path: repo
# type:
# # The below are defaults. only set what's needed.
# system_prompt: ""
# system_format: "{system}"
# field_system: system
# field_instruction: instruction
# field_input: input
# field_output: output
# # Customizable to be single line or multi-line
# # 'format' can include {input}
# format: |-
# User: {instruction} {input}
# Assistant:
# # 'no_input_format' cannot include {input}
# no_input_format: "{instruction} "
# # For `completion` datsets only, uses the provided field instead of `text` column
# field:
# # Axolotl attempts to save the dataset as an arrow after packing the data together so
# # subsequent training attempts load faster, relative path
# dataset_prepared_path: data/last_run_prepared
# # Push prepared dataset to hub
# push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set.
# dataset_processes: # defaults to os.cpu_count() if not set
# # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub
# # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
# hub_strategy:
# # Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# # Required to be true when used in combination with `push_dataset_to_hub`
# hf_use_auth_token: # boolean
# # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
# val_set_size: 0.04
# # Num shards for whole dataset
# dataset_shard_num:
# # Index of shard to use for whole dataset
# dataset_shard_idx:
# # The maximum length of an input to train with, this should typically be less than 2048
# # as most models have a token/context limit of 2048
# sequence_len: 2048
# # Pad inputs so each step uses constant sized buffers
# # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
# pad_to_sequence_len:
# # Max sequence length to concatenate training samples together up to
# # Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# # FutureWarning: This will soon be DEPRECATED
# max_packed_sequence_len: 1024
# # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
# sample_packing:
# # Set to 'false' if getting errors during eval with sample_packing on.
# eval_sample_packing:
# # You can set these packing optimizations AFTER starting a training at least once.
# # The trainer will provide recommended values for these values.
# sample_packing_eff_est:
# total_num_tokens:
# # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
# adapter: lora
# # If you already have a lora model trained that you want to load, put that here.
# # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
# lora_model_dir:
# # LoRA hyperparameters
# # For more details about the following options, see:
# # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
# lora_r: 8
# lora_alpha: 16
# lora_dropout: 0.05
# lora_target_modules:
# - q_proj
# - v_proj
# # - k_proj
# # - o_proj
# # - gate_proj
# # - down_proj
# # - up_proj
# lora_target_linear: # If true, will target all linear layers
# # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# # `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# # https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
# lora_modules_to_save:
# # - embed_tokens
# # - lm_head
# # Once you complete training, the model will be saved to the following directory.
# # If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# # Make sure `lora_model_dir` points to this directory if you want to use the trained model.
# lora_out_dir:
# lora_fan_in_fan_out: false
# # ReLoRA configuration
# # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
# relora_steps: # Number of steps per ReLoRA restart
# relora_warmup_steps: # Number of per-restart warmup steps
# relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# # wandb configuration if you're using it
# wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
# wandb_project: # Your wandb project name
# wandb_entity: # A wandb Team name if using a Team
# wandb_watch:
# wandb_run_id: # Set the name of your wandb run
# wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# # Where to save the full-finetuned model to
# output_dir: ./completed-model
# # Whether to use torch.compile and which backend to use
# torch_compile: # bool
# torch_compile_backend: # Optional[str]
# # Training hyperparameters
# # If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
# gradient_accumulation_steps: 1
# # The number of samples to include in each batch. This is the number of samples sent to each GPU.
# micro_batch_size: 2
# eval_batch_size:
# num_epochs: 4
# warmup_steps: 100 # cannot use with warmup_ratio
# warmup_ratio: 0.05 # cannot use with warmup_steps
# learning_rate: 0.00003
# lr_quadratic_warmup:
# logging_steps:
# save_strategy: # Set to `no` to skip checkpoint saves
# save_steps: # Leave empty to save at each epoch
# eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
# save_total_limit: # Checkpoints saved at a time
# # Maximum number of iterations to train for. It precedes num_epochs which means that
# # if both are set, num_epochs will not be guaranteed.
# # e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
# max_steps:
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false
# # Group similarly sized data to minimize padding.
# # May be slower to start, as it must download and sort the entire dataset.
# # Note that training loss may have an oscillating pattern with this enabled.
# group_by_length: false
# # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
# gradient_checkpointing: false
# # Stop training after this many evaluation losses have increased in a row
# # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
# early_stopping_patience: 3
# # Specify a scheduler and kwargs to use with the optimizer
# lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
# lr_scheduler_kwargs:
# # For one_cycle optim
# lr_div_factor: # Learning rate div factor
# # For log_sweep optim
# log_sweep_min_lr:
# log_sweep_max_lr:
# # Specify optimizer
# # Valid values are driven by the Transformers OptimizerNames class, see:
# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
# #
# # Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
# # torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
# # in the examples/ for your model and fine-tuning use case.
# #
# # Valid values for 'optimizer' include:
# # - adamw_hf
# # - adamw_torch
# # - adamw_torch_fused
# # - adamw_torch_xla
# # - adamw_apex_fused
# # - adafactor
# # - adamw_anyprecision
# # - sgd
# # - adagrad
# # - adamw_bnb_8bit
# # - lion_8bit
# # - lion_32bit
# # - paged_adamw_32bit
# # - paged_adamw_8bit
# # - paged_lion_32bit
# # - paged_lion_8bit
# optimizer:
# # Specify weight decay
# weight_decay:
# # adamw hyperparams
# adam_beta1:
# adam_beta2:
# adam_epsilon:
# # Gradient clipping max norm
# max_grad_norm:
# # Augmentation techniques
# # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# # currently only supported on Llama and Mistral
# noisy_embedding_alpha:
# # Whether to bettertransformers
# flash_optimum:
# # Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
# xformers_attention:
# # Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
# flash_attention:
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# # Whether to use scaled-dot-product attention
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# sdp_attention:
# # Landmark attention (only llama)
# landmark_attention:
# # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# # LLaMA only
# xpos_rope:
# # Resume from a specific checkpoint dir
# resume_from_checkpoint:
# # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# # Be careful with this being turned on between different models.
# auto_resume_from_checkpoints: false
# # Don't mess with this, it's here for accelerate and torchrun
# local_rank:
# # Add or change special tokens.
# # If you add tokens here, you don't need to add them to the `tokens` list.
# special_tokens:
# # bos_token: "<s>"
# # eos_token: "</s>"
# # unk_token: "<unk>"
# # Add extra tokens.
# tokens:
# # FSDP
# fsdp:
# fsdp_config:
# # Deepspeed config path. e.g., deepspeed/zero3.json
# deepspeed:
# # Advanced DDP Arguments
# ddp_timeout:
# ddp_bucket_cap_mb:
# ddp_broadcast_buffers:
# # Path to torch distx for optim 'adamw_anyprecision'
# torchdistx_path:
# # Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
# pretraining_dataset:
# # Debug mode
# debug:
# # Seed
# seed:
# # Allow overwrite yml config using from cli
# strict:
base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG}
revision_of_model: ${REVISION_OF_MODEL}
tokenizer_config: ${TOKENIZER_CONFIG}
model_type: ${MODEL_TYPE}
tokenizer_type: ${TOKENIZER_TYPE}
trust_remote_code: ${TRUST_REMOTE_CODE}
tokenizer_use_fast: ${TOKENIZER_USE_FAST}
tokenizer_legacy: ${TOKENIZER_LEGACY}
resize_token_embeddings_to_32x: ${RESIZE_TOKEN_EMBEDDINGS_TO_32X}
is_falcon_derived_model: ${IS_FALCON_DERIVED_MODEL}
is_llama_derived_model: ${IS_LLAMA_DERIVED_MODEL}
is_qwen_derived_model: ${IS_QWEN_DERIVED_MODEL}
is_mistral_derived_model: ${IS_MISTRAL_DERIVED_MODEL}
overrides_of_model_config:
rope_scaling:
type: ${ROPE_SCALING_TYPE}
factor: ${ROPE_SCALING_FACTOR}
bnb_config_kwargs:
llm_int8_has_fp16_weight: ${BNB_LLM_INT8_HAS_FP16_WEIGHT}
bnb_4bit_quant_type: ${BNB_4BIT_QUANT_TYPE}
bnb_4bit_use_double_quant: ${BNB_4BIT_USE_DOUBLE_QUANT}
gptq: ${GPTQ}
load_in_8bit: ${LOAD_IN_8BIT}
load_in_4bit: ${LOAD_IN_4BIT}
bf16: ${BF16}
fp16: ${FP16}
tf32: ${TF32}
bfloat16: ${BFLOAT16}
float16: ${FLOAT16}
gpu_memory_limit: ${GPU_MEMORY_LIMIT}
lora_on_cpu: ${LORA_ON_CPU}
datasets:
- path: ${DATASET_PATH}
type: ${DATASET_TYPE}
ds_type: ${DATASET_DS_TYPE}
data_files: ${DATASET_DATA_FILES}
shards: ${DATASET_SHARDS}
name: ${DATASET_NAME}
train_on_split: ${DATASET_TRAIN_ON_SPLIT}
revision: ${DATASET_REVISION}
trust_remote_code: ${DATASET_TRUST_REMOTE_CODE}
rl: ${RL}
dpo_use_weighting: ${DPO_USE_WEIGHTING}
chat_template: ${CHAT_TEMPLATE}
chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_processes: ${DATASET_PROCESSES}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY}
hf_use_auth_token: ${HF_USE_AUTH_TOKEN}
val_set_size: ${VAL_SET_SIZE}
dataset_shard_num: ${DATASET_SHARD_NUM}
dataset_shard_idx: ${DATASET_SHARD_IDX}
sequence_len: ${SEQUENCE_LEN}
pad_to_sequence_len: ${PAD_TO_SEQUENCE_LEN}
sample_packing: ${SAMPLE_PACKING}
eval_sample_packing: ${EVAL_SAMPLE_PACKING}
sample_packing_eff_est: ${SAMPLE_PACKING_EFF_EST}
total_num_tokens: ${TOTAL_NUM_TOKENS}
sample_packing_group_size: ${SAMPLE_PACKING_GROUP_SIZE}
sample_packing_bin_size: ${SAMPLE_PACKING_BIN_SIZE}
batch_flattening: ${BATCH_FLATTENING}
device_map: ${DEVICE_MAP}
max_memory: ${MAX_MEMORY}
adapter: ${ADAPTER}
lora_model_dir: ${LORA_MODEL_DIR}
lora_r: ${LORA_R}
lora_alpha: ${LORA_ALPHA}
lora_dropout: ${LORA_DROPOUT}
lora_target_modules:
- ${LORA_TARGET_MODULES}
lora_target_linear: ${LORA_TARGET_LINEAR}
peft_layers_to_transform: ${PEFT_LAYERS_TO_TRANSFORM}
lora_modules_to_save: ${LORA_MODULES_TO_SAVE}
lora_fan_in_fan_out: ${LORA_FAN_IN_FAN_OUT}
loraplus_lr_ratio: ${LORAPLUS_LR_RATIO}
loraplus_lr_embedding: ${LORAPLUS_LR_EMBEDDING}
peft:
loftq_config:
loftq_bits: ${LOFTQ_BITS}
relora_steps: ${RELORA_STEPS}
relora_warmup_steps: ${RELORA_WARMUP_STEPS}
relora_anneal_steps: ${RELORA_ANNEAL_STEPS}
relora_prune_ratio: ${RELORA_PRUNE_RATIO}
relora_cpu_offload: ${RELORA_CPU_OFFLOAD}
wandb_mode: ${WANDB_MODE}
wandb_project: ${WANDB_PROJECT}
wandb_entity: ${WANDB_ENTITY}
wandb_watch: ${WANDB_WATCH}
wandb_name: ${WANDB_NAME}
wandb_run_id: ${WANDB_RUN_ID}
wandb_log_model: ${WANDB_LOG_MODEL}
mlflow_tracking_uri: ${MLFLOW_TRACKING_URI}
mlflow_experiment_name: ${MLFLOW_EXPERIMENT_NAME}
mlflow_run_name: ${MLFLOW_RUN_NAME}
hf_mlflow_log_artifacts: ${HF_MLFLOW_LOG_ARTIFACTS}
use_comet: ${USE_COMET}
comet_api_key: ${COMET_API_KEY}
comet_workspace: ${COMET_WORKSPACE}
comet_project_name: ${COMET_PROJECT_NAME}
comet_experiment_key: ${COMET_EXPERIMENT_KEY}
comet_mode: ${COMET_MODE}
comet_online: ${COMET_ONLINE}
comet_experiment_config: ${COMET_EXPERIMENT_CONFIG}
output_dir: ${OUTPUT_DIR}
torch_compile: ${TORCH_COMPILE}
torch_compile_backend: ${TORCH_COMPILE_BACKEND}
gradient_accumulation_steps: ${GRADIENT_ACCUMULATION_STEPS}
micro_batch_size: ${MICRO_BATCH_SIZE}
eval_batch_size: ${EVAL_BATCH_SIZE}
num_epochs: ${NUM_EPOCHS}
warmup_steps: ${WARMUP_STEPS}
warmup_ratio: ${WARMUP_RATIO}
learning_rate: ${LEARNING_RATE}
lr_quadratic_warmup: ${LR_QUADRATIC_WARMUP}
logging_steps: ${LOGGING_STEPS}
eval_steps: ${EVAL_STEPS}
evals_per_epoch: ${EVALS_PER_EPOCH}
save_strategy: ${SAVE_STRATEGY}
save_steps: ${SAVE_STEPS}
saves_per_epoch: ${SAVES_PER_EPOCH}
save_total_limit: ${SAVE_TOTAL_LIMIT}
max_steps: ${MAX_STEPS}
eval_table_size: ${EVAL_TABLE_SIZE}
eval_max_new_tokens: ${EVAL_MAX_NEW_TOKENS}
eval_causal_lm_metrics: ${EVAL_CAUSAL_LM_METRICS}
profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}
early_stopping_patience: ${EARLY_STOPPING_PATIENCE}
lr_scheduler: ${LR_SCHEDULER}
lr_scheduler_kwargs: ${LR_SCHEDULER_KWARGS}
cosine_min_lr_ratio: ${COSINE_MIN_LR_RATIO}
cosine_constant_lr_ratio: ${COSINE_CONSTANT_LR_RATIO}
lr_div_factor: ${LR_DIV_FACTOR}
optimizer: ${OPTIMIZER}
optim_args: ${OPTIM_ARGS}
optim_target_modules: ${OPTIM_TARGET_MODULES}
weight_decay: ${WEIGHT_DECAY}
adam_beta1: ${ADAM_BETA1}
adam_beta2: ${ADAM_BETA2}
adam_epsilon: ${ADAM_EPSILON}
max_grad_norm: ${MAX_GRAD_NORM}
neftune_noise_alpha: ${NEFTUNE_NOISE_ALPHA}
flash_optimum: ${FLASH_OPTIMUM}
xformers_attention: ${XFORMERS_ATTENTION}
flash_attention: ${FLASH_ATTENTION}
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
sdp_attention: ${SDP_ATTENTION}
s2_attention: ${S2_ATTENTION}
resume_from_checkpoint: ${RESUME_FROM_CHECKPOINT}
auto_resume_from_checkpoints: ${AUTO_RESUME_FROM_CHECKPOINTS}
local_rank: ${LOCAL_RANK}
special_tokens:
bos_token: ${SPECIAL_TOKEN_BOS}
eos_token: ${SPECIAL_TOKEN_EOS}
unk_token: ${SPECIAL_TOKEN_UNK}
pad_token: ${SPECIAL_TOKEN_PAD}
tokens: ${TOKENS}
fsdp: ${FSDP}
fsdp_config: ${FSDP_CONFIG}
deepspeed: ${DEEPSPEED}
ddp_timeout: ${DDP_TIMEOUT}
ddp_bucket_cap_mb: ${DDP_BUCKET_CAP_MB}
ddp_broadcast_buffers: ${DDP_BROADCAST_BUFFERS}
torchdistx_path: ${TORCHDISTX_PATH}
pretraining_dataset: ${PRETRAINING_DATASET}
debug: ${DEBUG}
seed: ${SEED}
strict: ${STRICT}

View File

@@ -1,64 +0,0 @@
"""
Runpod serverless entrypoint handler
"""
import os
import runpod
import yaml
from huggingface_hub._login import login
from train import train
from utils import get_output_dir
BASE_VOLUME = os.environ.get("BASE_VOLUME", "/runpod-volume")
if not os.path.exists(BASE_VOLUME):
os.makedirs(BASE_VOLUME)
logger = runpod.RunPodLogger()
async def handler(job):
runpod_job_id = job["id"]
inputs = job["input"]
run_id = inputs.get("run_id", "default_run_id")
args = inputs.get("args", {})
# Set output directory
output_dir = os.path.join(BASE_VOLUME, get_output_dir(run_id))
args["output_dir"] = output_dir
# First save args to a temporary config file
config_path = "/workspace/test_config.yaml"
# Add run_name and job_id to args before saving
args["run_name"] = run_id
args["runpod_job_id"] = runpod_job_id
yaml_data = yaml.dump(args, default_flow_style=False)
with open(config_path, "w", encoding="utf-8") as file:
file.write(yaml_data)
# Handle credentials
credentials = inputs.get("credentials", {})
if "wandb_api_key" in credentials:
os.environ["WANDB_API_KEY"] = credentials["wandb_api_key"]
if "hf_token" in credentials:
os.environ["HF_TOKEN"] = credentials["hf_token"]
if os.environ.get("HF_TOKEN"):
login(token=os.environ["HF_TOKEN"])
else:
logger.info("No HF_TOKEN provided. Skipping login.")
logger.info("Starting Training.")
async for result in train(config_path): # Pass the config path instead of args
logger.info(result)
logger.info("Training Complete.")
# Cleanup
del os.environ["WANDB_API_KEY"]
del os.environ["HF_TOKEN"]
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})

View File

@@ -1,61 +0,0 @@
{
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "NousResearch/Meta-Llama-3-8B",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_8bit": true,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca"
}
],
"val_set_size": 0.05,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 4,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": false,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|end_of_text|>"
}
}
}
}

View File

@@ -1,45 +0,0 @@
"""
Runpod train entrypoint
"""
import asyncio
async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
"""
Run preprocessing (if enabled) and training with the given config file
:param config_path: Path to the YAML config file
:param gpu_id: GPU ID to use (default: "0")
:param preprocess: Whether to run preprocessing (default: True)
"""
# First check if preprocessing is needed
if preprocess:
# Preprocess command
preprocess_cmd = (
f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}"
)
process = await asyncio.create_subprocess_shell(
preprocess_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
if process.stdout is not None:
async for line in process.stdout:
yield f"Preprocessing: {line.decode().strip()}"
await process.wait()
yield "Preprocessing completed."
else:
yield "Skipping preprocessing step."
# Training command
train_cmd = f"axolotl train {config_path}"
process = await asyncio.create_subprocess_shell(
train_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
)
if process.stdout is not None:
async for line in process.stdout:
yield f"Training: {line.decode().strip()}"
await process.wait()

View File

@@ -1,89 +0,0 @@
"""
Runpod launcher utils
"""
import os
import yaml
def get_output_dir(run_id):
path = f"fine-tuning/{run_id}"
return path
def make_valid_config(input_args):
"""
Creates and saves updated config file, returns the path to the new config
:param input_args: dict of input args
:return: str, path to the updated config file
"""
# Load default config
with open("config/config.yaml", "r", encoding="utf-8") as fin:
all_args = yaml.safe_load(fin)
if not input_args:
print("No args provided, using defaults")
else:
all_args.update(input_args)
# Create updated config path
updated_config_path = "config/updated_config.yaml"
# Save updated config to new file
with open(updated_config_path, "w", encoding="utf-8") as f:
yaml.dump(all_args, f)
return updated_config_path
def set_config_env_vars(args: dict):
"""
Convert API arguments into environment variables.
Handles nested dictionaries, lists, and special values.
Args:
args (dict): The arguments dictionary from the API request
"""
def process_value(value):
"""Convert Python values to string format for environment variables"""
if value is None:
return ""
if isinstance(value, bool):
return str(value).lower()
if isinstance(value, (list, dict)):
return str(value)
return str(value)
def set_env_vars(data, prefix=""):
"""Recursively set environment variables from nested dictionary"""
for key, value in data.items():
env_key = prefix + key.upper()
# Handle special cases
if isinstance(value, dict):
# For nested dictionaries (like special_tokens)
set_env_vars(value, f"{env_key}_")
elif isinstance(value, list):
# Handle list of dictionaries (like datasets)
if value and isinstance(value[0], dict):
for i, item in enumerate(value):
set_env_vars(item, f"{env_key}_{i}_")
else:
# For simple lists (like lora_target_modules)
os.environ[env_key] = process_value(value)
else:
# Handle all other cases
os.environ[env_key] = process_value(value)
# Clear any existing related environment variables
# This prevents old values from persisting
for key in list(os.environ.keys()):
if key.startswith(
("BASE_MODEL", "MODEL_TYPE", "TOKENIZER_TYPE", "DATASET", "LORA_", "WANDB_")
):
del os.environ[key]
# Set new environment variables
set_env_vars(args)

View File

@@ -1,89 +0,0 @@
{
"tests": [
{
"name": "quick_smoke_test_sft",
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "NousResearch/Meta-Llama-3-8B",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_8bit": true,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca"
}
],
"val_set_size": 0.05,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 4,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": false,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|end_of_text|>"
}
}
},
"timeout": 100000
}
],
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,
"containerDiskInGb": 200,
"env": [
{
"key": "TOKENIZER",
"value": ""
},
{
"key": "DISABLE_LOG_STATS",
"value": "true"
}
],
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
]
}
}

View File

@@ -52,4 +52,4 @@ pytest -v --durations=10 \
--cov-append \
--cov-report=xml:e2e-coverage.xml
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}

View File

@@ -1,7 +1,5 @@
codecov:
require_ci_to_pass: yes
notify:
wait_for_ci: true
coverage:
precision: 2

View File

@@ -49,7 +49,8 @@ sections = [
("Knowledge Distillation (KD)", "kd"),
("Liger Kernels", "liger"),
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
("Spectrum", "spectrum")
("Spectrum", "spectrum"),
("LLMCompressor", "llm_compressor")
]
for section_name, folder_name in sections:

View File

@@ -28,8 +28,6 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.7.0`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
@@ -52,7 +50,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest
# nightly build
@@ -70,7 +68,6 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`

View File

@@ -0,0 +1,77 @@
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true

View File

@@ -10,6 +10,7 @@ plugins:
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
cut_cross_entropy: true
llama4_linearized_experts: true # needed with custom linearized experts model
load_in_4bit: true

View File

@@ -149,6 +149,9 @@ extras_require = {
"vllm": [
"vllm==0.7.2",
],
"llmcompressor": [
"llmcompressor==0.5.1",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(

View File

@@ -932,6 +932,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
kwargs["ring_attn_func"] = training_args.ring_attn_func
return collator(
*collator_args,
@@ -1048,9 +1051,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
@@ -1121,12 +1121,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args
def build(self, total_num_steps):

View File

@@ -371,15 +371,13 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
loss = super().compute_loss(
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return loss
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

View File

@@ -6,4 +6,4 @@
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,86 +1,16 @@
"""
Module for Axolotl trainer sequence parallelism mixin and training context manager
"""
"""Module for Axolotl trainer sequence parallelism mixin"""
import functools
import logging
import torch
import torch.distributed as dist
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
update_ring_attn_params,
)
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
LOG = logging.getLogger(__name__)
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
local_rank: Local rank in the sequence parallel group
local_world_size: World size of the sequence parallel group
ring_attn_func: The ring attention function to use
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[local_rank],
chunks[2 * local_world_size - local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous()
return batch
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
@@ -157,157 +87,3 @@ class SequenceParallelMixin:
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
class SequenceParallelContextManager:
"""
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
"""
def __init__(
self,
model: nn.Module,
sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc,
):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Create a partially applied version of the apply_sequence_parallelism function
# with pre-configured params
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(batch=kwargs)
return args, kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output):
# Gather the sharded outputs
return self.gather_outputs(output)
# Register both hooks
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
self.model.register_forward_hook(sequence_parallel_post_hook)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def gather_outputs(self, output):
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
# Handle different output formats (dict, tensor, etc.)
if isinstance(output, dict):
gathered_output = {}
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
# Gather logits or other sequence-sharded tensors
gathered_value = self.gather_tensor(value)
gathered_output[key] = gathered_value
else:
gathered_value = value.clone()
dist.all_reduce(
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
)
gathered_output[key] = gathered_value
return gathered_output
if isinstance(output, torch.Tensor):
return self.gather_tensor(output)
return output
def gather_tensor(self, tensor):
"""Gather a sharded tensor from all ranks."""
# Prepare tensors for all_gather
world_size = self.local_world_size
# Create list to store tensors from all ranks
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
# All-gather operation
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
# Concatenate along sequence dimension (typically dim=1)
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
# Simple concatenation for standard sharding
return torch.cat(gathered_tensors, dim=1)
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
# Each rank has a pattern of (rank, world_size*2-rank-1)
reconstituted_tensors = [None] * (world_size * 2)
# First, split each gathered tensor into its two chunks
for rank, gathered_tensor in enumerate(gathered_tensors):
# Each tensor contains two chunks in the sequence dimension
chunk_size = gathered_tensor.size(1) // 2
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
# Place chunks in their original positions
reconstituted_tensors[rank] = chunk1
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
# Concatenate the reconstituted tensors in the correct order
return torch.cat(reconstituted_tensors, dim=1)
# Otherwise, RingAttnFunc.BATCH_STRIPE
# In striping, each rank has every world_size-th slice
batch_size = tensor.size(0)
hidden_dim = tensor.size(-1)
# First, determine the full sequence length
total_seq_len = 0
for t in gathered_tensors:
total_seq_len += t.size(1)
# Create a tensor to hold the unstriped result
result = torch.zeros(
batch_size,
total_seq_len,
hidden_dim,
dtype=tensor.dtype,
device=tensor.device,
)
# For each rank's tensor, distribute its slices to the correct positions
for rank, gathered_tensor in enumerate(gathered_tensors):
# The rank's tensor contains every world_size-th slice
# starting from its rank position
seq_len = gathered_tensor.size(1)
for i in range(seq_len):
# Calculate the position in the full tensor
pos = i * world_size + rank
if pos < total_seq_len:
result[:, pos] = gathered_tensor[:, i]
return result

View File

@@ -27,6 +27,8 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
```
## Supported Models

View File

@@ -28,7 +28,7 @@ class CutCrossEntropyArgs(BaseModel):
Input args for Cut Cross Entropy.
"""
cut_cross_entropy: Optional[bool] = True
cut_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod

View File

@@ -0,0 +1,108 @@
# LLMCompressor Integration
Fine-tune sparsified models in Axolotl using Neural Magic's [LLMCompressor](https://github.com/vllm-project/llm-compressor).
This integration enables fine-tuning of models sparsified using LLMCompressor within the Axolotl training framework. By combining LLMCompressor's model compression capabilities with Axolotl's distributed training pipelines, users can efficiently fine-tune sparse models at scale.
It uses Axolotls plugin system to hook into the fine-tuning flows while maintaining sparsity throughout training.
---
## Requirements
- Axolotl with `llmcompressor` extras:
```bash
pip install "axolotl[llmcompressor]"
```
- Requires `llmcompressor >= 0.5.1`
This will install all necessary dependencies to fine-tune sparsified models using the integration.
---
## Usage
To enable sparse fine-tuning with this integration, include the plugin in your Axolotl config:
```yaml
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true
# ... (other training arguments)
```
This plugin **does not apply pruning or sparsification itself** — it is intended for **fine-tuning models that have already been sparsified**.
Pre-sparsified checkpoints can be:
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
- Any custom LLM with compatible sparsity patterns that you've created yourself
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
### Storage Optimization with save_compressed
Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which:
- Reduces disk space usage by approximately 40%
- Maintains compatibility with vLLM for accelerated inference
- Maintains compatibility with llmcompressor for further optimization (example: quantization)
This option is highly recommended when working with sparse models to maximize the benefits of model compression.
### Example Config
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
---
## Inference with vLLM
After fine-tuning your sparse model, you can leverage vLLM for efficient inference.
You can also use LLMCompressor to apply additional quantization to your fine-tuned
sparse model before inference for even greater performance benefits.:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM("path/to/your/sparse/model")
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).
## Learn More
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)

View File

@@ -0,0 +1,5 @@
"""Integration entry point for the LLMCompressor plugin."""
from .plugin import LLMCompressorPlugin
__all__ = ["LLMCompressorPlugin"]

View File

@@ -0,0 +1,40 @@
"""
LLMCompressor and Sparse Finetuning config models.
"""
from typing import Any
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class CompressionArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[
Any,
Field(
description="The recipe containing the compression algorithms and hyperparameters to apply."
),
]
save_compressed: Annotated[
bool,
Field(
default=False,
description="Whether to save the compressed model after training.",
),
]
class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[
CompressionArgs,
Field(
description="Arguments enabling compression pathways through the LLM Compressor plugins"
),
]

View File

@@ -0,0 +1,171 @@
"""
Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks
by maintaining masks for zero weights during training.
"""
import logging
from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
from llmcompressor import active_session, create_session
from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe
from torch.nn import Module
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
class LLMCompressorCallbackHandler(TrainerCallback):
"""
Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps,
ensuring zero-weight updates are canceled out.
"""
def __init__(self, trainer: Trainer, recipe: Any):
"""
Initialize the Sparse Finetuning callback handler.
Args:
trainer (Trainer): Huggingface Trainer instance.
recipe (Recipe | dict): Sparse finetuning recipe to apply.
"""
super().__init__()
self.trainer = trainer
self.recipe = (
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
)
self.original_compute_loss = trainer.compute_loss
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
create_session()
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of training. Initializes the compression session.
Args:
args (TrainingArguments): Training arguments.
state (TrainerState): Trainer state.
control (TrainerControl): Trainer control.
"""
super().on_train_begin(args, state, control, **kwargs)
self.trainer.accelerator.wait_for_everyone()
active_session().initialize(
model=self.trainer.model,
optimizer=self.trainer.optimizer,
start=state.epoch,
recipe=self.recipe,
)
self.trainer.accelerator.wait_for_everyone()
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of a training step. Triggers batch_start callback.
"""
super().on_step_begin(args, state, control, **kwargs)
session_callbacks.batch_start()
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
"""
super().on_step_end(args, state, control, **kwargs)
session_callbacks.optim_pre_step()
session_callbacks.optim_post_step()
session_callbacks.batch_end()
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of training. Finalizes the compression session.
"""
super().on_train_end(args, state, control, **kwargs)
active_session().finalize()
self.trainer.compute_loss_func = self.original_compute_loss
class LLMCompressorPlugin(BasePlugin):
"""
Sparse Finetuning plugin for Axolotl integration.
"""
def get_input_args(self) -> str:
"""
Returns the path to the plugin's argument definition.
Returns:
str: Dotted path to the LLMCompressorArgs class.
"""
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds Sparse Finetuning callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = LLMCompressorCallbackHandler(
trainer=trainer,
recipe=cfg.llmcompressor.recipe,
)
return [callback]
def compute_loss_wrapper(
compute_loss_func: Callable[Concatenate[Module, P], R],
) -> Callable[Concatenate[Module, P], R]:
"""
Wraps the loss computation function to trigger the loss_calculated callback.
Args:
compute_loss_func (Callable): Original loss computation function.
Returns:
Callable: Wrapped function that also invokes the loss_calculated callback.
"""
@wraps(compute_loss_func)
def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(model, *args, **kwargs)
if active_session().lifecycle.initialized_ and model.training:
session_callbacks.loss_calculated(loss=loss)
return loss
return compute_and_notify

View File

@@ -0,0 +1,40 @@
"""Utilities for llmcompressor integration with axolotl."""
from typing import Union
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
)
from transformers import PreTrainedModel, Trainer
def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
Synchronize processes, apply compression hooks, and save the model.
Args:
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
# Only the main process writes the files
if not trainer.accelerator.is_main_process:
return
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -6,7 +6,6 @@ import os
import signal
import sys
import weakref
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Dict
@@ -26,9 +25,6 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -189,28 +185,16 @@ def execute_training(
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
# Define the context managers to use
flash_context = (
torch.backends.cuda.sdp_kernel(
LOG.info("Starting trainer...")
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
)
if cfg.flash_optimum
else nullcontext()
)
sequence_parallel_context = (
SequenceParallelContextManager(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func,
)
if cfg.sequence_parallel_degree > 1
else nullcontext()
)
LOG.info("Starting trainer...")
with flash_context, sequence_parallel_context:
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
@@ -287,6 +271,19 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
save_compressed_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
@@ -295,6 +292,7 @@ def save_trained_model(
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)

View File

@@ -1,12 +1,20 @@
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
"""
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
includes logic for handling sequence parallelism collation.
"""
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass
class DataCollatorForSeq2Seq:
@@ -41,6 +49,8 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_degree (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
"""
tokenizer: PreTrainedTokenizerBase
@@ -51,6 +61,17 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
sequence_parallel_degree: int = 1
ring_attn_func: RingAttnFunc | None = None
def __post_init__(self):
if self.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys()
@@ -120,8 +141,62 @@ class DataCollatorForSeq2Seq:
)
features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_degree > 1:
features = self.apply_sequence_parallelism(features)
return features
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary from parent collator.
Returns:
Sliced batch dictionary.
"""
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].size(1)
# Update params for varlen ring attention calculation
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
for key in batch:
if batch[key].size(1) == total_seq_len:
if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# TODO(djsaunde): This doesn't seem to work as expected
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):

View File

@@ -126,6 +126,9 @@ def normalize_config(cfg):
with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f)
if cfg.sequence_parallel_degree is None:
cfg.sequence_parallel_degree = 1
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step

View File

@@ -134,9 +134,10 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
"csv", data_files=f.name, split="train", streaming=True
)
else:
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")

View File

@@ -1,7 +1,5 @@
"""custom checkpointing utils"""
from functools import partial
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
@@ -11,10 +9,6 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
decoder_layer.__self__,
*args,
)

View File

@@ -139,6 +139,22 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
# Detect compressed-tensors config
is_compressed_tensors_config = (
quant_config_exists
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
)
if is_compressed_tensors_config:
if model_config.quantization_config.get("config_groups"):
LOG.warning(
"Found `config_groups` in a compressed-tensors config. "
"QAT integration with llmcompressor is not tested."
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config

View File

@@ -1149,17 +1149,22 @@ class AxolotlInputConfig(
return data
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1:
if not self.flash_attention:
@field_validator("sequence_parallel_degree", mode="after")
@classmethod
def check_sequence_parallel_degree(cls, value, info):
if not value:
value = 1
if value > 1:
if not info.data.get("flash_attention"):
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
if self.sample_packing and self.micro_batch_size > 1:
if (
info.data.get("sample_packing")
and not info.data["micro_batch_size"] == 1
):
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement"
@@ -1179,40 +1184,42 @@ class AxolotlInputConfig(
# according to the proportion of non-padding tokens per rank.
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
f"sequence_parallel_degree={value}. Please note that logged losses may "
"differ slightly to the non-SP losses due to transformers Trainer "
"implementation details. Please see "
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return self
return value
@model_validator(mode="after")
def validate_ring_attn_func(self):
if getattr(self, "sequence_parallel_degree", 1) == 1:
return self
@field_validator("ring_attn_func", mode="after")
@classmethod
def check_ring_attn_func(cls, value, info):
if not info.data.get("sequence_parallel_degree", 1) > 1:
return value
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if self.ring_attn_func is not None:
if value is not None:
# Set the ring attention function if passed in config
valid_funcs = list(RingAttnFunc)
if self.ring_attn_func in valid_funcs:
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
if value in valid_funcs:
value = RingAttnFunc(value)
else:
raise ValueError(
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}"
f"ring_attn_func: {value} must be one of {valid_funcs}"
)
else:
# Default ring attention function selection
sample_packing = getattr(self, "sample_packing", False)
self.ring_attn_func = (
sample_packing = info.data.get("sample_packing")
value = (
RingAttnFunc.VARLEN_LLAMA3
if sample_packing
else RingAttnFunc.BATCH_RING
)
return self
return value
@model_validator(mode="before")
@classmethod

View File

@@ -348,7 +348,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -358,7 +358,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing:
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
@@ -528,13 +528,6 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
from axolotl.utils.distributed import distributed_state
if distributed_state and distributed_state.initialized:
raise RuntimeError(
"Distributed State already initialized before Deepspeed setup"
)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage:

View File

@@ -0,0 +1,104 @@
"""
E2E smoke tests for LLMCompressorPlugin integration
"""
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
MODELS = [
"nm-testing/llama2.c-stories42M-pruned2.4-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
]
@pytest.mark.parametrize(
"base_model", MODELS, ids=["no-checkpoint-recipe", "with-checkpoint-recipe"]
)
@pytest.mark.parametrize(
"save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"]
)
class TestLLMCompressorIntegration:
"""
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
"""
@require_torch_2_4_1
def test_llmcompressor_plugin(
self, temp_dir, base_model: str, save_compressed: bool
):
# core cfg
cfg = DictDefault(
{
"base_model": base_model,
"plugins": ["axolotl.integrations.llm_compressor.LLMCompressorPlugin"],
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"llmcompressor": {
"recipe": {
"finetuning_stage": {
"finetuning_modifiers": {
"ConstantPruningModifier": {
"targets": [
"re:.*q_proj.weight",
"re:.*k_proj.weight",
"re:.*v_proj.weight",
"re:.*o_proj.weight",
"re:.*gate_proj.weight",
"re:.*up_proj.weight",
"re:.*down_proj.weight",
],
"start": 0,
},
},
},
},
"save_compressed": save_compressed,
},
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
# recipe.yaml should exist
assert (Path(temp_dir) / "recipe.yaml").exists()
# sparsity config exists if save_compressed
if save_compressed:
from compressed_tensors import ModelCompressor
from compressed_tensors.config import Sparse24BitMaskConfig
compressor = ModelCompressor.from_pretrained(temp_dir)
assert compressor is not None
assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig)

View File

@@ -1,77 +0,0 @@
"""
E2E tests for activation checkpointing
"""
import pytest
import transformers
from torch.utils.checkpoint import checkpoint
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@pytest.fixture()
def fix_checkpoint_after_test():
yield
transformers.modeling_utils.checkpoint = checkpoint
class TestActivationCheckpointing:
"""
E2E tests for activation checkpointing
"""
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"eos_token": "<|im_end|>",
},
"datasets": [
{
"chat_template": "chatml",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -99,7 +99,6 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -2,19 +2,14 @@
# pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch
import pytest
import torch
from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.dict import DictDefault
@@ -52,27 +47,6 @@ def fixture_cfg():
return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return batch
class TestRingAttention:
"""Tests for the ring attention functionality."""
@@ -99,6 +73,11 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3
@@ -122,303 +101,88 @@ class TestRingAttention:
set_ring_attn_group(None)
class TestConfigValidation:
"""Tests for validating sequence parallelism configurations."""
# Mock a simplified DataCollator test
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_sequence_parallel_slicing(
mock_world_size, mock_rank, mock_get_group, partial_state
):
"""Test the basic sequence slicing logic without full collator instantiation."""
# Setup mocks
mock_get_group.return_value = MagicMock()
mock_rank.return_value = 1 # Second GPU
mock_world_size.return_value = 4 # 4 GPUs total
@pytest.fixture(autouse=True)
def setup_mocks(self, monkeypatch):
"""Set up mocks for all tests in this class."""
# Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
# Create a sample batch
batch = {
"input_ids": torch.tensor(
[
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
]
),
"attention_mask": torch.ones(2, 12),
}
@pytest.fixture
def base_cfg(self):
"""Create a base configuration for testing."""
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
# Simplified slicing logic from SequenceParallelDataCollator
def slice_batch(batch, rank, world_size):
result = {}
for key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
result[key] = batch[key][:, start_idx:end_idx]
return result
@pytest.mark.parametrize(
"config_updates, expected_values, should_pass, error_msg",
[
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
],
# Slice the batch
result = slice_batch(
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
)
def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg
cfg.update(config_updates)
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
# Check slicing
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
expected_input_ids = torch.tensor(
[
(None, True, RingAttnFunc.VARLEN_LLAMA3),
(None, False, RingAttnFunc.BATCH_RING),
],
ids=["default_with_sample_packing", "default_without_sample_packing"],
[104, 105, 106], # Second slice of first sequence
[204, 205, 206], # Second slice of second sequence
]
)
def test_ring_attn_func_validation(
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value)
assert torch.all(result["input_ids"] == expected_input_ids)
class TestApplySequenceParallelism:
"""Tests for the apply_sequence_parallelism function."""
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
def test_config_validation_with_valid_inputs(cfg):
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks
from axolotl.utils.schemas.config import AxolotlInputConfig
@pytest.fixture(autouse=True)
def mock_distributed(self, monkeypatch):
"""Mock torch.distributed functions for testing."""
# Mock is_initialized to return True
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
}
# Mock get_rank to return 0 by default
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
# Should validate without errors
config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_degree == 2
assert config.flash_attention is True
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
MagicMock,
)
def test_config_validation_with_invalid_inputs(cfg):
"""Test that invalid sequence parallelism configurations fail validation."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": False,
}
def test_world_size_one(self, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
result = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Should return the original batch unchanged
assert result == sequence_parallel_batch
def test_batch_ring_rank0(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
def test_batch_ring_rank1(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
def test_batch_zigzag(self, sequence_parallel_batch):
"""Test BATCH_ZIGZAG sharding pattern."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
seq_len = batch["input_ids"].size(1)
# Test rank 0
result_rank0 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Test rank 1
result_rank1 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Checks for both ranks
assert result_rank0["input_ids"].shape[1] == seq_len // 2
assert result_rank1["input_ids"].shape[1] == seq_len // 2
# For a 2-rank system with 8 tokens, check specific zigzag pattern
# Rank 0 should get chunks [0, 1] and [6, 7]
# Rank 1 should get chunks [2, 3] and [4, 5]
if seq_len == 8:
# Create expected tensors for comparison
rank0_expected = torch.cat(
[original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
)
rank1_expected = torch.cat(
[original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
)
assert torch.equal(result_rank0["input_ids"], rank0_expected)
assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
"""Test that we can create a partially applied version of the function."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" not in result
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
# Verify error message
assert "flash_attention: true must be set" in str(excinfo.value)