Compare commits

..

17 Commits

Author SHA1 Message Date
Wing Lian
b5198d8734 granite chat multipack support and example 2025-08-02 20:57:00 -04:00
Wing Lian
4ab6a1bd7e add support for granite chat templates 2025-08-02 11:29:03 -04:00
Wing Lian
5639552064 prevent usage of low bit ao optimizers with configurations that use parameter groups (#3003)
* prevent usage of low bit ao optimizers with configurations that use parameter groups

* use optimizer enum value

* fix validation
2025-08-01 17:54:04 -04:00
Wing Lian
cda3c82351 move ib/rdma libs into base image (#3002)
* move ib/rdma libs into base image

* use  --no-install-recommends
2025-08-01 16:10:37 -04:00
Wing Lian
7c3b428f23 Add validation for TP with models with tied embeddings (#2999)
* add validation for tp + tied embeddings models

* fix logic and messaging

* add additional guard for null tp size
2025-08-01 13:58:16 -04:00
Wing Lian
01a6bd1a0e use CCE fix for TP using vocab parallel for CEL (#3000) 2025-08-01 13:21:58 -04:00
NanoCode012
41709822a7 fix: move memory usage log to trainer.log (#2996) [skip ci] 2025-08-01 13:21:43 -04:00
Wing Lian
02a37199ee prevent empty value for vllm_mode (#2998) 2025-08-01 09:59:45 -04:00
NanoCode012
7026cd5e9e Feat: Add N-D parallelism docs (#2989)
* fix: remove non-existent file

* feat: add n-d parallel docs

* fix: comments

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-08-01 13:18:31 +07:00
NanoCode012
eb0a8a7775 feat: upgrade cce commit to include smollm3, granite, granitemoe (#2993) 2025-07-31 18:18:44 -04:00
salman
294c7fe7a6 Distributed/ND-Parallel (#2977) 2025-07-31 15:25:02 -04:00
Wing Lian
7b68dfafd7 jagged lr restart scheudler (#1680) [skip ci]
* jagged lr restart scheudler

var name fix
make sure to create scheduler first

* wire things together

* more fixes

* fix for nesting scheduler and first anneal phase

* no need for relora trainer anymore since we've generalized the relora scheduler

* remove redundant relora scheduler and lint

* update relora e2e test for updated params

* need restart steps for relora test

* update quarto docs for dropped relora trainer

* update example yaml

* drop verbose arg

* min lr scale support for jagged lr

* don't let min_lr be nonetype

* cleanup args
2025-07-31 13:50:03 -04:00
salman
32a7890231 Revert test update to index.qmd (#2995) [skip ci] 2025-07-31 11:46:31 -04:00
Wing Lian
563f5eed7a update dependencies - liger + trl (#2987)
* update dependencies

* set dataset processes for tests

* add support for GSPO
2025-07-31 11:17:17 -04:00
Wing Lian
6ec282094d actually call the register method on plugins (#2991) [skip ci] 2025-07-31 11:13:15 -04:00
salman
09dda462ab Fix don't preview docs for contributors (#2994) [skip ci]
* checking against fork vs. main repo

* force doc preview
2025-07-31 11:12:41 -04:00
Dan Saunders
bb1cae1a20 CLI: add --launcher option, support launcher args, cleanup, refactor (#2924)
* add --launcher option; explicit True/False bool args; small cleanup

* refactor

* add torchrun, accelerate cli args

* add rdzv arg default + tests

* update _quarto

* coderabbit

* fix

* we can't set rdvz_id independently across nodes

* coderabbit

* fix tests
2025-07-30 15:46:56 -04:00
112 changed files with 2579 additions and 1799 deletions

View File

@@ -53,7 +53,7 @@ jobs:
- name: Netlify Publish
uses: nwtgck/actions-netlify@v3.0
if: ${{ secrets.NETLIFY_AUTH_TOKEN != '' }}
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
id: netlify
with:
publish-dir: './_site'
@@ -68,7 +68,7 @@ jobs:
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
- name: Update PR with preview link
if: ${{ steps.netlify.outcome == 'success' && secrets.NETLIFY_AUTH_TOKEN != '' }}
if: ${{ steps.netlify.outcome == 'success' }}
uses: marocchino/sticky-pull-request-comment@v2
with:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -35,25 +35,30 @@ quartodoc:
- cli.train
- cli.evaluate
- cli.args
- cli.art
- cli.checks
- cli.config
- cli.delinearize_llama4
- cli.inference
- cli.merge_lora
- cli.merge_sharded_fsdp_weights
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.quantize
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- cli.quantize
- cli.utils
- cli.utils.args
- cli.utils.fetch
- cli.utils.load
- cli.utils.sweeps
- cli.utils.train
- title: Trainers
desc: Training implementations
contents:
- core.trainers.base
- core.trainers.trl
- core.trainers.mamba
- core.trainers.relora
- core.trainers.dpo.trainer
- core.trainers.grpo.trainer
- core.trainers.grpo.sampler
@@ -269,7 +274,6 @@ website:
- docs/dataset_preprocessing.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/gradient_accumulation.qmd
- section: "Advanced Features"
contents:
@@ -279,6 +283,7 @@ website:
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/nd_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \
pytest -v --durations=10 -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -65,6 +65,9 @@ GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -16,7 +16,10 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
&& apt-get install -y --no-install-recommends \
wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
ibverbs-providers ibverbs-utils infiniband-diags \
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
&& rm -rf /var/cache/apt/archives \
&& rm -rf /var/lib/apt/lists/* \
&& wget \

View File

@@ -15,7 +15,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \

View File

@@ -23,6 +23,20 @@ axolotl <command> [config.yml] [options]
The config file can be local or a URL to a raw YAML file.
### Launcher Arguments
For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:
```bash
# Pass torchrun arguments
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
# Pass accelerate arguments
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4
```
Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).
## Command Reference
### fetch
@@ -80,7 +94,11 @@ axolotl train config.yml \
--num-epochs 3
# Training without accelerate
axolotl train config.yml --no-accelerate
axolotl train config.yml --launcher python
# Pass launcher-specific arguments using -- separator
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml
# Resume training from checkpoint
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
@@ -175,6 +193,9 @@ Evaluates a model's performance (loss etc) on the train and eval datasets.
```bash
# Basic evaluation
axolotl evaluate config.yml
# Evaluation with launcher arguments
axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2
```
### lm-eval
@@ -287,9 +308,6 @@ axolotl preprocess config.yml --cloud cloud_config.yml
# Train on cloud
axolotl train config.yml --cloud cloud_config.yml
# Train without accelerate on cloud
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
# Run lm-eval on cloud
axolotl lm-eval config.yml --cloud cloud_config.yml
```

View File

@@ -69,11 +69,19 @@ export NCCL_BUFFSIZE=2097152
Run the following on each node:
### Option 1: New Axolotl CLI with launcher args (Recommended)
```bash
axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port"
```
### Option 2: Direct torchrun (Legacy)
```bash
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
```
Please make sure to substitute the placeholder variables.
Please make sure to substitute the placeholder variables:
- `num_nodes`: Number of nodes (containing GPUs)
- `gpu_per_node`: Number of gpus per node
@@ -81,8 +89,6 @@ Please make sure to substitute the placeholder variables.
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
- `rdzv_id`: A unique job ID that is used by the job across nodes.
::: {.callout-note}
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
:::
The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)

102
docs/nd_parallelism.qmd Normal file
View File

@@ -0,0 +1,102 @@
# N-D Parallelism
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
- A model's weights are too large to fit on a single GPU's memory.
- A model's activations, especially with very long contexts, are too large for a single GPU.
- You want to accelerate training by using multiple GPUs or nodes.
or combinations of the above!
## Core Concepts
Parallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorch's `DeviceMesh` is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid.
### Data Parallelism {#sec-dp}
Data Parallelism focuses on splitting the global data batch across GPUs.
- Distributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states.
- [Fully Sharded Data Parallel (FSDP)](multi-gpu.qmd#fully-sharded-data-parallel-(fsdp)): A highly memory-efficient form of data parallelism (inspired by DeepSpeed's ZeRO). Instead of replicating the model, FSDP shards the model's *parameters, gradients, and optimizer states* across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an `all_gather` operation just before they are used, and they can be discarded immediately after (`reshard-after-forward`).
- FSDP maps to ZeRO stages:
- ZeRO-2 (`reshard_after_forward=False`): Shards gradients and optimizer states. Model weights are replicated on each GPU.
- ZeRO-3 (`reshard_after_forward=True`): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes).
### [Experimental] Tensor Parallelism (TP) {#sec-tp}
Also known as "horizontal model parallelism," as described in the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Instead of splitting the batch, TP splits the model's layers themselves across GPUs.
- How it works: For a linear layer `Y = XA`, the weight matrix `A` is split column-wise (`A = [A_1, A_2]`). The computation becomes `Y_1 = XA_1` and `Y_2 = XA_2`, which can happen in parallel on different GPUs. The final output `Y` is simply the concatenation of `Y_1` and `Y_2`. Check [this comment](https://github.com/huggingface/transformers/issues/10321#issuecomment-783543530) for more detailed info.
- Requirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes.
### Context Parallelism (CP) {#sec-cp}
Context Parallelism, also called [Sequence Parallelism](sequence_parallelism.qmd), addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs.
- How it works: If you have a sequence of 8192 tokens and a `context_parallel_size` of 4, each GPU will only handle a chunk of 2048 tokens.
- The Challenge: Attention is not local; every token needs to "attend to" every other token. Splitting the sequence breaks this.
- The Solution (`ring-flash-attention`): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a "ring." After `N-1` steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized `flash-attention` kernel at each step.
### Hybrid Sharding Data Parallel (HSDP) {#sec-hsdp}
HSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training.
- Intra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the `all_gather` operations for sharded parameters fast.
- Inter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDP's parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband).
- Example: With 2 nodes of 8 GPUs each (16 total), you could have `dp_shard_size=8` (FSDP within each node) and `dp_replicate_size=2` (DDP across the two nodes).
## Usage
```yaml
# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp
fsdp_version: 2
fsdp_config:
# ...
# The number of GPUs to shard the model parameters across (FSDP dimension).
dp_shard_size: 4
# The number of times to replicate the sharded model (DDP dimension).
dp_replicate_size: 2
# Number of GPUs for Tensor Parallelism.
tensor_parallel_size: 1 # (default is 1, no TP)
# Number of GPUs for Context/Sequence Parallelism.
context_parallel_size: 1 # (default is 1, no CP)
```
Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size`.
## Examples
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
- You want FSDP within each node and DDP across nodes.
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
2. FSDP + TP on a single 8-GPU node:
- You want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP.
- Set `dp_shard_size: 4` and `tensor_parallel_size: 2`.
3. FSDP + CP on a single 8-GPU node for long context:
- You want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs.
- Set `dp_shard_size: 8` and `context_parallel_size: 8`. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group.
## Support Matrix
This matrix describes how different parallelism methods can be combined in Axolotl.
| Combination | `dp_replicate_size` | `dp_shard_size` | `tp_size` | `cp_size` | Status & Notes |
| --- | :---: | :---: |:---:|:---:|---|
| **FSDP** (ZeRO-3) | 1 | >1 | 1 | 1 | ✅ Fully supported. Shards model across all GPUs. |
| **HSDP** | >1 | >1 | 1 | 1 | ✅ Fully supported. FSDP intra-node, DDP inter-node. |
| **FSDP + TP** | 1 | >1 | >1 | 1 | ✅ **2D Parallelism**. Shards the model across a `dp_shard` group, and TP-splits layers within the `tp` group. |
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP/CP without FSDP is inefficient and complex. You should use FSDP instead (`dp_shard_size > 1`). |
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
- `tp_size` refers to `tensor_parallel_size`
- `cp_size` refers to `context_parallel_size`

View File

@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
context_parallel_size: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,7 +30,7 @@ heads_k_stride: 1
ring_attn_func:
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
@@ -66,7 +66,7 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -20,7 +20,7 @@ min_sample_len: 200_000
sample_packing: true
tiled_mlp: true
sequence_parallel_degree: 8
context_parallel_size: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
]
},
{

View File

@@ -25,9 +25,12 @@ lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
relora_steps: 150
relora_warmup_ratio: 0.1
relora: true
relora_prune_ratio: 0.9
relora_cpu_offload: false
jagged_restart_steps: 150
jagged_restart_warmup_steps: 10
jagged_restart_anneal_steps: false
wandb_project:
wandb_entity:

View File

@@ -6,19 +6,19 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.6.0
liger-kernel==0.6.1
# END section
packaging==23.2
huggingface_hub>=0.33.0
peft==0.16.0
transformers==4.54.0
transformers==4.54.1
tokenizers>=0.21.1
accelerate==1.9.0
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
datasets==4.0.0
deepspeed>=0.17.0
trl==0.19.1
trl==0.20.0
hf_xet==1.1.5
optimum==1.16.2

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
)

View File

@@ -72,12 +72,13 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm>=0.10.0"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:

View File

@@ -30,8 +30,6 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=0)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
num_processes: Optional[int] = field(default=None)
@dataclass

View File

@@ -3,7 +3,7 @@ launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union
from typing import Literal
import yaml
@@ -11,7 +11,7 @@ from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
def load_cloud_cfg(cloud_config: Path | str) -> DictDefault:
"""Load and validate cloud configuration."""
# Load cloud configuration.
with open(cloud_config, encoding="utf-8") as file:
@@ -20,8 +20,8 @@ def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
def do_cli_preprocess(
cloud_config: Union[Path, str],
config: Union[Path, str],
cloud_config: Path | str,
config: Path | str,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
@@ -31,9 +31,10 @@ def do_cli_preprocess(
def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
cloud_config: Path | str,
config: Path | str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
cwd=None,
**kwargs,
) -> None:
@@ -44,12 +45,18 @@ def do_cli_train(
local_dirs = {}
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
local_dirs = {"/workspace/mounts": cwd}
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
cloud.train(
config_yaml,
launcher=launcher,
launcher_args=launcher_args,
local_dirs=local_dirs,
**kwargs,
)
def do_cli_lm_eval(
cloud_config: Union[Path, str],
config: Union[Path, str],
cloud_config: Path | str,
config: Path | str,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)

View File

@@ -3,6 +3,7 @@ base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod
from typing import Literal
class Cloud(ABC):
@@ -15,5 +16,12 @@ class Cloud(ABC):
pass
@abstractmethod
def train(self, config_yaml: str, accelerate: bool = True) -> str:
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None,
**kwargs,
):
pass

View File

@@ -8,7 +8,7 @@ import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
from typing import Optional
from typing import Literal
import modal
@@ -230,8 +230,9 @@ class ModalCloud(Cloud):
def train(
self,
config_yaml: str,
accelerate: bool = True,
local_dirs: Optional[dict[str, str]] = None,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None,
**kwargs,
):
modal_fn = self.get_train_env(local_dirs)(_train)
@@ -239,7 +240,8 @@ class ModalCloud(Cloud):
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
launcher=launcher,
launcher_args=launcher_args,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
@@ -270,20 +272,35 @@ def _preprocess(config_yaml: str, volumes=None):
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _train(
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
volumes=None,
**kwargs, # pylint: disable=unused-argument
):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
if accelerate:
accelerate_args = "--accelerate"
launcher_args = launcher_args or []
# Build the base command
if launcher == "accelerate":
launcher_arg = "--launcher accelerate"
elif launcher == "torchrun":
launcher_arg = "--launcher torchrun"
else:
accelerate_args = "--no-accelerate"
num_processes_args = ""
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
launcher_arg = "--launcher python"
# Build launcher args string
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
run_cmd(
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(),
run_folder,
volumes,
)

View File

@@ -200,14 +200,13 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
for key, value in kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict:
if isinstance(cfg[key], bool):
cfg[key] = bool(value)
else:
cfg[k] = kwargs[k]
cfg[key] = value
try:
device_props = torch.cuda.get_device_properties("cuda")

View File

@@ -9,7 +9,6 @@ from typing import Generator, Union
import fire
import torch
from accelerate import init_empty_weights
from dotenv import load_dotenv
from transformers import AutoProcessor
@@ -152,5 +151,4 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -5,7 +5,6 @@ from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
@@ -13,7 +12,6 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -30,9 +28,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# pylint: disable=duplicate-code
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
@@ -64,5 +59,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -9,7 +9,6 @@ from typing import Union
import fire
import torch
import transformers
from dotenv import load_dotenv
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs
@@ -268,5 +267,4 @@ def do_cli(
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -4,12 +4,9 @@
import os
import subprocess # nosec B404
import tempfile
from pathlib import Path
from typing import Optional
from typing import Literal, Optional
import click
import yaml
from dotenv import load_dotenv
import axolotl
@@ -21,13 +18,14 @@ from axolotl.cli.args import (
VllmServeCliArgs,
)
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
filter_none_kwargs,
generate_config_files,
launch_training,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
@@ -36,12 +34,19 @@ from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
LAUNCHER_COMMAND_MAPPING = {
"accelerate": ["accelerate", "launch"],
"torchrun": ["torchrun"],
}
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
print_axolotl_text_art()
load_dotenv()
patch_optimized_env()
@cli.command()
@@ -50,7 +55,7 @@ def cli():
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
"""
Preprocess datasets before training.
@@ -60,7 +65,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
patch_optimized_env()
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
@@ -72,12 +76,15 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
do_cli(config=config, **kwargs)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@click.option(
@@ -88,126 +95,81 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def train(
ctx: click.Context,
config: str,
accelerate: bool,
cloud: Optional[str] = None,
sweep: Optional[str] = None,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
cloud: str | None = None,
sweep: str | None = None,
**kwargs,
) -> None:
):
"""
Train or fine-tune a model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python").
cloud: Path to a cloud accelerator configuration file
sweep: Path to YAML config for sweeping hyperparameters.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if sweep:
# load the sweep configuration yaml file
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)
# Handle Ray launcher override
_launcher = None if kwargs.get("use_ray") else launcher
# generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
def iter_configs():
for perm in permutations:
# open temp directory for temporary configurations
with tempfile.TemporaryDirectory() as temp_dir:
with open(
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
) as fout:
yaml.dump(perm, fout)
yield str(Path(temp_dir) / "config.yaml")
else:
def iter_configs():
yield config
for cfg_file in iter_configs():
# handle errors from subprocess so we can continue rest of sweeps
# Process each configuration
for cfg_file in generate_config_files(config, sweep):
try:
if accelerate:
if cloud:
from axolotl.cli.cloud import do_cli_train
cwd = os.getcwd()
do_cli_train(
cloud_config=cloud,
config=config,
accelerate=True,
cwd=cwd,
**kwargs,
)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num_processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
from axolotl.cli.cloud import do_cli_train
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
else:
from axolotl.cli.train import do_cli
do_cli(config=cfg_file, **kwargs)
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
except subprocess.CalledProcessError as exc:
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
finally:
# Only delete temp files, not the original config
if cfg_file != config:
os.unlink(cfg_file)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU evaluation",
)
@add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
@click.pass_context
def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
"""
Evaluate a model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python").
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.evaluate"]
)
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
@@ -218,30 +180,42 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
do_cli(config=config, **kwargs)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU inference",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU inference",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
@click.pass_context
def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
"""
Run inference with a trained model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
gradio: Whether to use Gradio browser interface or command line for inference.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.inference"]
)
if config:
base_cmd.append(config)
if gradio:
@@ -254,33 +228,42 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
do_cli(config=config, gradio=gradio, **kwargs)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for weight merging",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for weight merging",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
@click.pass_context
def merge_sharded_fsdp_weights(
ctx: click.Context, config: str, launcher: str, **kwargs
):
"""
Merge sharded FSDP model weights.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python").
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = [
"accelerate",
"launch",
"-m",
"axolotl.cli.merge_sharded_fsdp_weights",
]
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.merge_sharded_fsdp_weights"]
)
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
@@ -296,7 +279,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_lora(config: str, **kwargs) -> None:
def merge_lora(config: str, **kwargs):
"""
Merge trained LoRA adapters into a base model.
@@ -313,7 +296,7 @@ def merge_lora(config: str, **kwargs) -> None:
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]) -> None:
def fetch(directory: str, dest: Optional[str]):
"""
Fetch example configs or other resources.
@@ -351,7 +334,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs):
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))
def delinearize_llama4(model: str, output: str) -> None:
def delinearize_llama4(model: str, output: str):
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
do_delinearize_llama4(model, output)
@@ -365,5 +348,4 @@ def main():
if __name__ == "__main__":
load_dotenv()
main()

View File

@@ -4,7 +4,6 @@ from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
@@ -70,7 +69,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
sequence_parallel_degree=None,
context_parallel_size=None,
deepspeed=None,
fsdp=None,
fsdp_config=None,
@@ -88,5 +87,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -17,7 +17,6 @@ from accelerate.utils import (
WEIGHTS_NAME,
is_torch_version,
)
from dotenv import load_dotenv
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
@@ -204,5 +203,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -9,7 +9,6 @@ import fire
import transformers
from accelerate import init_empty_weights
from colorama import Fore
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM
from axolotl.cli.args import PreprocessCliArgs
@@ -109,5 +108,4 @@ def do_cli(
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -7,7 +7,6 @@ from typing import Union
import fire
from accelerate import Accelerator
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
@@ -16,7 +15,6 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
@@ -31,9 +29,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
@@ -122,5 +117,4 @@ def ray_train_func(kwargs: dict):
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,330 +0,0 @@
"""Utility methods for axolotl CLI."""
import concurrent.futures
import dataclasses
import hashlib
import json
from functools import wraps
from pathlib import Path
from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
import requests
from pydantic import BaseModel
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
)
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def strip_optional_type(field_type: type | str | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.
Args:
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
Args:
func: Function to wrap.
Returns:
Wrapped function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
"""
Create Click options from the fields of a dataclass.
Args:
config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = strip_optional_type(field.type)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name,
default=field.default,
help=field.metadata.get("description"),
)(function)
else:
option_name = f"--{field.name.replace('_', '-')}"
function = click.option(
option_name,
type=field_type,
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
Args:
config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
field_type = strip_optional_type(field.annotation)
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name, default=None, help=field.description
)(function)
else:
option_name = f"--{name.replace('_', '-')}"
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
"""
Build command list from base command and options.
Args:
base_cmd: Command without options.
options: Options to parse and append to base command.
Returns:
List of strings giving shell command.
"""
cmd = base_cmd.copy()
for key, value in options.items():
if value is None:
continue
key = key.replace("_", "-")
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.extend([f"--{key}", str(value)])
return cmd
def download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> tuple[str, str]:
"""
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha).
raw_base_url: Base URL for raw GitHub content.
dest_path: Local destination directory.
dir_prefix: Directory prefix to filter files.
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
dest_file = dest_path / file_path.split(dir_prefix)[-1]
# Check if file exists and needs updating
if dest_file.exists():
with open(dest_file, "rb") as file:
content = file.read()
# Calculate git blob SHA
blob = b"blob " + str(len(content)).encode() + b"\0" + content
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
if local_sha == remote_sha:
print(f"Skipping {file_path} (unchanged)")
return file_path, "unchanged"
print(f"Updating {file_path}")
status = "new"
else:
print(f"Downloading {file_path}")
status = "new"
# Create directories if needed
dest_file.parent.mkdir(parents=True, exist_ok=True)
# Download and save file
try:
response = requests.get(raw_url, timeout=30)
response.raise_for_status()
with open(dest_file, "wb") as file:
file.write(response.content)
return file_path, status
except (requests.RequestException, IOError) as request_error:
print(f"Error downloading {file_path}: {str(request_error)}")
return file_path, "error"
def fetch_from_github(
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
'deepspeed_configs/').
dest_dir: Local destination directory.
max_workers: Maximum number of concurrent downloads.
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
# Get repository tree with timeout
response = requests.get(api_url, timeout=30)
response.raise_for_status()
tree = json.loads(response.text)
# Filter for files and get their SHA
files = {
item["path"]: item["sha"]
for item in tree["tree"]
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
}
if not files:
raise click.ClickException(f"No files found in {dir_prefix}")
# Default destination directory is the last part of dir_prefix
default_dest = Path(dir_prefix.rstrip("/"))
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: dict[str, list[str]] = {
"new": [],
"updated": [],
"unchanged": [],
"error": [],
}
# Process files in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {
executor.submit(
download_file,
(file_path, remote_sha),
raw_base_url,
dest_path,
dir_prefix,
): file_path
for file_path, remote_sha in files.items()
}
# Process completed tasks as they finish
for future in concurrent.futures.as_completed(future_to_file):
file_path = future_to_file[future]
try:
file_path, status = future.result()
files_processed[status].append(file_path)
except (requests.RequestException, IOError) as request_error:
print(f"Error processing {file_path}: {str(request_error)}")
files_processed["error"].append(file_path)
# Log summary
LOG.info("\nSync Summary:")
LOG.info(f"New files: {len(files_processed['new'])}")
LOG.info(f"Updated files: {len(files_processed['updated'])}")
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[
PreTrainedModel,
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
ProcessorMixin | None,
]:
"""
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
model, _ = model_loader.load()
processor = None
if cfg.is_multimodal:
LOG.info("loading processor...")
processor = load_processor(cfg, tokenizer)
return model, tokenizer, processor

View File

@@ -0,0 +1,23 @@
"""Init for axolotl.cli.utils module."""
from .args import (
add_options_from_config,
add_options_from_dataclass,
filter_none_kwargs,
)
from .fetch import fetch_from_github
from .load import load_model_and_tokenizer
from .sweeps import generate_sweep_configs
from .train import build_command, generate_config_files, launch_training
__all__ = [
"filter_none_kwargs",
"add_options_from_dataclass",
"add_options_from_config",
"build_command",
"generate_config_files",
"generate_sweep_configs",
"load_model_and_tokenizer",
"launch_training",
"fetch_from_github",
]

View File

@@ -0,0 +1,120 @@
"""Utilities for axolotl CLI args."""
import dataclasses
from functools import wraps
from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
from pydantic import BaseModel
def _strip_optional_type(field_type: type | str | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.
Args:
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
Args:
func: Function to wrap.
Returns:
Wrapped function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
"""
Create Click options from the fields of a dataclass.
Args:
config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = _strip_optional_type(field.type)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name,
default=field.default,
help=field.metadata.get("description"),
)(function)
else:
option_name = f"--{field.name.replace('_', '-')}"
function = click.option(
option_name,
type=field_type,
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
Args:
config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name, default=None, help=field.description
)(function)
else:
option_name = f"--{name.replace('_', '-')}"
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator

View File

@@ -0,0 +1,142 @@
"""Utilities for axolotl fetch CLI command."""
import concurrent.futures
import hashlib
import json
from pathlib import Path
import click
import requests
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> tuple[str, str]:
"""
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha).
raw_base_url: Base URL for raw GitHub content.
dest_path: Local destination directory.
dir_prefix: Directory prefix to filter files.
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
dest_file = dest_path / file_path.split(dir_prefix)[-1]
# Check if file exists and needs updating
if dest_file.exists():
with open(dest_file, "rb") as file:
content = file.read()
# Calculate git blob SHA
blob = b"blob " + str(len(content)).encode() + b"\0" + content
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
if local_sha == remote_sha:
print(f"Skipping {file_path} (unchanged)")
return file_path, "unchanged"
print(f"Updating {file_path}")
status = "updated"
else:
print(f"Downloading {file_path}")
status = "new"
# Create directories if needed
dest_file.parent.mkdir(parents=True, exist_ok=True)
# Download and save file
try:
response = requests.get(raw_url, timeout=30)
response.raise_for_status()
with open(dest_file, "wb") as file:
file.write(response.content)
return file_path, status
except (requests.RequestException, IOError) as request_error:
print(f"Error downloading {file_path}: {str(request_error)}")
return file_path, "error"
def fetch_from_github(
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
'deepspeed_configs/').
dest_dir: Local destination directory.
max_workers: Maximum number of concurrent downloads.
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
# Get repository tree with timeout
response = requests.get(api_url, timeout=30)
response.raise_for_status()
tree = json.loads(response.text)
# Filter for files and get their SHA
files = {
item["path"]: item["sha"]
for item in tree["tree"]
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
}
if not files:
raise click.ClickException(f"No files found in {dir_prefix}")
# Default destination directory is the last part of dir_prefix
default_dest = Path(dir_prefix.rstrip("/"))
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: dict[str, list[str]] = {
"new": [],
"updated": [],
"unchanged": [],
"error": [],
}
# Process files in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {
executor.submit(
_download_file,
(file_path, remote_sha),
raw_base_url,
dest_path,
dir_prefix,
): file_path
for file_path, remote_sha in files.items()
}
# Process completed tasks as they finish
for future in concurrent.futures.as_completed(future_to_file):
file_path = future_to_file[future]
try:
file_path, status = future.result()
files_processed[status].append(file_path)
except (requests.RequestException, IOError) as request_error:
print(f"Error processing {file_path}: {str(request_error)}")
files_processed["error"].append(file_path)
# Log summary
LOG.info("\nSync Summary:")
LOG.info(f"New files: {len(files_processed['new'])}")
LOG.info(f"Updated files: {len(files_processed['updated'])}")
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")

View File

@@ -0,0 +1,52 @@
"""Utilities for model, tokenizer, etc. loading."""
from typing import Any
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
)
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[
PreTrainedModel,
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
ProcessorMixin | None,
]:
"""
Helper function for loading a model, tokenizer, and processor specified in the
given `axolotl` config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
model, _ = model_loader.load()
processor = None
if cfg.is_multimodal:
LOG.info("loading processor...")
processor = load_processor(cfg, tokenizer)
return model, tokenizer, processor

View File

@@ -0,0 +1,188 @@
"""Utilities for axolotl train CLI command."""
import os
import subprocess # nosec
import tempfile
from typing import Any, Iterator, Literal
import yaml
from axolotl.cli.utils.sweeps import generate_sweep_configs
def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]:
"""
Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing.
Args:
launcher_args: List of launcher arguments
Returns:
Updated launcher args with defaults added if needed
"""
args = launcher_args.copy()
# Check if rdzv_endpoint is present
has_rdzv_endpoint = any("--rdzv_endpoint" in arg for arg in args)
if has_rdzv_endpoint:
# Check if rdzv_backend is already provided
has_rdzv_backend = any("--rdzv_backend" in arg for arg in args)
if not has_rdzv_backend:
args.extend(["--rdzv_backend", "c10d"])
# Check if rdzv_id is already provided
has_rdzv_id = any("--rdzv_id" in arg for arg in args)
if not has_rdzv_id:
import uuid
args.extend(["--rdzv_id", str(uuid.uuid4())[:8]])
return args
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
"""
Build command list from base command and options.
Args:
base_cmd: Command without options.
options: Options to parse and append to base command.
Returns:
List of strings giving shell command.
"""
cmd = base_cmd.copy()
for key, value in options.items():
if value is None:
continue
key = key.replace("_", "-")
cmd.append(f"--{key}={value}")
return cmd
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
"""Generate list of configuration files to process."""
if not sweep:
yield config
return
# Load sweep and base configurations
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
for permutation in permutations:
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
mode="w",
suffix=".yaml",
delete=False,
encoding="utf-8",
)
yaml.dump(permutation, temp_file)
temp_file.close()
yield temp_file.name
def launch_training(
cfg_file: str,
launcher: Literal["accelerate", "torchrun", "python"] | None,
cloud: str | None,
kwargs: dict,
launcher_args: list[str] | None = None,
) -> None:
"""Execute training with the given configuration."""
launcher_args = launcher_args or []
if cloud:
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
elif launcher:
if launcher == "accelerate":
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
elif launcher == "torchrun":
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
elif launcher == "python":
_launch_python_training(cfg_file, kwargs)
def _launch_cloud_training(
cloud: str,
cfg_file: str,
launcher: Literal["accelerate", "torchrun", "python"] | None,
kwargs: dict,
launcher_args: list[str] | None = None,
) -> None:
"""Execute training via cloud launcher."""
from axolotl.cli.cloud import do_cli_train
launcher_args = launcher_args or []
cwd = os.getcwd() if launcher else None
do_cli_train(
cloud_config=cloud,
config=cfg_file,
launcher=launcher or "accelerate",
launcher_args=launcher_args,
cwd=cwd,
**kwargs,
)
def _launch_accelerate_training(
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
) -> None:
"""Execute training via accelerate launcher."""
launcher_args = launcher_args or []
internal_launcher_args = []
# Extract launcher-specific arguments from kwargs (legacy support)
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port")
internal_launcher_args.extend(["--main_process_port", str(main_process_port)])
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes")
internal_launcher_args.extend(["--num_processes", str(num_processes)])
# Combine internal args with user-provided launcher args
all_launcher_args = internal_launcher_args + launcher_args
base_cmd = (
["accelerate", "launch"] + all_launcher_args + ["-m", "axolotl.cli.train"]
)
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
def _launch_torchrun_training(
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
) -> None:
"""Execute training via torchrun launcher."""
launcher_args = launcher_args or []
# Add default RDZV arguments if rdzv_endpoint is set
launcher_args = _add_default_rdzv_args(launcher_args)
base_cmd = ["torchrun"] + launcher_args + ["-m", "axolotl.cli.train"]
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
"""Execute training via python launcher."""
from axolotl.cli.train import do_cli
do_cli(config=cfg_file, **kwargs)

View File

@@ -24,9 +24,11 @@ from pathlib import Path
from typing import Any
import torch
from accelerate import PartialState
from transformers import (
TrainerCallback,
)
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager
@@ -34,7 +36,6 @@ from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
)
@@ -139,8 +140,6 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.save_first_step:
callbacks.append(SaveModelOnFirstStepCallback())
callbacks.append(GPUStatsCallback(cfg=self.cfg))
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
@@ -434,8 +433,30 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_accelerator_config(self, training_args_kwargs: dict):
partial_state = PartialState()
has_pc_attr = (
hasattr(partial_state, "parallelism_config")
and partial_state.parallelism_config
)
has_pc_key = (
"parallelism_config"
in partial_state._shared_state # pylint: disable=protected-access
and partial_state._shared_state[ # pylint: disable=protected-access
"parallelism_config"
]
)
use_configured_state = has_pc_attr or has_pc_key
if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
use_configured_state = self.cfg.accelerator_config.pop(
"use_configured_state", use_configured_state
)
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state, **self.cfg.accelerator_config
)
else:
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state,
)
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.activation_offloading is True:

View File

@@ -19,7 +19,6 @@ from axolotl.core.trainers import (
AxolotlPRMTrainer,
AxolotlRewardTrainer,
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
@@ -58,7 +57,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
if self.cfg.relora_steps:
if self.cfg.relora:
callbacks.append(ReLoRACallback(self.cfg))
if (
@@ -131,8 +130,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
@@ -271,20 +268,25 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.sample_packing_eff_est
)
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = (
self.cfg.relora_warmup_steps
)
if self.cfg.relora_anneal_steps:
training_arguments_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora and self.cfg.jagged_restart_steps:
if self.cfg.relora_prune_ratio:
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.jagged_restart_steps:
training_arguments_kwargs["jagged_restart_steps"] = (
self.cfg.jagged_restart_steps
)
if self.cfg.jagged_restart_warmup_steps:
training_arguments_kwargs["jagged_restart_warmup_steps"] = (
self.cfg.jagged_restart_warmup_steps
)
if self.cfg.jagged_restart_anneal_steps:
training_arguments_kwargs["jagged_restart_anneal_steps"] = (
self.cfg.jagged_restart_anneal_steps
)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs["lisa_step_interval"] = (

View File

@@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
sequence_parallel=self.cfg.context_parallel_size > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))

View File

@@ -7,7 +7,6 @@ from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,

View File

@@ -27,6 +27,7 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
DistributedParallelMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
@@ -37,6 +38,8 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_tagging,
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -50,6 +53,7 @@ class AxolotlTrainer(
RngLoaderMixin,
CheckpointSaveMixin,
ActivationOffloadingMixin,
DistributedParallelMixin,
Trainer,
):
"""Extend the base Trainer for axolotl helpers"""
@@ -558,6 +562,17 @@ class AxolotlTrainer(
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
if is_main_process():
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_memory_active"] = active
logs["memory/max_memory_allocated"] = allocated
logs["memory/device_memory_reserved"] = reserved
except (ValueError, FileNotFoundError):
pass
del self._stored_metrics[train_eval]
return super().log(logs, start_time)

View File

@@ -8,7 +8,11 @@ import torch
from torch import nn
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
RngLoaderMixin,
SchedulerMixin,
)
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
@@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import (
class AxolotlDPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DPOTrainer,
DistributedParallelMixin,
):
"""Extend the base DPOTrainer for axolotl helpers."""

View File

@@ -49,7 +49,8 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode:
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode == "colocate":
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
vllm_cfg.gpu_memory_utilization
@@ -82,8 +83,13 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.sequence_parallel_degree > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
if cfg.context_parallel_size > 1:
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
if trl.importance_sampling_level is not None:
grpo_args_kwargs["importance_sampling_level"] = (
trl.importance_sampling_level
)
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights

View File

@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
sequence_parallel_degree: int | None = None
context_parallel_size: int | None = None

View File

@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
- Data is properly distributed across SP groups.
In the table below, the values represent dataset indices. Each SP group has
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
`context_parallel_size = 2` GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
Sequence Parallel Groups
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: Rank of current process.
batch_size: Number of samples per batch.
repeat_count: How many times to repeat the full sampling process.
sequence_parallel_degree: Number of ranks in a sequence parallel group.
context_parallel_size: Number of ranks in a sequence parallel group.
shuffle: Whether to shuffle the dataset.
seed: Random seed for shuffling.
drop_last: Whether to drop the last incomplete batch.
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
context_parallel_size: int = 1,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
self.rank = rank
# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
self.context_parallel_size = context_parallel_size
self.num_sp_groups = world_size // context_parallel_size
self.sp_group_id = rank // context_parallel_size
# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)

View File

@@ -43,7 +43,11 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
RngLoaderMixin,
SchedulerMixin,
)
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
@@ -53,7 +57,12 @@ if is_peft_available():
class AxolotlGRPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
GRPOTrainer,
):
"""Extend the base GRPOTrainer for axolotl helpers"""
@@ -100,7 +109,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Get number of SP groups (number of processes divided by SP degree)
num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree
num_sp_groups = num_processes // self.args.context_parallel_size
# Calculate batch size per SP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
@@ -130,7 +139,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
if self.num_generations not in possible_values:
raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, "
@@ -167,9 +176,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
rank=self.rank,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
// self.args.context_parallel_size,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
sequence_parallel_degree=self.args.sequence_parallel_degree,
context_parallel_size=self.args.context_parallel_size,
shuffle=True,
seed=self.args.seed,
drop_last=True,
@@ -235,7 +244,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
return dataloader
# Otherwise prepare with accelerator
@@ -308,18 +317,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate sequence parallel group information
world_size = self.accelerator.num_processes
sequence_parallel_degree = self.args.sequence_parallel_degree
num_sp_groups = world_size // sequence_parallel_degree
context_parallel_size = self.args.context_parallel_size
num_sp_groups = world_size // context_parallel_size
# Since processes in the same SP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each SP group
ordered_set_of_prompts = []
for sp_group_id in range(num_sp_groups):
# Get the first process from each SP group (typically the group leader)
group_leader_rank = sp_group_id * sequence_parallel_degree
group_leader_rank = sp_group_id * context_parallel_size
# Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group
@@ -335,7 +344,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[
:: self.num_generations * self.args.sequence_parallel_degree
:: self.num_generations * self.args.context_parallel_size
]
with profiling_context(self, "vLLM.generate"):
@@ -352,14 +361,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
)
else:
completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree
len(all_prompts_text) // self.args.context_parallel_size
)
# Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)
# Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size
@@ -583,7 +592,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size

View File

@@ -5,6 +5,7 @@ import torch
from axolotl.core.trainers.base import AxolotlTrainer
# pylint: disable=too-many-ancestors
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""

View File

@@ -5,6 +5,7 @@
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .distributed_parallel import DistributedParallelMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin
from .rng_state_loader import RngLoaderMixin

View File

@@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer):
def _save_optimizer_and_scheduler(self, output_dir):
try:
super()._save_optimizer_and_scheduler(output_dir)
except NotImplementedError as exc:
LOG.warning(
except (NotImplementedError, KeyError) as exc:
# TODO: fix fsdp2 optimizer saving
LOG.warning_once(
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible."
"for this training run will not be possible.",
main_process_only=True,
)

View File

@@ -0,0 +1,20 @@
"""
Mixin for correctly saving fsdp
"""
from transformers import Trainer
class DistributedParallelMixin(Trainer):
"""
Mixin for correctly saving fsdp
"""
def _save(self, output_dir: str | None = None, state_dict=None):
if (
state_dict is None
and self.accelerator.parallelism_config
and self.accelerator.parallelism_config.dp_shard_enabled
):
state_dict = self.accelerator.get_state_dict(self.model)
super()._save(output_dir, state_dict=state_dict)

View File

@@ -7,6 +7,7 @@ from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.logging import get_logger
from axolotl.utils.schedulers import (
JaggedLRRestartScheduler,
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
@@ -113,7 +114,7 @@ class SchedulerMixin(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning(
@@ -123,4 +124,22 @@ class SchedulerMixin(Trainer):
LOG.warning(
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
if self.args.jagged_restart_steps:
warmup_steps = (
self.args.jagged_restart_warmup_steps or 10
)
anneal_steps = (
self.args.jagged_restart_anneal_steps or 1
)
if not self.lr_scheduler:
super().create_scheduler(num_training_steps, optimizer)
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,
warmup_steps,
anneal_steps,
min_lr_scale=self.args.cosine_min_lr_ratio or 0.001,
)
return self.lr_scheduler # type: ignore

View File

@@ -1,46 +0,0 @@
"""Module for ReLoRA trainer"""
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
class ReLoRATrainer(AxolotlTrainer):
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
) -> LRScheduler:
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler: LRScheduler = super().create_scheduler(
num_training_steps, optimizer
)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler( # type: ignore
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler # type: ignore
return self.lr_scheduler # type: ignore

View File

@@ -8,13 +8,18 @@ from trl import (
RewardTrainer,
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
ORPOTrainer,
):
"""
Extend the base ORPOTrainer for axolotl helpers
@@ -24,7 +29,12 @@ class AxolotlORPOTrainer(
class AxolotlKTOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
KTOTrainer,
):
"""
Extend the base KTOTrainer for axolotl helpers
@@ -34,7 +44,12 @@ class AxolotlKTOTrainer(
class AxolotlCPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
CPOTrainer,
):
"""
Extend the base CPOTrainer for axolotl helpers
@@ -44,7 +59,12 @@ class AxolotlCPOTrainer(
class AxolotlRewardTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
RewardTrainer,
):
"""
Extend the base RewardTrainer for axolotl helpers
@@ -54,7 +74,12 @@ class AxolotlRewardTrainer(
class AxolotlPRMTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
PRMTrainer,
):
"""
Extend the base trl.PRMTrainer for axolotl helpers

View File

@@ -82,18 +82,26 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
jagged_restart_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for jagged restarts"},
)
jagged_restart_warmup_steps: Optional[int] = field(
default=None,
metadata={
"help": "how many warmup steps to take after reset for jagged restarts"
},
)
jagged_restart_anneal_steps: Optional[int] = field(
default=None,
metadata={
"help": "how many anneal steps to take before reset for jagged restarts"
},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
```
## Usage
@@ -41,6 +41,8 @@ plugins:
- gemma3n_text
- glm
- glm4
- granite
- granitemoe
- llama
- llama4
- llama4_text
@@ -56,6 +58,8 @@ plugins:
- qwen2_5_vl
- qwen3
- qwen3_moe
- smollm3
- voxtral
## Citation

View File

@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
)

View File

@@ -21,6 +21,7 @@ from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
# pylint: disable=too-many-ancestors
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)

View File

@@ -16,8 +16,6 @@
Module for handling LIGER input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
@@ -30,13 +28,13 @@ class LigerArgs(BaseModel):
Input args for LIGER.
"""
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
liger_rope: bool | None = None
liger_rms_norm: bool | None = None
liger_layer_norm: bool | None = None
liger_swiglu: bool | None = None
liger_glu_activation: bool | None = None
liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: bool | None = None
@model_validator(mode="before")
@classmethod
@@ -66,3 +64,20 @@ class LigerArgs(BaseModel):
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
)
return data
@model_validator(mode="before")
@classmethod
def check_liger_rms_norm_tensor_parallel(cls, data):
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
raise ValueError(
"`liger_rms_norm` is incompatible with tensor parallelism, "
"see https://github.com/linkedin/Liger-Kernel/issues/826"
)
return data
@model_validator(mode="after")
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
# TODO @SalmanMohammadi this is a larger fix - investigate
if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy:
raise ValueError("Tensor parallelism is not compatible with liger losses.")
return self

View File

@@ -1,11 +0,0 @@
"""
Axolotl custom modeling module
"""
from .args import AxolotlModelingArgs
from .plugin import AxolotlModelingPlugin
__all__ = [
"AxolotlModelingArgs",
"AxolotlModelingPlugin",
]

View File

@@ -1,13 +0,0 @@
"""
Args for using Axolotl custom modeling
"""
from pydantic import BaseModel
class AxolotlModelingArgs(BaseModel):
"""
Args for using Axolotl custom modeling
"""
use_liger_fused_rms_add: bool = False

View File

@@ -1,9 +0,0 @@
"""
Gemma3 modeling
"""
from .modeling_gemma3 import patch_gemma3
__all__ = [
"patch_gemma3",
]

View File

@@ -1,110 +0,0 @@
"""
Gemma3 custom decoder layer using liger fused add rms norm kernels
"""
import sys
import torch
from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm
from transformers import Cache, GradientCheckpointingLayer
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3Attention,
Gemma3MLP,
Gemma3RMSNorm,
)
class Gemma3AddRMSNorm(LigerFusedAddRMSNorm):
"""
Fused add rms norm
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__(hidden_size, eps, offset=1.0, casting_mode="gemma")
class Gemma3DecoderLayer(GradientCheckpointingLayer):
"""
Gemma3 decoder layer using liger fused add rms norm
"""
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma3MLP(config)
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma3RMSNorm(
self.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = Gemma3AddRMSNorm(
self.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = Gemma3RMSNorm(
self.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings_global: torch.Tensor,
position_embeddings_local: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
**kwargs,
) -> tuple[
torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor | None] | None
]:
# pylint: disable=duplicate-code
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# apply global RoPE to non-sliding layer only
if self.self_attn.is_sliding:
position_embeddings = position_embeddings_local
else:
position_embeddings = position_embeddings_global
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.pre_feedforward_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,) # type: ignore
return outputs # type: ignore
def patch_gemma3():
import transformers.models.gemma3.modeling_gemma3
transformers.models.gemma3.modeling_gemma3.Gemma3DecoderLayer = Gemma3DecoderLayer
sys.modules["transformers.models.gemma3.modeling_gemma3"].Gemma3DecoderLayer = (
Gemma3DecoderLayer
)

View File

@@ -1,9 +0,0 @@
"""
Llama modeling
"""
from modeling_llama import patch_llama
__all__ = [
"patch_llama",
]

View File

@@ -1,86 +0,0 @@
"""
Custom modeling for Llama for fused rms add kernels
"""
import sys
import torch
from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm
from transformers import Cache, GradientCheckpointingLayer
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaMLP,
LlamaRMSNorm,
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs
class LlamaAddRMSNorm(LigerFusedAddRMSNorm):
"""
Fused add rms norm
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__(hidden_size, eps, casting_mode="llama")
class LlamaDecoderLayer(GradientCheckpointingLayer):
"""
Llama decoder layer using liger fused add rms norm
"""
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaAddRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: (
tuple[torch.Tensor, torch.Tensor] | None
) = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
# pylint: disable=duplicate-code
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def patch_llama():
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
sys.modules["transformers.models.llama.modeling_llama"].LlamaDecoderLayer = (
LlamaDecoderLayer
)

View File

@@ -1,22 +0,0 @@
"""
Axolotl custom modeling plugin
"""
from axolotl.integrations.base import BasePlugin
class AxolotlModelingPlugin(BasePlugin):
"""
Axolotl custom modeling plugin
"""
def get_input_args(self) -> str | None:
return "axolotl.integrations.modeling.AxolotlModelingArgs"
def register(self, cfg): # pylint: disable=unused-argument
if cfg.use_liger_fused_rms_add:
from .gemma3 import patch_gemma3
from .llama import patch_llama
patch_gemma3()
patch_llama()

View File

@@ -13,7 +13,8 @@ import peft
import torch
import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights
from accelerate import PartialState, init_empty_weights
from accelerate.parallelism_config import ParallelismConfig
from peft import (
PeftConfig,
PeftMixedModel,
@@ -48,10 +49,7 @@ from axolotl.loaders.utils import (
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
get_device_count,
get_device_type,
)
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
@@ -87,6 +85,9 @@ class ModelLoader:
`AutoModelForCausalLM`).
"""
use_parallel_config: bool | None = False
parallelism_config: ParallelismConfig | None = None
def __init__(
self,
cfg: DictDefault,
@@ -183,6 +184,20 @@ class ModelLoader:
def _apply_pre_model_load_setup(self):
"""Apply patches and setup configurations before model loading."""
if self.use_parallel_config is not None:
self.use_parallel_config = (
self.cfg.fsdp_config
or (self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1)
or (
self.cfg.context_parallel_size
and self.cfg.context_parallel_size > 1
)
)
if self.cfg.fsdp_config and self.cfg.fsdp_version != 2:
self.use_parallel_config = False
if self.use_parallel_config:
self._set_parallel_config()
self._set_auto_model_loader()
self._set_device_map_config()
if self.cfg.revision_of_model:
@@ -390,6 +405,86 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
@staticmethod
def _get_parallel_config_kwargs(
world_size: int,
tensor_parallel_size: int = 1,
context_parallel_size: int = 1,
dp_shard_size: int | None = None,
dp_replicate_size: int | None = None,
is_fsdp: bool = False,
):
pc_kwargs = {}
remaining_world_size = world_size
if tensor_parallel_size and tensor_parallel_size > 1:
pc_kwargs["tp_size"] = tensor_parallel_size
remaining_world_size = remaining_world_size // tensor_parallel_size
if context_parallel_size and context_parallel_size > 1:
pc_kwargs["cp_size"] = context_parallel_size
remaining_world_size = remaining_world_size // context_parallel_size
if dp_shard_size is None and dp_replicate_size in (None, 1):
if remaining_world_size > 1:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if dp_replicate_size and dp_replicate_size > 1:
pc_kwargs["dp_replicate_size"] = dp_replicate_size
remaining_world_size = remaining_world_size // dp_replicate_size
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
if not is_fsdp:
raise ValueError(
"dp_shard_size was configured without a corresponding fsdp_config! "
"Please ensure you have configured FSDP using fsdp_config."
)
pc_kwargs["dp_shard_size"] = dp_shard_size
remaining_world_size = remaining_world_size // dp_shard_size
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
pc_kwargs["dp_replicate_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
if "dp_shard_size" not in pc_kwargs and is_fsdp:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
raise ValueError(
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
f"{pc_kwargs}"
)
return pc_kwargs
def _set_parallel_config(self):
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
get_world_size(),
self.cfg.tensor_parallel_size,
self.cfg.context_parallel_size,
self.cfg.dp_shard_size,
self.cfg.dp_replicate_size,
bool(self.cfg.fsdp or self.cfg.fsdp_config),
)
if pc_kwargs:
self.parallelism_config = ParallelismConfig(
**pc_kwargs,
)
device_mesh = self.parallelism_config.build_device_mesh("cuda")
partial_state = PartialState()
# fmt: off
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
self.parallelism_config
)
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
device_mesh
)
# fmt: on
def _set_auto_model_loader(self):
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
@@ -622,6 +717,14 @@ class ModelLoader:
def _build_model(self) -> bool:
"""Load model, with load strategy depending on config."""
skip_move_to_device = False
if self.cfg.tensor_parallel_size > 1:
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
self.model_kwargs["tp_plan"] = "auto"
self.model_kwargs["device_mesh"] = PartialState().device_mesh
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True
@@ -734,6 +837,14 @@ class ModelLoader:
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
# pylint: disable=protected-access
if self.cfg.tensor_parallel_size > 1:
# workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
# TODO(wing): remove once 4.54.1 is released
if self.model._tp_size != self.cfg.tensor_parallel_size:
self.model._tp_size = self.cfg.tensor_parallel_size
self.model._device_mesh = self.model_kwargs["device_mesh"]
return skip_move_to_device
def _set_z3_leaf_modules(self):

View File

@@ -49,6 +49,7 @@ class PatchManager:
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
self._apply_transformers_patches()
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
@@ -64,13 +65,19 @@ class PatchManager:
self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_sequence_parallel_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_voxtral_patches()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
patch_prepare_from_posids,
)
patch_prepare_from_posids()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
@@ -253,17 +260,6 @@ class PatchManager:
has_remote_code=has_remote_code,
)
def _apply_sequence_parallel_patches(self):
"""Apply sequence parallelism patches."""
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.ring_attn.patch import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
)
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import (

View File

@@ -131,6 +131,17 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`."
)
if (
cfg.tensor_parallel_size
and cfg.tensor_parallel_size > 1
and hasattr(model_config, "tie_word_embeddings")
and model_config.tie_word_embeddings
):
raise ValueError(
"Tensor parallelism is incompatible with models configured with `tie_word_embeddings` enabled. "
"Please use a model without `tie_word_embeddings`, or disable tensor parallelism."
)
def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
"""Loads and configures a model configuration from HuggingFace or local sources.

View File

@@ -249,13 +249,19 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
)
mesh = getattr(accelerator.state, "device_mesh", None)
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
"mesh": (
mesh[tuple(accelerator.state.parallelism_config.fsdp_dim_names)]
if mesh is not None
else None
),
}
model_has_params4bit = False
for _, param in model.named_parameters():
# this is a temporary fix whereby loading models with bnb params cannot be moved from

View File

@@ -36,6 +36,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"glm",
"glm4",
"smollm3",
"granite",
"granitemoe",
]

View File

@@ -6,7 +6,7 @@ import os.path
import shutil
from functools import partial
from pathlib import Path
from typing import Dict, List, Sequence, Union
from typing import Dict, List, Union
import bitsandbytes as bnb
import peft
@@ -14,8 +14,6 @@ import safetensors.torch as st
import torch
from huggingface_hub import snapshot_download
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
TrainerCallback,
TrainerControl,
@@ -84,7 +82,7 @@ class ReLoRACallback(TrainerCallback):
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
def __init__(self, cfg: DictDefault):
self.relora_steps = cfg.relora_steps
self.relora_steps = cfg.jagged_restart_steps
self.cpu_offload = cfg.relora_cpu_offload
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
self.last_full_model = cfg.base_model
@@ -255,51 +253,6 @@ class ReLoRACallback(TrainerCallback):
return control
class ReLoRAScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps - self.warmup_steps:
scale = 1
else:
per_relora_progress = step % self.relora_steps
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.relora_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
return original * scale
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
model_name = "model.safetensors"
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(

View File

@@ -5,18 +5,14 @@
from .patch import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
register_ring_attn_from_device_mesh,
set_ring_attn_group,
update_ring_attn_params,
)
__all__ = (
"get_ring_attn_group",
"patch_prepare_data_loader",
"patch_prepare_device_mesh",
"register_ring_attn",
"register_ring_attn_from_device_mesh",
"set_ring_attn_group",
"update_ring_attn_params",
)

View File

@@ -8,13 +8,12 @@ We also provide some patches for accelerate functions to prepare the dataloader
sequence parallelism training.
"""
import inspect
import os
from typing import Optional
import accelerate
import torch
import torch.distributed as dist
from torch.distributed import DeviceMesh
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
@@ -29,39 +28,13 @@ from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // submesh_tp_size"""
NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
submesh_cp_size = 1
if "cp" in torch_device_mesh.mesh_dim_names:
submesh_cp_size = torch_device_mesh["cp"].size()
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // (submesh_tp_size * submesh_cp_size)"""
def get_ring_attn_group() -> dist.ProcessGroup:
"""Getter for ring attention group on this rank."""
if RING_ATTN_GROUP is None:
raise RuntimeError("register_ring_attn() not yet called")
raise RuntimeError("register_ring_attn_from_device_mesh() not yet called")
return RING_ATTN_GROUP
@@ -161,15 +134,17 @@ def create_ring_flash_attention_forward(
]
def register_ring_attn(
sequence_parallel_degree: int,
def register_ring_attn_from_device_mesh(
device_mesh: "DeviceMesh",
context_parallel_dim: tuple[str, ...],
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
"""Create ring attention group and substitute flash attn with ring flash attn.
"""Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
device_mesh: DeviceMesh object containing the parallelism topology.
context_parallel_dim: Name of the sequence parallel dimension in the device mesh.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
@@ -177,44 +152,39 @@ def register_ring_attn(
`batch` function.
"""
rank = dist.get_rank()
world_size = dist.get_world_size()
LOG.info(
f"Enabling ring attention sequence parallelism using DeviceMesh "
f"dimension '{context_parallel_dim}'",
main_process_only=True,
)
# Extract the sequence parallel submesh
try:
sequence_mesh = device_mesh[context_parallel_dim]
except (KeyError, IndexError) as e:
raise ValueError(
f"Dimension '{context_parallel_dim}' not found in device_mesh. "
f"Available dimensions: {device_mesh.mesh_dim_names}"
) from e
# Get the process group for context parallelism
sequence_pg = sequence_mesh.get_group()
context_parallel_size = sequence_mesh.size()
if rank == 0:
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
f"Sequence parallel degree: {context_parallel_size}, "
f"mesh shape: {sequence_mesh.mesh.shape}"
)
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
)
# Log which ranks are in the current process group
if sequence_pg != dist.GroupMember.WORLD:
ranks_in_group = dist.get_process_group_ranks(sequence_pg)
LOG.info(f"Current sequence parallel group ranks: {ranks_in_group}")
# Assign ranks to sequence parallel groups
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
# Track which GPUs are in which groups
for r in ring_attn_ranks:
group_assignments[r] = i
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
# Set the ring attention group
set_ring_attn_group(sequence_pg)
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
# fmt: off
@@ -257,92 +227,3 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raises:
RuntimeError: If source code to patch does not exist.
"""
original_fn = accelerate.data_loader.prepare_data_loader
original_source = inspect.getsource(original_fn)
if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source:
raise RuntimeError(
"SP patch failed - target snippet not found. "
"Check accelerate's version or update the patch."
)
patched_source = original_source.replace(
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
)
items_to_import = []
for item in dir(accelerate.data_loader):
if item in patched_source:
items_to_import.append(item)
# Create a new function from the patched source
namespace = {}
exec( # pylint: disable=exec-used # nosec B102
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
patched_source, globals(), namespace
)
patched_function = namespace["prepare_data_loader"]
original_fn.__code__ = patched_function.__code__
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.
Args:
sequence_parallel_degree: The degree of sequence parallelism to use.
fsdp: Whether to use FSDP.
"""
def _prepare_device_mesh(self):
"""Prepare the device mesh for distributed training. The dataloader will
determine how to load data based on the device mesh.
"""
if self.state.torch_tp_plugin:
return self.state.torch_tp_plugin.torch_device_mesh
if (
self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED
and hasattr(self.state, "ds_device_mesh")
):
return self.state.ds_device_mesh
# Create device mesh with sequence parallelism
world_size = dist.get_world_size()
mesh_shape = (
world_size // sequence_parallel_degree,
sequence_parallel_degree,
)
device_ids = list(range(world_size))
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming.
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
# only use "fsdp" and "cp" for the device mesh.
return dist.DeviceMesh(
"cuda",
torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
)
# Replace the original method with our new method
# pylint: disable=protected-access
accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh
LOG.info(
"Successfully patched Accelerator._prepare_device_mesh "
f"with sequence_parallel_degree={sequence_parallel_degree}"
)

View File

@@ -0,0 +1,87 @@
"""
Monkey patch to fix transformers.modeling_flash_attention_utils.
see https://github.com/huggingface/transformers/pull/39653/files
"""
import sys
import torch
def _prepare_from_posids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), device=position_ids.device, dtype=torch.int32
),
)
)
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length = cu_seq_lens.diff().max().item()
return (
query,
key,
value,
indices_q,
(cu_seq_lens, cu_seq_lens),
(max_length, max_length),
)
def patch_prepare_from_posids():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
_prepare_from_posids
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_prepare_from_posids",
_prepare_from_posids,
)

View File

@@ -205,7 +205,7 @@ def execute_training(
)
)
if cfg.sequence_parallel_degree > 1:
if cfg.context_parallel_size > 1:
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)
@@ -213,7 +213,7 @@ def execute_training(
stack.enter_context(
SequenceParallelContextManager(
models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree,
context_parallel_size=cfg.context_parallel_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
@@ -267,7 +267,7 @@ def save_trained_model(
"your model weights with `axolotl quantize`."
)
# Handle ReLoRA early return case
if cfg.relora_steps:
if cfg.relora:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:

View File

@@ -57,10 +57,10 @@ def gpu_memory_usage(device=0):
@check_cuda_device((0.0, 0.0, 0.0))
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
return active, allocated, reserved
def mps_memory_usage_all():
@@ -92,27 +92,38 @@ def gpu_memory_usage_smi(device=0):
return 0.0
def log_gpu_memory_usage(
log: logging.Logger | logging.LoggerAdapter,
msg: str = "",
device: int | torch.device = 0,
):
def get_gpu_memory_usage(device: int | torch.device = 0):
cur_device_type = str(get_device_type())
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
elif "npu" in cur_device_type and is_torch_npu_available():
usage, cache, misc = npu_memory_usage_all(device)
elif "gpu" in cur_device_type and torch.cuda.is_available():
elif "cuda" in cur_device_type and torch.cuda.is_available():
usage, cache, misc = gpu_memory_usage_all(device)
else:
return 0.0, 0.0, 0.0
return usage, cache, misc
def log_gpu_memory_usage(
log: logging.Logger | logging.LoggerAdapter,
msg: str = "",
device: int | torch.device = 0,
):
try:
active, allocated, reserved = get_gpu_memory_usage(device)
except ValueError:
# likely CPU, ignore
return
cur_device_type = str(get_device_type())
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
msg = f"{cur_device_type} memory usage:" if not msg else msg
log.info(
f"{msg} {usage:.03f}GB ({', '.join(extras)})",
if allocated > 0:
extras.append(f"+{allocated:.03f}GB allocated")
if reserved > 0:
extras.append(f"+{reserved:.03f}GB reserved")
msg = f"{cur_device_type} memory active:" if not msg else msg
log.debug(
f"{msg} {active:.03f}GB ({', '.join(extras)})",
stacklevel=2,
)

View File

@@ -35,7 +35,6 @@ from transformers.trainer_utils import (
from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.distributed import (
barrier,
@@ -93,28 +92,6 @@ class SaveBetterTransformerModelCallback(
return control
class GPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to track GPU utilization"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
def on_step_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> TrainerControl:
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control
class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""

View File

@@ -0,0 +1,62 @@
{# Alias tools -> available_tools #}
{%- if tools and not available_tools -%}
{%- set available_tools = tools -%}
{%- endif -%}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set system_message = "Knowledge Cutoff Date: April 2024.
Today's Date: " + strftime_now('%B %d, %Y') + ".
You are Granite, developed by IBM." %}
{%- if available_tools and documents %}
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.
Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- elif available_tools %}
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
{%- elif documents %}
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- elif thinking %}
{%- set system_message = system_message + " You are a helpful AI assistant.
Respond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between <think></think> and write your response between <response></response> for each user query." %}
{%- else %}
{%- set system_message = system_message + " You are a helpful AI assistant." %}
{%- endif %}
{%- if 'citations' in controls and documents %}
{%- set system_message = system_message + '
Use the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
{%- endif %}
{%- if 'hallucinations' in controls and documents %}
{%- set system_message = system_message + '
Finally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %}
{%- endif %}
{%- set loop_messages = messages %}
{%- endif %}
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
' }}
{%- if available_tools %}
{{- '<|start_of_role|>available_tools<|end_of_role|>' }}
{{- available_tools | tojson(indent=4) }}
{{- '<|end_of_text|>
' }}
{%- endif %}
{%- if documents %}
{%- for document in documents %}
{{- '<|start_of_role|>document {"document_id": "' + document['doc_id'] | string + '"}<|end_of_role|>
' }}
{{- document['text'] }}
{{- '<|end_of_text|>
' }}
{%- endfor %}
{%- endif %}
{%- for message in loop_messages %}
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- if loop.last and add_generation_prompt %}
{{- '<|start_of_role|>assistant' }}
{%- if controls %}
{{- ' ' + controls | tojson()}}
{%- endif %}
{{- '<|end_of_role|>' }}
{%- endif %}
{%- endfor %}

View File

@@ -0,0 +1,64 @@
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set system_message = "Knowledge Cutoff Date: April 2024.
Today's Date: " + strftime_now('%B %d, %Y') + ".
You are Granite, developed by IBM." %}
{%- if tools and documents %}
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.
Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- elif tools %}
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
{%- elif documents %}
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
{%- else %}
{%- set system_message = system_message + " You are a helpful AI assistant." %}
{%- endif %}
{%- if 'citations' in controls and documents %}
{%- set system_message = system_message + '
In your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
{%- endif %}
{%- if 'hallucinations' in controls and documents %}
{%- set system_message = system_message + '
Finally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.' %}
{%- endif %}
{%- set loop_messages = messages %}
{%- endif %}
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
' }}
{%- if tools %}
{{- '<|start_of_role|>tools<|end_of_role|>' }}
{{- tools | tojson(indent=4) }}
{{- '<|end_of_text|>
' }}
{%- endif %}
{%- if documents %}
{{- '<|start_of_role|>documents<|end_of_role|>' }}
{%- for document in documents %}
{{- 'Document ' + loop.index0 | string + '
' }}
{{- document['text'] }}
{%- if not loop.last %}
{{- '
'}}
{%- endif%}
{%- endfor %}
{{- '<|end_of_text|>
' }}
{%- endif %}
{%- for message in loop_messages %}
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
' }}
{%- if loop.last and add_generation_prompt %}
{{- '<|start_of_role|>assistant' }}
{%- if controls %}
{{- ' ' + controls | tojson()}}
{%- endif %}
{{- '<|end_of_role|>' }}
{%- endif %}
{%- endfor %}

View File

@@ -5,6 +5,7 @@ import inspect
import torch
import torch.distributed as dist
from accelerate import PartialState
from torch import nn
from torch.utils.hooks import RemovableHandle
from transformers.modeling_outputs import CausalLMOutputWithPast
@@ -12,7 +13,7 @@ from transformers.utils import ModelOutput
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
register_ring_attn,
register_ring_attn_from_device_mesh,
update_ring_attn_params,
)
from axolotl.utils.schemas.enums import RingAttnFunc
@@ -150,9 +151,18 @@ def apply_sequence_parallelism(
if "num_items_in_batch" in batch:
# Approximation; this needed since num_items_in_batch may be counted across
# all samples in a gradient accumulated batch, not on a per-step basis.
local_valid_tokens = (batch["labels"] != -100).sum()
# All-reduce across sequence parallel ranks to get global token count
cp_group = get_ring_attn_group()
global_valid_tokens = local_valid_tokens.clone()
# we use AVG instead of SUM as using sum seems to scale down the loss by over-accounting the number of tokens
dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.AVG, group=cp_group)
global_valid_tokens = int(global_valid_tokens.item())
batch["num_items_in_batch"] = (
batch["labels"] != -100
).sum() * gradient_accumulation_steps
global_valid_tokens * gradient_accumulation_steps
)
return batch, original_seq_len, pad_len
@@ -167,7 +177,7 @@ class SequenceParallelContextManager:
Args:
models: List of models to apply sequence parallelism to pre- and post- forward
hooks.
sequence_parallel_degree: Number of processes to split sequences over.
context_parallel_size: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
@@ -179,14 +189,14 @@ class SequenceParallelContextManager:
def __init__(
self,
models: list[nn.Module],
sequence_parallel_degree: int,
context_parallel_size: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
gather_outputs: bool,
):
self.models = models
self.sequence_parallel_degree = sequence_parallel_degree
self.context_parallel_size = context_parallel_size
self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride
@@ -230,8 +240,10 @@ class SequenceParallelContextManager:
def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism
register_ring_attn(
sequence_parallel_degree=self.sequence_parallel_degree,
partial_state = PartialState()
register_ring_attn_from_device_mesh(
device_mesh=partial_state.device_mesh,
context_parallel_dim=("cp",),
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,
)

View File

@@ -430,10 +430,11 @@ def save_preprocessed_dataset(
num_shards=cfg.num_dataset_shards_to_save,
)
else:
min_rows_per_proc = 256
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(
str(prepared_ds_path),
num_proc=min(max(1, len(dataset) // 8), num_workers),
num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers),
max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save,
)

View File

@@ -2,12 +2,15 @@
utils to get GPU info for the current environment
"""
from importlib.metadata import version
from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
)
from accelerate.utils.environment import (
get_gpu_info,
)
from packaging.version import Version, parse
def check_cuda_p2p_ib_support():
@@ -26,3 +29,13 @@ def check_cuda_p2p_ib_support():
except Exception: # pylint: disable=broad-except # nosec
pass
return True
def get_package_version(package: str) -> Version:
version_str = version(package)
return parse(version_str)
def is_package_version_ge(package: str, version_: str) -> bool:
package_version = get_package_version(package)
return package_version >= parse(version_)

View File

@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import gc
import math
import time
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context
from typing import Iterable, Iterator, Union
@@ -453,7 +454,10 @@ class MultipackBatchSampler(BatchSampler):
_sampled_lens = []
for _ in range(self.num_count_samples):
self._batches = None # Reset cached batches
# log timer for generating batches
start_time = time.time()
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
LOG.debug(f"generate_batches time: {time.time() - start_time}")
len_batches = min(_sampled_lens)
# Gather minimum across all ranks

View File

@@ -2,6 +2,7 @@
import math
from functools import partial
from typing import Sequence
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@@ -292,3 +293,50 @@ def get_cosine_schedule_with_warmup_decay_constant(
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
class JaggedLRRestartScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
jagged_restart_steps: int,
jagged_restart_warmup_steps: int,
jagged_restart_anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
# pylint: disable=duplicate-code
self.inner_schedule = inner_schedule
self.restarts_steps = jagged_restart_steps
self.warmup_steps = jagged_restart_warmup_steps
self.anneal_steps = jagged_restart_anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch)
def get_lr(self) -> float | Sequence[float]:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.restarts_steps - self.anneal_steps:
scale = 1
else:
per_restart_progress = step % self.restarts_steps
if per_restart_progress < self.warmup_steps:
cycle_t = min(1.0, (per_restart_progress) / self.warmup_steps)
elif per_restart_progress > (self.restarts_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.restarts_steps - per_restart_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
return original * scale

View File

@@ -43,7 +43,7 @@ from axolotl.utils.schemas.model import (
from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.training import HyperparametersConfig, JaggedLRConfig
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.validation import ValidationMixin
from axolotl.utils.schemas.vllm import VllmConfig
@@ -57,6 +57,7 @@ class AxolotlInputConfig(
ModelOutputConfig,
LoraConfig,
ReLoRAConfig,
JaggedLRConfig,
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
@@ -650,7 +651,23 @@ class AxolotlInputConfig(
},
)
dp_shard_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of devices to shard across. If not set, will use all available devices."
},
)
dp_replicate_size: int | None = Field(
default=None,
json_schema_extra={"description": "Number of devices to replicate across."},
)
sequence_parallel_degree: int | None = Field(
default=None,
json_schema_extra={
"description": "Deprecated: use `context_parallel_size` instead"
},
)
context_parallel_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."

View File

@@ -67,6 +67,8 @@ class ChatTemplate(str, Enum):
command_a_tool_use = "command_a_tool_use"
command_a_rag = "command_a_rag"
aya = "aya"
granite = "granite"
granitemoe = "granitemoe"
class CustomSupportedOptimizers(str, Enum):

View File

@@ -187,18 +187,10 @@ class LoraConfig(BaseModel):
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
relora_steps: int | None = Field(
default=None,
json_schema_extra={"description": "Number of steps per ReLoRA restart"},
)
relora_warmup_steps: int | None = Field(
default=None,
json_schema_extra={"description": "Number of per-restart warmup steps"},
)
relora_anneal_steps: int | None = Field(
relora: bool | None = Field(
default=None,
json_schema_extra={
"description": "Number of anneal steps for each relora cycle"
"description": "Whether to use ReLoRA. Use with jagged_restart_*steps options."
},
)
relora_prune_ratio: float | None = Field(

View File

@@ -160,3 +160,24 @@ class HyperparametersConfig(BaseModel):
if learning_rate and isinstance(learning_rate, str):
learning_rate = float(learning_rate)
return learning_rate
class JaggedLRConfig(BaseModel):
"""JaggedLR configuration subset, can be used w/ ReLoRA training"""
jagged_restart_steps: int | None = Field(
default=None,
json_schema_extra={"description": "how often to reset for jagged restarts"},
)
jagged_restart_warmup_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "how many warmup steps to take after reset for jagged restarts"
},
)
jagged_restart_anneal_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "how many anneal steps to take before reset for jagged restarts"
},
)

View File

@@ -80,6 +80,14 @@ class TRLConfig(BaseModel):
"description": "Number of completions to print when log_completions is True."
},
)
importance_sampling_level: Literal["sequence", "token"] | None = Field(
default=None,
json_schema_extra={
"description": "Controls whether importance sampling ratios are computed at the `'token'` or `'sequence'` level. "
"For GSPO, use `sequence`, default is None which corresponds to the original GRPO paper."
},
)
sync_ref_model: bool | None = Field(
default=False,
json_schema_extra={"description": "Whether to sync the reference model."},

View File

@@ -644,6 +644,19 @@ class LoRAValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_dropout_parameters(cls, data):
if (
data.get("lora_dropout", 0.0)
and data.get("lora_dropout") > 0.0
and data.get("lora_target_parameters")
):
# lora.ParamWrapper does not work with lora_dropout != 0
raise ValueError(
"`lora_dropout` does not work when using `lora_target_parameters`"
)
class RLValidationMixin:
"""Validation methods related to RL training configuration."""
@@ -673,7 +686,7 @@ class RLValidationMixin:
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("sequence_parallel_degree", 1) > 1
and data.get("context_parallel_size", 1) > 1
):
raise ValueError("GRPO + SP + Liger not currently supported")
return data
@@ -881,17 +894,19 @@ class OptimizationValidationMixin:
return self
@model_validator(mode="after")
def check_fsdp_sharded_state_dict_w_safetensors(self):
def lr_groups_ao_optimizer(self):
if (
hasattr(self, "fsdp_config")
and self.fsdp_config
and hasattr(self, "save_safetensors")
and self.save_safetensors
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
and str(getattr(self, "fsdp_version", "1")) != "2"
):
self.loraplus_lr_ratio is not None
or self.embedding_lr_scale is not None
or self.embedding_lr is not None
or self.lr_groups is not None
) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]:
# TODO(wing): remove this once ao>0.12.0
# requires https://github.com/pytorch/ao/pull/2606 in an ao release
raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
"lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not "
"supported with ao low-bit optimizers until ao>0.12.0. "
"Please refer to https://github.com/pytorch/ao/pull/2606."
)
return self
@@ -900,31 +915,30 @@ class OptimizationValidationMixin:
def check_tensor_parallel_size_update_ds_json(cls, data):
tensor_parallel_size = data.get("tensor_parallel_size")
if tensor_parallel_size is not None and tensor_parallel_size > 1:
if not data.get("deepspeed"):
raise ValueError(
"Tensor parallelism (TP) is only supported with DeepSpeed"
)
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
ds_config = json.load(ds_fin)
should_save = False
if "tensor_parallel" not in ds_config:
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
should_save = True
if (
"gather_16bit_weights_on_model_save"
not in ds_config["zero_optimization"]
):
ds_config["zero_optimization"][
if data.get("deepspeed"):
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
ds_config = json.load(ds_fin)
should_save = False
if "tensor_parallel" not in ds_config:
ds_config["tensor_parallel"] = {
"autotp_size": tensor_parallel_size
}
should_save = True
if (
"gather_16bit_weights_on_model_save"
] = True
should_save = True
if should_save:
temp_dir = tempfile.mkdtemp()
with open(
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
not in ds_config["zero_optimization"]
):
ds_config["zero_optimization"][
"gather_16bit_weights_on_model_save"
] = True
should_save = True
if should_save:
temp_dir = tempfile.mkdtemp()
with open(
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
return data
@@ -1164,7 +1178,9 @@ class ComplexValidationMixin:
@model_validator(mode="after")
def check_relora(self):
if self.relora_steps:
if self.relora:
if not self.jagged_restart_steps:
raise ValueError("jagged_restart_steps must be set to use ReLoRA")
if self.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
@@ -1203,13 +1219,18 @@ class ComplexValidationMixin:
return self
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1:
def check_context_parallel_size(self):
if self.sequence_parallel_degree and not self.context_parallel_size:
LOG.warning(
"`sequence_parallel_degree` is deprecated, use `context_parallel_size`"
)
self.context_parallel_size = self.sequence_parallel_degree
if not self.context_parallel_size:
self.context_parallel_size = 1
elif self.context_parallel_size > 1:
if not self.flash_attention:
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
"flash_attention: true must be set with context_parallel_size > 1"
)
if self.sample_packing and self.micro_batch_size > 1:
@@ -1219,17 +1240,23 @@ class ComplexValidationMixin:
)
try:
import transformers.modeling_flash_attention_utils
# pylint: disable=protected-access
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
transformers.modeling_flash_attention_utils._flash_supports_window
)
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
"context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
f"context_parallel_size={self.context_parallel_size}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
@@ -1240,7 +1267,7 @@ class ComplexValidationMixin:
@model_validator(mode="after")
def validate_ring_attn_func(self):
if getattr(self, "sequence_parallel_degree", 1) == 1:
if getattr(self, "context_parallel_size", 1) == 1:
return self
if self.ring_attn_func is not None:
@@ -1257,6 +1284,33 @@ class ComplexValidationMixin:
return self
class DistributedValidationMixin:
"""validation for distributed training."""
@model_validator(mode="after")
def check_tensor_parallel_optimizer(self):
if self.tensor_parallel_size > 1:
if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]:
raise ValueError(
"tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
)
return self
class GRPOVllmValidationMixin:
"""Validation mixin for vllm when using GRPO."""
@model_validator(mode="after")
def check_vllm_mode_set(self):
if self.trl and self.trl.use_vllm and not self.trl.vllm_mode:
LOG.warning(
"vllm_mode must be set to either `server` or `colocate` when using vllm, using default value `server`"
)
self.trl.vllm_mode = "server"
return self
# pylint: disable=too-many-ancestors
class ValidationMixin(
DatasetValidationMixin,
@@ -1270,5 +1324,6 @@ class ValidationMixin(
PretrainingValidationMixin,
ModelCompatibilityValidationMixin,
ComplexValidationMixin,
GRPOVllmValidationMixin,
):
"""Full validation mixin for Axolotl configuration."""

View File

@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1
)
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.context_parallel_size
* cfg.tensor_parallel_size
)
LOG.debug(
@@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.floor(
data_loader_len
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.context_parallel_size
* cfg.tensor_parallel_size
)
)
@@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.context_parallel_size
* cfg.tensor_parallel_size
/ cfg.batch_size
)

View File

@@ -17,16 +17,23 @@ class BaseCliTest:
command: Command to test (train/evaluate)
"""
# Test missing config file
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
result = cli_runner.invoke(cli, [command, "--launcher", "python"])
assert result.exit_code != 0
# Test non-existent config file
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
result = cli_runner.invoke(
cli, [command, "nonexistent.yml", "--launcher", "python"]
)
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def _test_basic_execution(
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
self,
cli_runner,
tmp_path: Path,
valid_test_config: str,
command: str,
train: bool = True,
):
"""Test basic execution with accelerate.
@@ -35,6 +42,7 @@ class BaseCliTest:
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
command: Command to test (train/evaluate)
train: Whether to test training (default) or evaluation
"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
@@ -43,15 +51,21 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
expected = [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
"--debug=False",
"--debug-text-only=False",
"--debug-num-examples=0",
]
if train:
expected.append("--shard=False")
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -1,5 +1,7 @@
"""Tests for evaluate CLI command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -18,7 +20,9 @@ class TestEvaluateCommand(BaseCliTest):
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
self._test_basic_execution(
cli_runner, tmp_path, valid_test_config, "evaluate", train=False
)
def test_evaluate_basic_execution_no_accelerate(
self, cli_runner, tmp_path, valid_test_config
@@ -27,13 +31,15 @@ class TestEvaluateCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
# pylint: disable=duplicate-code
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -55,7 +61,8 @@ class TestEvaluateCommand(BaseCliTest):
"2",
"--sequence-len",
"128",
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -65,3 +72,104 @@ class TestEvaluateCommand(BaseCliTest):
cfg = mock_evaluate.call_args[0][0]
assert cfg.micro_batch_size == 2
assert cfg.sequence_len == 128
def test_evaluate_with_launcher_args_torchrun(
self, cli_runner, tmp_path, valid_test_config
):
"""Test evaluate with torchrun launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.evaluate" in called_cmd
def test_evaluate_with_launcher_args_accelerate(
self, cli_runner, tmp_path, valid_test_config
):
"""Test evaluate with accelerate launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.evaluate" in called_cmd
def test_evaluate_backward_compatibility_no_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test that existing evaluate commands work without launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"accelerate",
"--micro-batch-size",
"2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert (
len(launcher_section) == 0
) # No launcher args between 'launch' and '-m'

View File

@@ -1,5 +1,7 @@
"""pytest tests for axolotl CLI inference command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -10,7 +12,7 @@ def test_inference_basic(cli_runner, config_path):
with patch("axolotl.cli.inference.do_inference") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate"],
["inference", str(config_path), "--launcher", "python"],
catch_exceptions=False,
)
@@ -23,9 +25,124 @@ def test_inference_gradio(cli_runner, config_path):
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate", "--gradio"],
["inference", str(config_path), "--launcher", "python", "--gradio"],
catch_exceptions=False,
)
assert mock.called
assert result.exit_code == 0
def test_inference_with_launcher_args_torchrun(cli_runner, config_path):
"""Test inference with torchrun launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_with_launcher_args_accelerate(cli_runner, config_path):
"""Test inference with accelerate launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_gradio_with_launcher_args(cli_runner, config_path):
"""Test inference with gradio and launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
"--gradio",
"--",
"--num_processes=2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify both gradio flag and launcher args are present
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--num_processes=2" in called_cmd
assert "--gradio" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path):
"""Test that existing inference commands work without launcher args"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'

View File

@@ -18,11 +18,10 @@ def test_build_command():
assert result == [
"accelerate",
"launch",
"--learning-rate",
"0.0001",
"--batch-size",
"8",
"--debug",
"--learning-rate=0.0001",
"--batch-size=8",
"--debug=True",
"--use-fp16=False",
]
@@ -38,7 +37,7 @@ def test_invalid_command_options(cli_runner):
],
)
assert result.exit_code != 0
assert "No such option" in result.output
assert "does not exist" in result.output
def test_required_config_argument(cli_runner):

View File

@@ -11,9 +11,101 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command without accelerate"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
cli,
["merge-sharded-fsdp-weights", str(config_path), "--launcher", "python"],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun(
cli_runner, config_path
):
"""Test merge-sharded-fsdp-weights with torchrun launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate(
cli_runner, config_path
):
"""Test merge-sharded-fsdp-weights with accelerate launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args(
cli_runner, config_path
):
"""Test that existing merge-sharded-fsdp-weights commands work without launcher args"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'

View File

@@ -2,7 +2,7 @@
unit tests for generating sweep configurations
"""
from axolotl.cli.main import generate_sweep_configs
from axolotl.cli.utils import generate_sweep_configs
def test_generate_sweep_configs_no_pairs():

View File

@@ -1,5 +1,7 @@
"""Tests for train CLI command."""
# pylint: disable=duplicate-code
from unittest.mock import MagicMock, patch
from axolotl.cli.main import cli
@@ -18,7 +20,9 @@ class TestTrainCommand(BaseCliTest):
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
self._test_basic_execution(
cli_runner, tmp_path, valid_test_config, "train", train=True
)
def test_train_basic_execution_no_accelerate(
self, cli_runner, tmp_path, valid_test_config
@@ -37,7 +41,8 @@ class TestTrainCommand(BaseCliTest):
[
"train",
str(config_path),
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -59,11 +64,10 @@ class TestTrainCommand(BaseCliTest):
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
"--learning-rate=1e-4",
"--micro-batch-size=2",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -73,3 +77,174 @@ class TestTrainCommand(BaseCliTest):
cfg = mock_train.call_args[1]["cfg"]
assert cfg["learning_rate"] == 1e-4
assert cfg["micro_batch_size"] == 2
def test_train_with_launcher_args_torchrun(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with torchrun launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.train" in called_cmd
def test_train_with_launcher_args_accelerate(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with accelerate launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.train" in called_cmd
def test_train_backward_compatibility_no_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test that existing train commands work without launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"accelerate",
"--learning-rate",
"1e-4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert (
len(launcher_section) == 0
) # No launcher args between 'launch' and '-m'
def test_train_mixed_args_with_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with both regular CLI args and launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"torchrun",
"--learning-rate",
"2e-4",
"--micro-batch-size",
"4",
"--",
"--nproc_per_node=8",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
called_cmd = mock_subprocess.call_args.args[0]
# Verify launcher args
assert "--nproc_per_node=8" in called_cmd
# Verify axolotl args are also present
assert "--learning-rate=2e-4" in called_cmd
assert "--micro-batch-size=4" in called_cmd
def test_train_cloud_with_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with cloud and launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
cloud_path = tmp_path / "cloud.yml"
cloud_path.write_text("provider: modal\ngpu: a100")
with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--cloud",
str(cloud_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=4",
"--nnodes=2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_cloud_train.assert_called_once()
# Verify cloud training was called with launcher args
call_kwargs = mock_cloud_train.call_args.kwargs
assert call_kwargs["launcher"] == "torchrun"
assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"]

View File

@@ -72,3 +72,160 @@ def test_fetch_from_github_network_error():
with patch("requests.get", side_effect=requests.RequestException):
with pytest.raises(requests.RequestException):
fetch_from_github("examples/", None)
def assert_launcher_args_in_command(
mock_subprocess_call,
launcher: str,
expected_launcher_args: list[str],
command_module: str,
):
"""
Helper function to verify launcher arguments are properly passed in subprocess calls.
Args:
mock_subprocess_call: The mock subprocess.run call
launcher: Expected launcher ("accelerate", "torchrun", etc.)
expected_launcher_args: List of expected launcher arguments
command_module: Expected module name (e.g., "axolotl.cli.train")
"""
assert mock_subprocess_call.called, "subprocess.run should have been called"
called_cmd = mock_subprocess_call.call_args.args[0]
# Verify launcher
assert (
called_cmd[0] == launcher
), f"Expected launcher {launcher}, got {called_cmd[0]}"
# Verify launcher args are present
for arg in expected_launcher_args:
assert (
arg in called_cmd
), f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
# Verify module is present
assert "-m" in called_cmd, "Expected -m flag for module execution"
assert (
command_module in called_cmd
), f"Expected module {command_module} not found in command: {called_cmd}"
def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
"""
Helper function to verify no unwanted launcher arguments are present.
Args:
mock_subprocess_call: The mock subprocess.run call
launcher: Expected launcher ("accelerate", "torchrun", etc.)
"""
assert mock_subprocess_call.called, "subprocess.run should have been called"
called_cmd = mock_subprocess_call.call_args.args[0]
if launcher == "accelerate":
# For accelerate, launcher args should be between 'launch' and '-m'
launch_idx = called_cmd.index("launch")
m_idx = called_cmd.index("-m")
launcher_section = called_cmd[launch_idx + 1 : m_idx]
assert (
len(launcher_section) == 0
), f"Unexpected launcher args found: {launcher_section}"
elif launcher == "torchrun":
# For torchrun, launcher args should be between 'torchrun' and '-m'
torchrun_idx = called_cmd.index("torchrun")
m_idx = called_cmd.index("-m")
launcher_section = called_cmd[torchrun_idx + 1 : m_idx]
assert (
len(launcher_section) == 0
), f"Unexpected launcher args found: {launcher_section}"
@pytest.fixture
def common_launcher_args():
"""Fixture providing common launcher argument combinations for testing."""
return {
"torchrun": ["--nproc_per_node=2", "--nnodes=1"],
"accelerate": ["--config_file=accelerate_config.yml", "--num_processes=4"],
}
def test_add_default_rdzv_args_with_endpoint():
"""Test that default RDZV args are added when rdzv_endpoint is present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = ["--nnodes=2", "--rdzv_endpoint=127.0.0.1:29400"]
result = _add_default_rdzv_args(launcher_args)
# Should have added rdzv_backend
assert "--rdzv_backend" in result
assert "c10d" in result
# Original args should still be present
assert "--nnodes=2" in result
assert "--rdzv_endpoint=127.0.0.1:29400" in result
def test_add_default_rdzv_args_with_existing_backend():
"""Test that existing rdzv_backend is not overridden."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_backend=static",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add another rdzv_backend
backend_count = sum(1 for arg in result if "--rdzv_backend" in arg)
assert backend_count == 1
assert "--rdzv_backend=static" in result
def test_add_default_rdzv_args_with_existing_id():
"""Test that existing rdzv_id is not overridden."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_id=my_job_123",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add another rdzv_id
id_count = sum(1 for arg in result if "--rdzv_id" in arg)
assert id_count == 1
assert "--rdzv_id=my_job_123" in result
# Should still add rdzv_backend
assert "--rdzv_backend" in result
assert "c10d" in result
def test_add_default_rdzv_args_without_endpoint():
"""Test that no RDZV args are added when rdzv_endpoint is not present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = ["--nnodes=2", "--nproc_per_node=4"]
result = _add_default_rdzv_args(launcher_args)
# Should not add any rdzv args
assert "--rdzv_backend" not in result
assert result == launcher_args
def test_add_default_rdzv_args_with_all_existing():
"""Test that no defaults are added when all RDZV args are present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_backend=static",
"--rdzv_id=existing_job",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add any additional args
assert len(result) == len(launcher_args)
assert result == launcher_args

Some files were not shown because too many files have changed in this diff Show More