Compare commits

...

16 Commits

Author SHA1 Message Date
NanoCode012
2b9a2dde4b chore: update title 2025-04-26 16:21:31 -04:00
Wing Lian
388e950016 restore dockerfile 2025-04-26 16:21:30 -04:00
NanoCode012
fb4adbb311 fix: trim allowed cuda versions 2025-04-26 16:21:30 -04:00
Wing Lian
5e8abca54f use axolotl cloud image as base and various fixes 2025-04-26 16:21:30 -04:00
Wing Lian
168ec339e5 chore: lint 2025-04-26 16:21:30 -04:00
zeke
cb7185998b remove LICENSE and fix README 2025-04-26 16:21:30 -04:00
zeke
c2fc35f520 Add runpod sls handler 2025-04-26 16:21:30 -04:00
Wing Lian
f9c7c3bb72 don't use is_main_process during config validation (#2569) 2025-04-26 14:14:52 -04:00
Wing Lian
caf5cb63ea add e2e smoke test for using activation/gradient checkpointing with offload (#2565)
* add e2e smoke test for using activation/gradient checkpointing with offload

* disable duplicate code check for the test

* fix relative import

* seq len too small to test this dataset with packing

* Fix checkpoint ptaching for tests
2025-04-25 21:11:17 -04:00
Wing Lian
5dba5c82a8 fix support for wandb run_name for rl trainers (#2566) [skip ci]
* fix support for wandb run_name for rl trainers

* prefer to use wandb random names for run_name
2025-04-25 21:10:54 -04:00
Chiwan Park
e3c9d541a7 fix: crash when pretraining_dataset with dispatch_batches is false (#2558) 2025-04-25 17:15:03 -04:00
NanoCode012
9eba0ad118 chore(doc): update docker tags on doc (#2559) [skip ci] 2025-04-25 17:14:48 -04:00
Wing Lian
53dbf97d85 make cce default to true when using the plugin (#2562) [skip ci] 2025-04-25 17:14:26 -04:00
Eko Julianto Salim
2c2563bc34 fix: gradient checkpointing functools.partial object has no attribute __self__ (#2563) [skip ci]
* fix: gradient checkpointing causing functools.partial error

* lint

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-04-25 17:02:37 -04:00
Wing Lian
5cb3398460 don't fail on codecov upload for external contributor PRs (#2564) [skip ci] 2025-04-25 15:10:55 -04:00
Dan Saunders
ae1c7ace63 Sequence parallel training context manager (#2553)
* ctx manager for SP

* updates

* update

* further simplifying

* accommodate both training context managers

* simplifying

* simplifying

* nit

* reorg

* tweak codecov yaml

* add gather post hook, simplify, fixes

* pytest

* pytest fix
2025-04-25 10:33:54 -04:00
32 changed files with 2251 additions and 212 deletions

View File

@@ -8,6 +8,7 @@ on:
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

161
.runpod/.gitignore vendored Normal file
View File

@@ -0,0 +1,161 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
pod/scripts/config.yaml

18
.runpod/Dockerfile Normal file
View File

@@ -0,0 +1,18 @@
FROM runpod/pytorch:3.10-2.0.0-117
COPY .runpod/requirements.txt /requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade pip && \
python3 -m pip install --upgrade -r /requirements.txt
# Environment settings
ARG BASE_VOLUME="/runpod-volume"
ENV BASE_VOLUME=$BASE_VOLUME
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
COPY .runpod/src /src
CMD ["python3", "/src/handler.py"]

335
.runpod/README.md Normal file
View File

@@ -0,0 +1,335 @@
<h1>LLM Post Training- Full fine-tune, LoRA, QLoRa etc. Llama/Mistral/Gemma and more</h1>
# Configuration Options
This document outlines all available configuration options for training models. The configuration can be provided as a JSON request.
## Usage
You can use these configuration Options:
1. As a JSON request body:
```json
{
"input": {
"user_id": "user",
"model_id": "model-name",
"run_id": "run-id",
"credentials": {
"wandb_api_key": "", # add your Weights & biases key. TODO: you will be able to set this in Enviornment variables.
"hf_token": "", # add your HF_token. TODO: you will be able to set this in Enviornment variables.
},
"args": {
"base_model": "NousResearch/Llama-3.2-1B",
// ... other options
}
}
}
```
## Configuration Options
### Model Configuration
| Option | Description | Default |
| ------------------- | --------------------------------------------------------------------------------------------- | -------------------- |
| `base_model` | Path to the base model (local or HuggingFace) | Required |
| `base_model_config` | Configuration path for the base model | Same as base_model |
| `revision_of_model` | Specific model revision from HuggingFace hub | Latest |
| `tokenizer_config` | Custom tokenizer configuration path | Optional |
| `model_type` | Type of model to load | AutoModelForCausalLM |
| `tokenizer_type` | Type of tokenizer to use | AutoTokenizer |
| `hub_model_id` | Repository ID where the model will be pushed on Hugging Face Hub (format: username/repo-name) | Optional |
## Model Family Identification
| Option | Default | Description |
| -------------------------- | ------- | ------------------------------ |
| `is_falcon_derived_model` | `false` | Whether model is Falcon-based |
| `is_llama_derived_model` | `false` | Whether model is LLaMA-based |
| `is_qwen_derived_model` | `false` | Whether model is Qwen-based |
| `is_mistral_derived_model` | `false` | Whether model is Mistral-based |
## Model Configuration Overrides
| Option | Default | Description |
| ----------------------------------------------- | ---------- | ---------------------------------- |
| `overrides_of_model_config.rope_scaling.type` | `"linear"` | RoPE scaling type (linear/dynamic) |
| `overrides_of_model_config.rope_scaling.factor` | `1.0` | RoPE scaling factor |
### Model Loading Options
| Option | Description | Default |
| -------------- | ----------------------------- | ------- |
| `load_in_8bit` | Load model in 8-bit precision | false |
| `load_in_4bit` | Load model in 4-bit precision | false |
| `bf16` | Use bfloat16 precision | false |
| `fp16` | Use float16 precision | false |
| `tf32` | Use tensor float 32 precision | false |
## Memory and Device Settings
| Option | Default | Description |
| ------------------ | --------- | ----------------------- |
| `gpu_memory_limit` | `"20GiB"` | GPU memory limit |
| `lora_on_cpu` | `false` | Load LoRA on CPU |
| `device_map` | `"auto"` | Device mapping strategy |
| `max_memory` | `null` | Max memory per device |
## Training Hyperparameters
| Option | Default | Description |
| ----------------------------- | --------- | --------------------------- |
| `gradient_accumulation_steps` | `1` | Gradient accumulation steps |
| `micro_batch_size` | `2` | Batch size per GPU |
| `eval_batch_size` | `null` | Evaluation batch size |
| `num_epochs` | `4` | Number of training epochs |
| `warmup_steps` | `100` | Warmup steps |
| `warmup_ratio` | `0.05` | Warmup ratio |
| `learning_rate` | `0.00003` | Learning rate |
| `lr_quadratic_warmup` | `false` | Quadratic warmup |
| `logging_steps` | `null` | Logging frequency |
| `eval_steps` | `null` | Evaluation frequency |
| `evals_per_epoch` | `null` | Evaluations per epoch |
| `save_strategy` | `"epoch"` | Checkpoint saving strategy |
| `save_steps` | `null` | Saving frequency |
| `saves_per_epoch` | `null` | Saves per epoch |
| `save_total_limit` | `null` | Maximum checkpoints to keep |
| `max_steps` | `null` | Maximum training steps |
### Dataset Configuration
```yaml
datasets:
- path: vicgalle/alpaca-gpt4 # HuggingFace dataset or TODO: You will be able to add the local path.
type: alpaca # Format type (alpaca, gpteacher, oasst, etc.)
ds_type: json # Dataset type
data_files: path/to/data # Source data files
train_on_split: train # Dataset split to use
```
## Chat Template Settings
| Option | Default | Description |
| ------------------------ | -------------------------------- | ---------------------- |
| `chat_template` | `"tokenizer_default"` | Chat template type |
| `chat_template_jinja` | `null` | Custom Jinja template |
| `default_system_message` | `"You are a helpful assistant."` | Default system message |
## Dataset Processing
| Option | Default | Description |
| ----------------------------- | -------------------------- | --------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
## LoRA Configuration
| Option | Default | Description |
| -------------------------- | ---------------------- | ------------------------------ |
| `adapter` | `"lora"` | Adapter type (lora/qlora) |
| `lora_model_dir` | `""` | Directory with pretrained LoRA |
| `lora_r` | `8` | LoRA attention dimension |
| `lora_alpha` | `16` | LoRA alpha parameter |
| `lora_dropout` | `0.05` | LoRA dropout |
| `lora_target_modules` | `["q_proj", "v_proj"]` | Modules to apply LoRA |
| `lora_target_linear` | `false` | Target all linear modules |
| `peft_layers_to_transform` | `[]` | Layers to transform |
| `lora_modules_to_save` | `[]` | Modules to save |
| `lora_fan_in_fan_out` | `false` | Fan in/out structure |
## Optimization Settings
| Option | Default | Description |
| ------------------------- | ------- | -------------------------- |
| `train_on_inputs` | `false` | Train on input prompts |
| `group_by_length` | `false` | Group by sequence length |
| `gradient_checkpointing` | `false` | Use gradient checkpointing |
| `early_stopping_patience` | `3` | Early stopping patience |
## Learning Rate Scheduling
| Option | Default | Description |
| -------------------------- | ---------- | -------------------- |
| `lr_scheduler` | `"cosine"` | Scheduler type |
| `lr_scheduler_kwargs` | `{}` | Scheduler parameters |
| `cosine_min_lr_ratio` | `null` | Minimum LR ratio |
| `cosine_constant_lr_ratio` | `null` | Constant LR ratio |
| `lr_div_factor` | `null` | LR division factor |
## Optimizer Settings
| Option | Default | Description |
| ---------------------- | ------------ | ------------------- |
| `optimizer` | `"adamw_hf"` | Optimizer choice |
| `optim_args` | `{}` | Optimizer arguments |
| `optim_target_modules` | `[]` | Target modules |
| `weight_decay` | `null` | Weight decay |
| `adam_beta1` | `null` | Adam beta1 |
| `adam_beta2` | `null` | Adam beta2 |
| `adam_epsilon` | `null` | Adam epsilon |
| `max_grad_norm` | `null` | Gradient clipping |
## Attention Implementations
| Option | Default | Description |
| -------------------------- | ------- | ----------------------------- |
| `flash_optimum` | `false` | Use better transformers |
| `xformers_attention` | `false` | Use xformers |
| `flash_attention` | `false` | Use flash attention |
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
| `sdp_attention` | `false` | Use scaled dot product |
| `s2_attention` | `false` | Use shifted sparse attention |
## Tokenizer Modifications
| Option | Default | Description |
| ---------------- | ------- | ---------------------------- |
| `special_tokens` | - | Special tokens to add/modify |
| `tokens` | `[]` | Additional tokens |
## Distributed Training
| Option | Default | Description |
| ----------------------- | ------- | --------------------- |
| `fsdp` | `null` | FSDP configuration |
| `fsdp_config` | `null` | FSDP config options |
| `deepspeed` | `null` | Deepspeed config path |
| `ddp_timeout` | `null` | DDP timeout |
| `ddp_bucket_cap_mb` | `null` | DDP bucket capacity |
| `ddp_broadcast_buffers` | `null` | DDP broadcast buffers |
<details>
<summary><h3>Example Configuration Request:</h3></summary>
Here's a complete example for fine-tuning a LLaMA model using LoRA:
```json
{
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "test-run",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "NousResearch/Llama-3.2-1B",
"load_in_8bit": false,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca"
}
],
"dataset_prepared_path": "last_run_prepared",
"val_set_size": 0.1,
"output_dir": "./outputs/lora-out",
"adapter": "lora",
"sequence_len": 2048,
"sample_packing": true,
"eval_sample_packing": true,
"pad_to_sequence_len": true,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_modules": [
"gate_proj",
"down_proj",
"up_proj",
"q_proj",
"v_proj",
"k_proj",
"o_proj"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": false,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"loss_watchdog_threshold": 5,
"loss_watchdog_patience": 3,
"warmup_steps": 10,
"evals_per_epoch": 4,
"saves_per_epoch": 1,
"weight_decay": 0,
"hub_model_id": "runpod/llama-fr-lora",
"wandb_name": "test-run-1",
"wandb_project": "test-run-1",
"wandb_entity": "axo-test",
"special_tokens": {
"pad_token": "<|end_of_text|>"
}
}
}
}
```
</details>
### Advanced Features
#### Wandb Integration
- `wandb_project`: Project name for Weights & Biases
- `wandb_entity`: Team name in W&B
- `wandb_watch`: Monitor model with W&B
- `wandb_name`: Name of the W&B run
- `wandb_run_id`: ID for the W&B run
#### Performance Optimization
- `sample_packing`: Enable efficient sequence packing
- `eval_sample_packing`: Use sequence packing during evaluation
- `torch_compile`: Enable PyTorch 2.0 compilation
- `flash_attention`: Use Flash Attention implementation
- `xformers_attention`: Use xFormers attention implementation
### Available Optimizers
The following optimizers are supported:
- `adamw_hf`: HuggingFace's AdamW implementation
- `adamw_torch`: PyTorch's AdamW
- `adamw_torch_fused`: Fused AdamW implementation
- `adamw_torch_xla`: XLA-optimized AdamW
- `adamw_apex_fused`: NVIDIA Apex fused AdamW
- `adafactor`: Adafactor optimizer
- `adamw_anyprecision`: Anyprecision AdamW
- `adamw_bnb_8bit`: 8-bit AdamW from bitsandbytes
- `lion_8bit`: 8-bit Lion optimizer
- `lion_32bit`: 32-bit Lion optimizer
- `sgd`: Stochastic Gradient Descent
- `adagrad`: Adagrad optimizer
## Notes
- Set `load_in_8bit: true` or `load_in_4bit: true` for memory-efficient training
- Enable `flash_attention: true` for faster training on modern GPUs
- Use `gradient_checkpointing: true` to reduce memory usage
- Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html).
### Errors:
- if you face any issues with the Flash Attention-2, Delete yoor worker and Re-start.

93
.runpod/hub.json Normal file
View File

@@ -0,0 +1,93 @@
{
"title": "Axolotl Fine-Tuning",
"description": "Serverless fine-tuning of open-source LLMs with Axolotl. Supports LoRA, QLoRA, DPO, and more using Hugging Face models and datasets.",
"type": "serverless",
"category": "language",
"iconUrl": "https://avatars.githubusercontent.com/u/167502477",
"config": {
"runsOn": "GPU",
"containerDiskInGb": 200,
"gpuCount": 1,
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
],
"presets": [],
"env": [
{
"key": "TOKENIZER",
"input": {
"name": "Tokenizer",
"type": "string",
"description": "Name or path of the Hugging Face tokenizer to use.",
"default": "",
"advanced": true
}
},
{
"key": "MAX_NUM_SEQS",
"input": {
"name": "Max Num Seqs",
"type": "number",
"description": "Maximum number of sequences per iteration.",
"default": 256,
"advanced": true
}
},
{
"key": "DISABLE_LOG_STATS",
"input": {
"name": "Disable Log Stats",
"type": "boolean",
"description": "Disable logging statistics.",
"default": false,
"trueValue": "true",
"falseValue": "false"
}
},
{
"key": "LOAD_FORMAT",
"input": {
"name": "Load Format",
"type": "string",
"description": "The format of the model weights to load.",
"default": "auto",
"options": [
{
"label": "auto",
"value": "auto"
},
{
"label": "pt",
"value": "pt"
},
{
"label": "safetensors",
"value": "safetensors"
},
{
"label": "npcache",
"value": "npcache"
},
{
"label": "dummy",
"value": "dummy"
},
{
"label": "tensorizer",
"value": "tensorizer"
},
{
"label": "bitsandbytes",
"value": "bitsandbytes"
}
],
"advanced": true
}
}
]
}
}

15
.runpod/requirements.txt Normal file
View File

@@ -0,0 +1,15 @@
# Required Python packages get listed here, one per line.
# Reccomended to lock the version number to avoid unexpected changes.
# You can also install packages from a git repository, e.g.:
# git+https://github.com/runpod/runpod-python.git
# To learn more, see https://pip.pypa.io/en/stable/reference/requirements-file-format/
runpod~=1.7.0
huggingface_hub
typing-extensions
pydantic
pydantic-settings
hf-transfer
setuptools
numpy==2.0.0
axolotl[flash-attn,deepspeed]

View File

@@ -0,0 +1,577 @@
# # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# # This can also be a relative path to a model on disk
# base_model: ./llama-7b-hf
# # You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
# base_model_ignore_patterns:
# # If the base_model repo on hf hub doesn't include configuration .json files,
# # You can set that here, or leave this empty to default to base_model
# base_model_config: ./llama-7b-hf
# # You can specify to choose a specific model revision from huggingface hub
# model_revision:
# # Optional tokenizer configuration override in case you want to use a different tokenizer
# # than the one defined in the base model
# tokenizer_config:
# # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
# model_type: AutoModelForCausalLM
# # Corresponding tokenizer for the model AutoTokenizer is a good choice
# tokenizer_type: AutoTokenizer
# # Trust remote code for untrusted source
# trust_remote_code:
# # use_fast option for tokenizer loading from_pretrained, default to True
# tokenizer_use_fast:
# # Whether to use the legacy tokenizer setting, defaults to True
# tokenizer_legacy:
# # Resize the model embeddings when new tokens are added to multiples of 32
# # This is reported to improve training speed on some models
# resize_token_embeddings_to_32x:
# # Used to identify which the model is based on
# is_falcon_derived_model:
# is_llama_derived_model:
# # Please note that if you set this to true, `padding_side` will be set to "left" by default
# is_mistral_derived_model:
# is_qwen_derived_model:
# # optional overrides to the base model configuration
# model_config:
# # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
# rope_scaling:
# type: # linear | dynamic
# factor: # float
# # Whether you are training a 4-bit GPTQ quantized model
# gptq: true
# gptq_groupsize: 128 # group size
# gptq_model_v1: false # v1 or v2
# # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
# load_in_8bit: true
# # Use bitsandbytes 4 bit
# load_in_4bit:
# # Use CUDA bf16
# bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
# # Use CUDA fp16
# fp16: true
# # Use CUDA tf32
# tf32: true # require >=ampere
# # No AMP (automatic mixed precision)
# bfloat16: true # require >=ampere
# float16: true
# # A list of one or more datasets to finetune the model with
# datasets:
# # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
# - path: vicgalle/alpaca-gpt4
# # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
# type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
# ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
# data_files: # Optional[str] path to source data files
# shards: # Optional[int] number of shards to split data into
# name: # Optional[str] name of dataset configuration to load
# train_on_split: train # Optional[str] name of dataset split to load from
# # Optional[str] fastchat conversation type, only used with type: sharegpt
# conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# field_human: # Optional[str]. Human key to use for conversation.
# field_model: # Optional[str]. Assistant key to use for conversation.
# # Custom user prompt
# - path: repo
# type:
# # The below are defaults. only set what's needed.
# system_prompt: ""
# system_format: "{system}"
# field_system: system
# field_instruction: instruction
# field_input: input
# field_output: output
# # Customizable to be single line or multi-line
# # 'format' can include {input}
# format: |-
# User: {instruction} {input}
# Assistant:
# # 'no_input_format' cannot include {input}
# no_input_format: "{instruction} "
# # For `completion` datsets only, uses the provided field instead of `text` column
# field:
# # Axolotl attempts to save the dataset as an arrow after packing the data together so
# # subsequent training attempts load faster, relative path
# dataset_prepared_path: data/last_run_prepared
# # Push prepared dataset to hub
# push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set.
# dataset_processes: # defaults to os.cpu_count() if not set
# # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub
# # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
# hub_strategy:
# # Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# # Required to be true when used in combination with `push_dataset_to_hub`
# hf_use_auth_token: # boolean
# # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
# val_set_size: 0.04
# # Num shards for whole dataset
# dataset_shard_num:
# # Index of shard to use for whole dataset
# dataset_shard_idx:
# # The maximum length of an input to train with, this should typically be less than 2048
# # as most models have a token/context limit of 2048
# sequence_len: 2048
# # Pad inputs so each step uses constant sized buffers
# # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
# pad_to_sequence_len:
# # Max sequence length to concatenate training samples together up to
# # Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# # FutureWarning: This will soon be DEPRECATED
# max_packed_sequence_len: 1024
# # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
# sample_packing:
# # Set to 'false' if getting errors during eval with sample_packing on.
# eval_sample_packing:
# # You can set these packing optimizations AFTER starting a training at least once.
# # The trainer will provide recommended values for these values.
# sample_packing_eff_est:
# total_num_tokens:
# # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
# adapter: lora
# # If you already have a lora model trained that you want to load, put that here.
# # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
# lora_model_dir:
# # LoRA hyperparameters
# # For more details about the following options, see:
# # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
# lora_r: 8
# lora_alpha: 16
# lora_dropout: 0.05
# lora_target_modules:
# - q_proj
# - v_proj
# # - k_proj
# # - o_proj
# # - gate_proj
# # - down_proj
# # - up_proj
# lora_target_linear: # If true, will target all linear layers
# # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# # `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# # https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
# lora_modules_to_save:
# # - embed_tokens
# # - lm_head
# # Once you complete training, the model will be saved to the following directory.
# # If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# # Make sure `lora_model_dir` points to this directory if you want to use the trained model.
# lora_out_dir:
# lora_fan_in_fan_out: false
# # ReLoRA configuration
# # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
# relora_steps: # Number of steps per ReLoRA restart
# relora_warmup_steps: # Number of per-restart warmup steps
# relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# # wandb configuration if you're using it
# wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
# wandb_project: # Your wandb project name
# wandb_entity: # A wandb Team name if using a Team
# wandb_watch:
# wandb_run_id: # Set the name of your wandb run
# wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# # Where to save the full-finetuned model to
# output_dir: ./completed-model
# # Whether to use torch.compile and which backend to use
# torch_compile: # bool
# torch_compile_backend: # Optional[str]
# # Training hyperparameters
# # If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
# gradient_accumulation_steps: 1
# # The number of samples to include in each batch. This is the number of samples sent to each GPU.
# micro_batch_size: 2
# eval_batch_size:
# num_epochs: 4
# warmup_steps: 100 # cannot use with warmup_ratio
# warmup_ratio: 0.05 # cannot use with warmup_steps
# learning_rate: 0.00003
# lr_quadratic_warmup:
# logging_steps:
# save_strategy: # Set to `no` to skip checkpoint saves
# save_steps: # Leave empty to save at each epoch
# eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
# save_total_limit: # Checkpoints saved at a time
# # Maximum number of iterations to train for. It precedes num_epochs which means that
# # if both are set, num_epochs will not be guaranteed.
# # e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
# max_steps:
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false
# # Group similarly sized data to minimize padding.
# # May be slower to start, as it must download and sort the entire dataset.
# # Note that training loss may have an oscillating pattern with this enabled.
# group_by_length: false
# # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
# gradient_checkpointing: false
# # Stop training after this many evaluation losses have increased in a row
# # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
# early_stopping_patience: 3
# # Specify a scheduler and kwargs to use with the optimizer
# lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
# lr_scheduler_kwargs:
# # For one_cycle optim
# lr_div_factor: # Learning rate div factor
# # For log_sweep optim
# log_sweep_min_lr:
# log_sweep_max_lr:
# # Specify optimizer
# # Valid values are driven by the Transformers OptimizerNames class, see:
# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
# #
# # Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
# # torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
# # in the examples/ for your model and fine-tuning use case.
# #
# # Valid values for 'optimizer' include:
# # - adamw_hf
# # - adamw_torch
# # - adamw_torch_fused
# # - adamw_torch_xla
# # - adamw_apex_fused
# # - adafactor
# # - adamw_anyprecision
# # - sgd
# # - adagrad
# # - adamw_bnb_8bit
# # - lion_8bit
# # - lion_32bit
# # - paged_adamw_32bit
# # - paged_adamw_8bit
# # - paged_lion_32bit
# # - paged_lion_8bit
# optimizer:
# # Specify weight decay
# weight_decay:
# # adamw hyperparams
# adam_beta1:
# adam_beta2:
# adam_epsilon:
# # Gradient clipping max norm
# max_grad_norm:
# # Augmentation techniques
# # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# # currently only supported on Llama and Mistral
# noisy_embedding_alpha:
# # Whether to bettertransformers
# flash_optimum:
# # Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
# xformers_attention:
# # Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
# flash_attention:
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# # Whether to use scaled-dot-product attention
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# sdp_attention:
# # Landmark attention (only llama)
# landmark_attention:
# # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# # LLaMA only
# xpos_rope:
# # Resume from a specific checkpoint dir
# resume_from_checkpoint:
# # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# # Be careful with this being turned on between different models.
# auto_resume_from_checkpoints: false
# # Don't mess with this, it's here for accelerate and torchrun
# local_rank:
# # Add or change special tokens.
# # If you add tokens here, you don't need to add them to the `tokens` list.
# special_tokens:
# # bos_token: "<s>"
# # eos_token: "</s>"
# # unk_token: "<unk>"
# # Add extra tokens.
# tokens:
# # FSDP
# fsdp:
# fsdp_config:
# # Deepspeed config path. e.g., deepspeed/zero3.json
# deepspeed:
# # Advanced DDP Arguments
# ddp_timeout:
# ddp_bucket_cap_mb:
# ddp_broadcast_buffers:
# # Path to torch distx for optim 'adamw_anyprecision'
# torchdistx_path:
# # Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
# pretraining_dataset:
# # Debug mode
# debug:
# # Seed
# seed:
# # Allow overwrite yml config using from cli
# strict:
base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG}
revision_of_model: ${REVISION_OF_MODEL}
tokenizer_config: ${TOKENIZER_CONFIG}
model_type: ${MODEL_TYPE}
tokenizer_type: ${TOKENIZER_TYPE}
trust_remote_code: ${TRUST_REMOTE_CODE}
tokenizer_use_fast: ${TOKENIZER_USE_FAST}
tokenizer_legacy: ${TOKENIZER_LEGACY}
resize_token_embeddings_to_32x: ${RESIZE_TOKEN_EMBEDDINGS_TO_32X}
is_falcon_derived_model: ${IS_FALCON_DERIVED_MODEL}
is_llama_derived_model: ${IS_LLAMA_DERIVED_MODEL}
is_qwen_derived_model: ${IS_QWEN_DERIVED_MODEL}
is_mistral_derived_model: ${IS_MISTRAL_DERIVED_MODEL}
overrides_of_model_config:
rope_scaling:
type: ${ROPE_SCALING_TYPE}
factor: ${ROPE_SCALING_FACTOR}
bnb_config_kwargs:
llm_int8_has_fp16_weight: ${BNB_LLM_INT8_HAS_FP16_WEIGHT}
bnb_4bit_quant_type: ${BNB_4BIT_QUANT_TYPE}
bnb_4bit_use_double_quant: ${BNB_4BIT_USE_DOUBLE_QUANT}
gptq: ${GPTQ}
load_in_8bit: ${LOAD_IN_8BIT}
load_in_4bit: ${LOAD_IN_4BIT}
bf16: ${BF16}
fp16: ${FP16}
tf32: ${TF32}
bfloat16: ${BFLOAT16}
float16: ${FLOAT16}
gpu_memory_limit: ${GPU_MEMORY_LIMIT}
lora_on_cpu: ${LORA_ON_CPU}
datasets:
- path: ${DATASET_PATH}
type: ${DATASET_TYPE}
ds_type: ${DATASET_DS_TYPE}
data_files: ${DATASET_DATA_FILES}
shards: ${DATASET_SHARDS}
name: ${DATASET_NAME}
train_on_split: ${DATASET_TRAIN_ON_SPLIT}
revision: ${DATASET_REVISION}
trust_remote_code: ${DATASET_TRUST_REMOTE_CODE}
rl: ${RL}
dpo_use_weighting: ${DPO_USE_WEIGHTING}
chat_template: ${CHAT_TEMPLATE}
chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_processes: ${DATASET_PROCESSES}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY}
hf_use_auth_token: ${HF_USE_AUTH_TOKEN}
val_set_size: ${VAL_SET_SIZE}
dataset_shard_num: ${DATASET_SHARD_NUM}
dataset_shard_idx: ${DATASET_SHARD_IDX}
sequence_len: ${SEQUENCE_LEN}
pad_to_sequence_len: ${PAD_TO_SEQUENCE_LEN}
sample_packing: ${SAMPLE_PACKING}
eval_sample_packing: ${EVAL_SAMPLE_PACKING}
sample_packing_eff_est: ${SAMPLE_PACKING_EFF_EST}
total_num_tokens: ${TOTAL_NUM_TOKENS}
sample_packing_group_size: ${SAMPLE_PACKING_GROUP_SIZE}
sample_packing_bin_size: ${SAMPLE_PACKING_BIN_SIZE}
batch_flattening: ${BATCH_FLATTENING}
device_map: ${DEVICE_MAP}
max_memory: ${MAX_MEMORY}
adapter: ${ADAPTER}
lora_model_dir: ${LORA_MODEL_DIR}
lora_r: ${LORA_R}
lora_alpha: ${LORA_ALPHA}
lora_dropout: ${LORA_DROPOUT}
lora_target_modules:
- ${LORA_TARGET_MODULES}
lora_target_linear: ${LORA_TARGET_LINEAR}
peft_layers_to_transform: ${PEFT_LAYERS_TO_TRANSFORM}
lora_modules_to_save: ${LORA_MODULES_TO_SAVE}
lora_fan_in_fan_out: ${LORA_FAN_IN_FAN_OUT}
loraplus_lr_ratio: ${LORAPLUS_LR_RATIO}
loraplus_lr_embedding: ${LORAPLUS_LR_EMBEDDING}
peft:
loftq_config:
loftq_bits: ${LOFTQ_BITS}
relora_steps: ${RELORA_STEPS}
relora_warmup_steps: ${RELORA_WARMUP_STEPS}
relora_anneal_steps: ${RELORA_ANNEAL_STEPS}
relora_prune_ratio: ${RELORA_PRUNE_RATIO}
relora_cpu_offload: ${RELORA_CPU_OFFLOAD}
wandb_mode: ${WANDB_MODE}
wandb_project: ${WANDB_PROJECT}
wandb_entity: ${WANDB_ENTITY}
wandb_watch: ${WANDB_WATCH}
wandb_name: ${WANDB_NAME}
wandb_run_id: ${WANDB_RUN_ID}
wandb_log_model: ${WANDB_LOG_MODEL}
mlflow_tracking_uri: ${MLFLOW_TRACKING_URI}
mlflow_experiment_name: ${MLFLOW_EXPERIMENT_NAME}
mlflow_run_name: ${MLFLOW_RUN_NAME}
hf_mlflow_log_artifacts: ${HF_MLFLOW_LOG_ARTIFACTS}
use_comet: ${USE_COMET}
comet_api_key: ${COMET_API_KEY}
comet_workspace: ${COMET_WORKSPACE}
comet_project_name: ${COMET_PROJECT_NAME}
comet_experiment_key: ${COMET_EXPERIMENT_KEY}
comet_mode: ${COMET_MODE}
comet_online: ${COMET_ONLINE}
comet_experiment_config: ${COMET_EXPERIMENT_CONFIG}
output_dir: ${OUTPUT_DIR}
torch_compile: ${TORCH_COMPILE}
torch_compile_backend: ${TORCH_COMPILE_BACKEND}
gradient_accumulation_steps: ${GRADIENT_ACCUMULATION_STEPS}
micro_batch_size: ${MICRO_BATCH_SIZE}
eval_batch_size: ${EVAL_BATCH_SIZE}
num_epochs: ${NUM_EPOCHS}
warmup_steps: ${WARMUP_STEPS}
warmup_ratio: ${WARMUP_RATIO}
learning_rate: ${LEARNING_RATE}
lr_quadratic_warmup: ${LR_QUADRATIC_WARMUP}
logging_steps: ${LOGGING_STEPS}
eval_steps: ${EVAL_STEPS}
evals_per_epoch: ${EVALS_PER_EPOCH}
save_strategy: ${SAVE_STRATEGY}
save_steps: ${SAVE_STEPS}
saves_per_epoch: ${SAVES_PER_EPOCH}
save_total_limit: ${SAVE_TOTAL_LIMIT}
max_steps: ${MAX_STEPS}
eval_table_size: ${EVAL_TABLE_SIZE}
eval_max_new_tokens: ${EVAL_MAX_NEW_TOKENS}
eval_causal_lm_metrics: ${EVAL_CAUSAL_LM_METRICS}
profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}
early_stopping_patience: ${EARLY_STOPPING_PATIENCE}
lr_scheduler: ${LR_SCHEDULER}
lr_scheduler_kwargs: ${LR_SCHEDULER_KWARGS}
cosine_min_lr_ratio: ${COSINE_MIN_LR_RATIO}
cosine_constant_lr_ratio: ${COSINE_CONSTANT_LR_RATIO}
lr_div_factor: ${LR_DIV_FACTOR}
optimizer: ${OPTIMIZER}
optim_args: ${OPTIM_ARGS}
optim_target_modules: ${OPTIM_TARGET_MODULES}
weight_decay: ${WEIGHT_DECAY}
adam_beta1: ${ADAM_BETA1}
adam_beta2: ${ADAM_BETA2}
adam_epsilon: ${ADAM_EPSILON}
max_grad_norm: ${MAX_GRAD_NORM}
neftune_noise_alpha: ${NEFTUNE_NOISE_ALPHA}
flash_optimum: ${FLASH_OPTIMUM}
xformers_attention: ${XFORMERS_ATTENTION}
flash_attention: ${FLASH_ATTENTION}
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
sdp_attention: ${SDP_ATTENTION}
s2_attention: ${S2_ATTENTION}
resume_from_checkpoint: ${RESUME_FROM_CHECKPOINT}
auto_resume_from_checkpoints: ${AUTO_RESUME_FROM_CHECKPOINTS}
local_rank: ${LOCAL_RANK}
special_tokens:
bos_token: ${SPECIAL_TOKEN_BOS}
eos_token: ${SPECIAL_TOKEN_EOS}
unk_token: ${SPECIAL_TOKEN_UNK}
pad_token: ${SPECIAL_TOKEN_PAD}
tokens: ${TOKENS}
fsdp: ${FSDP}
fsdp_config: ${FSDP_CONFIG}
deepspeed: ${DEEPSPEED}
ddp_timeout: ${DDP_TIMEOUT}
ddp_bucket_cap_mb: ${DDP_BUCKET_CAP_MB}
ddp_broadcast_buffers: ${DDP_BROADCAST_BUFFERS}
torchdistx_path: ${TORCHDISTX_PATH}
pretraining_dataset: ${PRETRAINING_DATASET}
debug: ${DEBUG}
seed: ${SEED}
strict: ${STRICT}

64
.runpod/src/handler.py Normal file
View File

@@ -0,0 +1,64 @@
"""
Runpod serverless entrypoint handler
"""
import os
import runpod
import yaml
from huggingface_hub._login import login
from train import train
from utils import get_output_dir
BASE_VOLUME = os.environ.get("BASE_VOLUME", "/runpod-volume")
if not os.path.exists(BASE_VOLUME):
os.makedirs(BASE_VOLUME)
logger = runpod.RunPodLogger()
async def handler(job):
runpod_job_id = job["id"]
inputs = job["input"]
run_id = inputs.get("run_id", "default_run_id")
args = inputs.get("args", {})
# Set output directory
output_dir = os.path.join(BASE_VOLUME, get_output_dir(run_id))
args["output_dir"] = output_dir
# First save args to a temporary config file
config_path = "/workspace/test_config.yaml"
# Add run_name and job_id to args before saving
args["run_name"] = run_id
args["runpod_job_id"] = runpod_job_id
yaml_data = yaml.dump(args, default_flow_style=False)
with open(config_path, "w", encoding="utf-8") as file:
file.write(yaml_data)
# Handle credentials
credentials = inputs.get("credentials", {})
if "wandb_api_key" in credentials:
os.environ["WANDB_API_KEY"] = credentials["wandb_api_key"]
if "hf_token" in credentials:
os.environ["HF_TOKEN"] = credentials["hf_token"]
if os.environ.get("HF_TOKEN"):
login(token=os.environ["HF_TOKEN"])
else:
logger.info("No HF_TOKEN provided. Skipping login.")
logger.info("Starting Training.")
async for result in train(config_path): # Pass the config path instead of args
logger.info(result)
logger.info("Training Complete.")
# Cleanup
del os.environ["WANDB_API_KEY"]
del os.environ["HF_TOKEN"]
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})

View File

@@ -0,0 +1,61 @@
{
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "NousResearch/Meta-Llama-3-8B",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_8bit": true,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca"
}
],
"val_set_size": 0.05,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 4,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": false,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|end_of_text|>"
}
}
}
}

45
.runpod/src/train.py Normal file
View File

@@ -0,0 +1,45 @@
"""
Runpod train entrypoint
"""
import asyncio
async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
"""
Run preprocessing (if enabled) and training with the given config file
:param config_path: Path to the YAML config file
:param gpu_id: GPU ID to use (default: "0")
:param preprocess: Whether to run preprocessing (default: True)
"""
# First check if preprocessing is needed
if preprocess:
# Preprocess command
preprocess_cmd = (
f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}"
)
process = await asyncio.create_subprocess_shell(
preprocess_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
if process.stdout is not None:
async for line in process.stdout:
yield f"Preprocessing: {line.decode().strip()}"
await process.wait()
yield "Preprocessing completed."
else:
yield "Skipping preprocessing step."
# Training command
train_cmd = f"axolotl train {config_path}"
process = await asyncio.create_subprocess_shell(
train_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
)
if process.stdout is not None:
async for line in process.stdout:
yield f"Training: {line.decode().strip()}"
await process.wait()

89
.runpod/src/utils.py Normal file
View File

@@ -0,0 +1,89 @@
"""
Runpod launcher utils
"""
import os
import yaml
def get_output_dir(run_id):
path = f"fine-tuning/{run_id}"
return path
def make_valid_config(input_args):
"""
Creates and saves updated config file, returns the path to the new config
:param input_args: dict of input args
:return: str, path to the updated config file
"""
# Load default config
with open("config/config.yaml", "r", encoding="utf-8") as fin:
all_args = yaml.safe_load(fin)
if not input_args:
print("No args provided, using defaults")
else:
all_args.update(input_args)
# Create updated config path
updated_config_path = "config/updated_config.yaml"
# Save updated config to new file
with open(updated_config_path, "w", encoding="utf-8") as f:
yaml.dump(all_args, f)
return updated_config_path
def set_config_env_vars(args: dict):
"""
Convert API arguments into environment variables.
Handles nested dictionaries, lists, and special values.
Args:
args (dict): The arguments dictionary from the API request
"""
def process_value(value):
"""Convert Python values to string format for environment variables"""
if value is None:
return ""
if isinstance(value, bool):
return str(value).lower()
if isinstance(value, (list, dict)):
return str(value)
return str(value)
def set_env_vars(data, prefix=""):
"""Recursively set environment variables from nested dictionary"""
for key, value in data.items():
env_key = prefix + key.upper()
# Handle special cases
if isinstance(value, dict):
# For nested dictionaries (like special_tokens)
set_env_vars(value, f"{env_key}_")
elif isinstance(value, list):
# Handle list of dictionaries (like datasets)
if value and isinstance(value[0], dict):
for i, item in enumerate(value):
set_env_vars(item, f"{env_key}_{i}_")
else:
# For simple lists (like lora_target_modules)
os.environ[env_key] = process_value(value)
else:
# Handle all other cases
os.environ[env_key] = process_value(value)
# Clear any existing related environment variables
# This prevents old values from persisting
for key in list(os.environ.keys()):
if key.startswith(
("BASE_MODEL", "MODEL_TYPE", "TOKENIZER_TYPE", "DATASET", "LORA_", "WANDB_")
):
del os.environ[key]
# Set new environment variables
set_env_vars(args)

89
.runpod/tests.json Normal file
View File

@@ -0,0 +1,89 @@
{
"tests": [
{
"name": "quick_smoke_test_sft",
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "NousResearch/Meta-Llama-3-8B",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_8bit": true,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca"
}
],
"val_set_size": 0.05,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 4,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": false,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|end_of_text|>"
}
}
},
"timeout": 100000
}
],
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,
"containerDiskInGb": 200,
"env": [
{
"key": "TOKENIZER",
"value": ""
},
{
"key": "DISABLE_LOG_STATS",
"value": "true"
}
],
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
]
}
}

View File

@@ -52,4 +52,4 @@ pytest -v --durations=10 \
--cov-append \
--cov-report=xml:e2e-coverage.xml
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true

View File

@@ -1,5 +1,7 @@
codecov:
require_ci_to_pass: yes
notify:
wait_for_ci: true
coverage:
precision: 2

View File

@@ -28,6 +28,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.7.0`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
@@ -50,7 +52,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
main-latest
# nightly build
@@ -68,6 +70,7 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`

View File

@@ -10,7 +10,6 @@ plugins:
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
cut_cross_entropy: true
llama4_linearized_experts: true # needed with custom linearized experts model
load_in_4bit: true

View File

@@ -932,9 +932,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
kwargs["ring_attn_func"] = training_args.ring_attn_func
return collator(
*collator_args,
@@ -1051,6 +1048,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
@@ -1121,6 +1121,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args
def build(self, total_num_steps):

View File

@@ -371,13 +371,15 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
loss = super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return loss
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

View File

@@ -6,4 +6,4 @@
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin
from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin

View File

@@ -1,16 +1,86 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
"""
Module for Axolotl trainer sequence parallelism mixin and training context manager
"""
import functools
import logging
import torch
import torch.distributed as dist
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
update_ring_attn_params,
)
LOG = logging.getLogger(__name__)
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
local_rank: Local rank in the sequence parallel group
local_world_size: World size of the sequence parallel group
ring_attn_func: The ring attention function to use
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[local_rank],
chunks[2 * local_world_size - local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous()
return batch
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
@@ -87,3 +157,157 @@ class SequenceParallelMixin:
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
class SequenceParallelContextManager:
"""
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
"""
def __init__(
self,
model: nn.Module,
sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc,
):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Create a partially applied version of the apply_sequence_parallelism function
# with pre-configured params
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(batch=kwargs)
return args, kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output):
# Gather the sharded outputs
return self.gather_outputs(output)
# Register both hooks
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
self.model.register_forward_hook(sequence_parallel_post_hook)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def gather_outputs(self, output):
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
# Handle different output formats (dict, tensor, etc.)
if isinstance(output, dict):
gathered_output = {}
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
# Gather logits or other sequence-sharded tensors
gathered_value = self.gather_tensor(value)
gathered_output[key] = gathered_value
else:
gathered_value = value.clone()
dist.all_reduce(
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
)
gathered_output[key] = gathered_value
return gathered_output
if isinstance(output, torch.Tensor):
return self.gather_tensor(output)
return output
def gather_tensor(self, tensor):
"""Gather a sharded tensor from all ranks."""
# Prepare tensors for all_gather
world_size = self.local_world_size
# Create list to store tensors from all ranks
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
# All-gather operation
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
# Concatenate along sequence dimension (typically dim=1)
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
# Simple concatenation for standard sharding
return torch.cat(gathered_tensors, dim=1)
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
# Each rank has a pattern of (rank, world_size*2-rank-1)
reconstituted_tensors = [None] * (world_size * 2)
# First, split each gathered tensor into its two chunks
for rank, gathered_tensor in enumerate(gathered_tensors):
# Each tensor contains two chunks in the sequence dimension
chunk_size = gathered_tensor.size(1) // 2
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
# Place chunks in their original positions
reconstituted_tensors[rank] = chunk1
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
# Concatenate the reconstituted tensors in the correct order
return torch.cat(reconstituted_tensors, dim=1)
# Otherwise, RingAttnFunc.BATCH_STRIPE
# In striping, each rank has every world_size-th slice
batch_size = tensor.size(0)
hidden_dim = tensor.size(-1)
# First, determine the full sequence length
total_seq_len = 0
for t in gathered_tensors:
total_seq_len += t.size(1)
# Create a tensor to hold the unstriped result
result = torch.zeros(
batch_size,
total_seq_len,
hidden_dim,
dtype=tensor.dtype,
device=tensor.device,
)
# For each rank's tensor, distribute its slices to the correct positions
for rank, gathered_tensor in enumerate(gathered_tensors):
# The rank's tensor contains every world_size-th slice
# starting from its rank position
seq_len = gathered_tensor.size(1)
for i in range(seq_len):
# Calculate the position in the full tensor
pos = i * world_size + rank
if pos < total_seq_len:
result[:, pos] = gathered_tensor[:, i]
return result

View File

@@ -27,8 +27,6 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
```
## Supported Models

View File

@@ -28,7 +28,7 @@ class CutCrossEntropyArgs(BaseModel):
Input args for Cut Cross Entropy.
"""
cut_cross_entropy: Optional[bool] = None
cut_cross_entropy: Optional[bool] = True
@model_validator(mode="before")
@classmethod

View File

@@ -6,6 +6,7 @@ import os
import signal
import sys
import weakref
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Dict
@@ -25,6 +26,9 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -185,16 +189,28 @@ def execute_training(
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
# Define the context managers to use
flash_context = (
torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
)
if cfg.flash_optimum
else nullcontext()
)
sequence_parallel_context = (
SequenceParallelContextManager(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func,
)
if cfg.sequence_parallel_degree > 1
else nullcontext()
)
LOG.info("Starting trainer...")
with flash_context, sequence_parallel_context:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -1,20 +1,12 @@
"""
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
includes logic for handling sequence parallelism collation.
"""
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass
class DataCollatorForSeq2Seq:
@@ -49,8 +41,6 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_degree (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
"""
tokenizer: PreTrainedTokenizerBase
@@ -61,17 +51,6 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
sequence_parallel_degree: int = 1
ring_attn_func: RingAttnFunc | None = None
def __post_init__(self):
if self.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys()
@@ -141,62 +120,8 @@ class DataCollatorForSeq2Seq:
)
features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_degree > 1:
features = self.apply_sequence_parallelism(features)
return features
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary from parent collator.
Returns:
Sliced batch dictionary.
"""
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].size(1)
# Update params for varlen ring attention calculation
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
for key in batch:
if batch[key].size(1) == total_seq_len:
if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# TODO(djsaunde): This doesn't seem to work as expected
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):

View File

@@ -126,9 +126,6 @@ def normalize_config(cfg):
with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f)
if cfg.sequence_parallel_degree is None:
cfg.sequence_parallel_degree = 1
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step

View File

@@ -134,10 +134,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")

View File

@@ -1,5 +1,7 @@
"""custom checkpointing utils"""
from functools import partial
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
@@ -9,6 +11,10 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer.__self__,
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)

View File

@@ -1149,22 +1149,17 @@ class AxolotlInputConfig(
return data
@field_validator("sequence_parallel_degree", mode="after")
@classmethod
def check_sequence_parallel_degree(cls, value, info):
if not value:
value = 1
if value > 1:
if not info.data.get("flash_attention"):
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1:
if not self.flash_attention:
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
if (
info.data.get("sample_packing")
and not info.data["micro_batch_size"] == 1
):
if self.sample_packing and self.micro_batch_size > 1:
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement"
@@ -1184,42 +1179,40 @@ class AxolotlInputConfig(
# according to the proportion of non-padding tokens per rank.
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={value}. Please note that logged losses may "
"differ slightly to the non-SP losses due to transformers Trainer "
"implementation details. Please see "
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return value
return self
@field_validator("ring_attn_func", mode="after")
@classmethod
def check_ring_attn_func(cls, value, info):
if not info.data.get("sequence_parallel_degree", 1) > 1:
return value
@model_validator(mode="after")
def validate_ring_attn_func(self):
if getattr(self, "sequence_parallel_degree", 1) == 1:
return self
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if value is not None:
# Set the ring attention function if passed in config
if self.ring_attn_func is not None:
valid_funcs = list(RingAttnFunc)
if value in valid_funcs:
value = RingAttnFunc(value)
if self.ring_attn_func in valid_funcs:
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
else:
raise ValueError(
f"ring_attn_func: {value} must be one of {valid_funcs}"
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}"
)
else:
# Default ring attention function selection
sample_packing = info.data.get("sample_packing")
value = (
sample_packing = getattr(self, "sample_packing", False)
self.ring_attn_func = (
RingAttnFunc.VARLEN_LLAMA3
if sample_packing
else RingAttnFunc.BATCH_RING
)
return value
return self
@model_validator(mode="before")
@classmethod

View File

@@ -348,7 +348,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
elif cfg.sample_packing:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -358,7 +358,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
if cfg.eval_sample_packing:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
@@ -528,6 +528,13 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
from axolotl.utils.distributed import distributed_state
if distributed_state and distributed_state.initialized:
raise RuntimeError(
"Distributed State already initialized before Deepspeed setup"
)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage:

View File

@@ -0,0 +1,77 @@
"""
E2E tests for activation checkpointing
"""
import pytest
import transformers
from torch.utils.checkpoint import checkpoint
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@pytest.fixture()
def fix_checkpoint_after_test():
yield
transformers.modeling_utils.checkpoint = checkpoint
class TestActivationCheckpointing:
"""
E2E tests for activation checkpointing
"""
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"eos_token": "<|im_end|>",
},
"datasets": [
{
"chat_template": "chatml",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -99,6 +99,7 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -2,14 +2,19 @@
# pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch
import pytest
import torch
from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.dict import DictDefault
@@ -47,6 +52,27 @@ def fixture_cfg():
return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return batch
class TestRingAttention:
"""Tests for the ring attention functionality."""
@@ -73,11 +99,6 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3
@@ -101,88 +122,303 @@ class TestRingAttention:
set_ring_attn_group(None)
# Mock a simplified DataCollator test
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_sequence_parallel_slicing(
mock_world_size, mock_rank, mock_get_group, partial_state
):
"""Test the basic sequence slicing logic without full collator instantiation."""
# Setup mocks
mock_get_group.return_value = MagicMock()
mock_rank.return_value = 1 # Second GPU
mock_world_size.return_value = 4 # 4 GPUs total
class TestConfigValidation:
"""Tests for validating sequence parallelism configurations."""
# Create a sample batch
batch = {
"input_ids": torch.tensor(
[
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
]
),
"attention_mask": torch.ones(2, 12),
}
@pytest.fixture(autouse=True)
def setup_mocks(self, monkeypatch):
"""Set up mocks for all tests in this class."""
# Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
# Simplified slicing logic from SequenceParallelDataCollator
def slice_batch(batch, rank, world_size):
result = {}
for key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
result[key] = batch[key][:, start_idx:end_idx]
return result
@pytest.fixture
def base_cfg(self):
"""Create a base configuration for testing."""
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
# Slice the batch
result = slice_batch(
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
)
# Check slicing
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
expected_input_ids = torch.tensor(
@pytest.mark.parametrize(
"config_updates, expected_values, should_pass, error_msg",
[
[104, 105, 106], # Second slice of first sequence
[204, 205, 206], # Second slice of second sequence
]
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
],
)
assert torch.all(result["input_ids"] == expected_input_ids)
def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg
cfg.update(config_updates)
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
[
(None, True, RingAttnFunc.VARLEN_LLAMA3),
(None, False, RingAttnFunc.BATCH_RING),
],
ids=["default_with_sample_packing", "default_without_sample_packing"],
)
def test_ring_attn_func_validation(
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value)
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
def test_config_validation_with_valid_inputs(cfg):
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks
from axolotl.utils.schemas.config import AxolotlInputConfig
class TestApplySequenceParallelism:
"""Tests for the apply_sequence_parallelism function."""
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
}
@pytest.fixture(autouse=True)
def mock_distributed(self, monkeypatch):
"""Mock torch.distributed functions for testing."""
# Mock is_initialized to return True
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
# Should validate without errors
config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_degree == 2
assert config.flash_attention is True
# Mock get_rank to return 0 by default
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
def test_config_validation_with_invalid_inputs(cfg):
"""Test that invalid sequence parallelism configurations fail validation."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
MagicMock,
)
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": False,
}
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
def test_world_size_one(self, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
result = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify error message
assert "flash_attention: true must be set" in str(excinfo.value)
# Should return the original batch unchanged
assert result == sequence_parallel_batch
def test_batch_ring_rank0(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
def test_batch_ring_rank1(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
def test_batch_zigzag(self, sequence_parallel_batch):
"""Test BATCH_ZIGZAG sharding pattern."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
seq_len = batch["input_ids"].size(1)
# Test rank 0
result_rank0 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Test rank 1
result_rank1 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Checks for both ranks
assert result_rank0["input_ids"].shape[1] == seq_len // 2
assert result_rank1["input_ids"].shape[1] == seq_len // 2
# For a 2-rank system with 8 tokens, check specific zigzag pattern
# Rank 0 should get chunks [0, 1] and [6, 7]
# Rank 1 should get chunks [2, 3] and [4, 5]
if seq_len == 8:
# Create expected tensors for comparison
rank0_expected = torch.cat(
[original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
)
rank1_expected = torch.cat(
[original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
)
assert torch.equal(result_rank0["input_ids"], rank0_expected)
assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
"""Test that we can create a partially applied version of the function."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" not in result
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2