Compare commits
17 Commits
custom-mod
...
chat-templ
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5198d8734 | ||
|
|
4ab6a1bd7e | ||
|
|
5639552064 | ||
|
|
cda3c82351 | ||
|
|
7c3b428f23 | ||
|
|
01a6bd1a0e | ||
|
|
41709822a7 | ||
|
|
02a37199ee | ||
|
|
7026cd5e9e | ||
|
|
eb0a8a7775 | ||
|
|
294c7fe7a6 | ||
|
|
7b68dfafd7 | ||
|
|
32a7890231 | ||
|
|
563f5eed7a | ||
|
|
6ec282094d | ||
|
|
09dda462ab | ||
|
|
bb1cae1a20 |
4
.github/workflows/preview-docs.yml
vendored
4
.github/workflows/preview-docs.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
||||
|
||||
- name: Netlify Publish
|
||||
uses: nwtgck/actions-netlify@v3.0
|
||||
if: ${{ secrets.NETLIFY_AUTH_TOKEN != '' }}
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
id: netlify
|
||||
with:
|
||||
publish-dir: './_site'
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
|
||||
|
||||
- name: Update PR with preview link
|
||||
if: ${{ steps.netlify.outcome == 'success' && secrets.NETLIFY_AUTH_TOKEN != '' }}
|
||||
if: ${{ steps.netlify.outcome == 'success' }}
|
||||
uses: marocchino/sticky-pull-request-comment@v2
|
||||
with:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
15
_quarto.yml
15
_quarto.yml
@@ -35,25 +35,30 @@ quartodoc:
|
||||
- cli.train
|
||||
- cli.evaluate
|
||||
- cli.args
|
||||
- cli.art
|
||||
- cli.checks
|
||||
- cli.config
|
||||
- cli.delinearize_llama4
|
||||
- cli.inference
|
||||
- cli.merge_lora
|
||||
- cli.merge_sharded_fsdp_weights
|
||||
- cli.preprocess
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.quantize
|
||||
- cli.vllm_serve
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- cli.quantize
|
||||
- cli.utils
|
||||
- cli.utils.args
|
||||
- cli.utils.fetch
|
||||
- cli.utils.load
|
||||
- cli.utils.sweeps
|
||||
- cli.utils.train
|
||||
- title: Trainers
|
||||
desc: Training implementations
|
||||
contents:
|
||||
- core.trainers.base
|
||||
- core.trainers.trl
|
||||
- core.trainers.mamba
|
||||
- core.trainers.relora
|
||||
- core.trainers.dpo.trainer
|
||||
- core.trainers.grpo.trainer
|
||||
- core.trainers.grpo.sampler
|
||||
@@ -269,7 +274,6 @@ website:
|
||||
- docs/dataset_preprocessing.qmd
|
||||
- docs/multipack.qmd
|
||||
- docs/mixed_precision.qmd
|
||||
- docs/gradient_accumulation.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
@@ -279,6 +283,7 @@ website:
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
- docs/nd_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
set -e
|
||||
|
||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||
pytest -v -n2 \
|
||||
pytest -v --durations=10 -n2 \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||
|
||||
@@ -65,6 +65,9 @@ GPU_CONFIG = f"L40S:{N_GPUS}"
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
sp_env = os.environ.copy()
|
||||
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
@@ -16,7 +16,10 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
|
||||
ibverbs-providers ibverbs-utils infiniband-diags \
|
||||
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
||||
&& rm -rf /var/cache/apt/archives \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
|
||||
@@ -15,7 +15,7 @@ COPY scripts/motd /etc/motd
|
||||
RUN pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt update && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
mkdir -p ~/.ssh && \
|
||||
|
||||
26
docs/cli.qmd
26
docs/cli.qmd
@@ -23,6 +23,20 @@ axolotl <command> [config.yml] [options]
|
||||
|
||||
The config file can be local or a URL to a raw YAML file.
|
||||
|
||||
### Launcher Arguments
|
||||
|
||||
For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:
|
||||
|
||||
```bash
|
||||
# Pass torchrun arguments
|
||||
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
|
||||
|
||||
# Pass accelerate arguments
|
||||
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4
|
||||
```
|
||||
|
||||
Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).
|
||||
|
||||
## Command Reference
|
||||
|
||||
### fetch
|
||||
@@ -80,7 +94,11 @@ axolotl train config.yml \
|
||||
--num-epochs 3
|
||||
|
||||
# Training without accelerate
|
||||
axolotl train config.yml --no-accelerate
|
||||
axolotl train config.yml --launcher python
|
||||
|
||||
# Pass launcher-specific arguments using -- separator
|
||||
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
|
||||
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml
|
||||
|
||||
# Resume training from checkpoint
|
||||
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
|
||||
@@ -175,6 +193,9 @@ Evaluates a model's performance (loss etc) on the train and eval datasets.
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl evaluate config.yml
|
||||
|
||||
# Evaluation with launcher arguments
|
||||
axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2
|
||||
```
|
||||
|
||||
### lm-eval
|
||||
@@ -287,9 +308,6 @@ axolotl preprocess config.yml --cloud cloud_config.yml
|
||||
# Train on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml
|
||||
|
||||
# Train without accelerate on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
|
||||
|
||||
# Run lm-eval on cloud
|
||||
axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
```
|
||||
|
||||
@@ -69,11 +69,19 @@ export NCCL_BUFFSIZE=2097152
|
||||
|
||||
Run the following on each node:
|
||||
|
||||
### Option 1: New Axolotl CLI with launcher args (Recommended)
|
||||
|
||||
```bash
|
||||
axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port"
|
||||
```
|
||||
|
||||
### Option 2: Direct torchrun (Legacy)
|
||||
|
||||
```bash
|
||||
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
Please make sure to substitute the placeholder variables.
|
||||
Please make sure to substitute the placeholder variables:
|
||||
|
||||
- `num_nodes`: Number of nodes (containing GPUs)
|
||||
- `gpu_per_node`: Number of gpus per node
|
||||
@@ -81,8 +89,6 @@ Please make sure to substitute the placeholder variables.
|
||||
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
|
||||
- `rdzv_id`: A unique job ID that is used by the job across nodes.
|
||||
|
||||
::: {.callout-note}
|
||||
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
|
||||
:::
|
||||
The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.
|
||||
|
||||
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)
|
||||
|
||||
102
docs/nd_parallelism.qmd
Normal file
102
docs/nd_parallelism.qmd
Normal file
@@ -0,0 +1,102 @@
|
||||
# N-D Parallelism
|
||||
|
||||
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
||||
|
||||
- A model's weights are too large to fit on a single GPU's memory.
|
||||
- A model's activations, especially with very long contexts, are too large for a single GPU.
|
||||
- You want to accelerate training by using multiple GPUs or nodes.
|
||||
|
||||
or combinations of the above!
|
||||
|
||||
## Core Concepts
|
||||
|
||||
Parallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorch's `DeviceMesh` is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid.
|
||||
|
||||
### Data Parallelism {#sec-dp}
|
||||
|
||||
Data Parallelism focuses on splitting the global data batch across GPUs.
|
||||
|
||||
- Distributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states.
|
||||
|
||||
- [Fully Sharded Data Parallel (FSDP)](multi-gpu.qmd#fully-sharded-data-parallel-(fsdp)): A highly memory-efficient form of data parallelism (inspired by DeepSpeed's ZeRO). Instead of replicating the model, FSDP shards the model's *parameters, gradients, and optimizer states* across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an `all_gather` operation just before they are used, and they can be discarded immediately after (`reshard-after-forward`).
|
||||
- FSDP maps to ZeRO stages:
|
||||
- ZeRO-2 (`reshard_after_forward=False`): Shards gradients and optimizer states. Model weights are replicated on each GPU.
|
||||
- ZeRO-3 (`reshard_after_forward=True`): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes).
|
||||
|
||||
### [Experimental] Tensor Parallelism (TP) {#sec-tp}
|
||||
|
||||
Also known as "horizontal model parallelism," as described in the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Instead of splitting the batch, TP splits the model's layers themselves across GPUs.
|
||||
|
||||
- How it works: For a linear layer `Y = XA`, the weight matrix `A` is split column-wise (`A = [A_1, A_2]`). The computation becomes `Y_1 = XA_1` and `Y_2 = XA_2`, which can happen in parallel on different GPUs. The final output `Y` is simply the concatenation of `Y_1` and `Y_2`. Check [this comment](https://github.com/huggingface/transformers/issues/10321#issuecomment-783543530) for more detailed info.
|
||||
- Requirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes.
|
||||
|
||||
### Context Parallelism (CP) {#sec-cp}
|
||||
|
||||
Context Parallelism, also called [Sequence Parallelism](sequence_parallelism.qmd), addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs.
|
||||
|
||||
- How it works: If you have a sequence of 8192 tokens and a `context_parallel_size` of 4, each GPU will only handle a chunk of 2048 tokens.
|
||||
- The Challenge: Attention is not local; every token needs to "attend to" every other token. Splitting the sequence breaks this.
|
||||
- The Solution (`ring-flash-attention`): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a "ring." After `N-1` steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized `flash-attention` kernel at each step.
|
||||
|
||||
### Hybrid Sharding Data Parallel (HSDP) {#sec-hsdp}
|
||||
|
||||
HSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training.
|
||||
|
||||
- Intra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the `all_gather` operations for sharded parameters fast.
|
||||
- Inter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDP's parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband).
|
||||
- Example: With 2 nodes of 8 GPUs each (16 total), you could have `dp_shard_size=8` (FSDP within each node) and `dp_replicate_size=2` (DDP across the two nodes).
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
# ...
|
||||
|
||||
# The number of GPUs to shard the model parameters across (FSDP dimension).
|
||||
dp_shard_size: 4
|
||||
|
||||
# The number of times to replicate the sharded model (DDP dimension).
|
||||
dp_replicate_size: 2
|
||||
|
||||
# Number of GPUs for Tensor Parallelism.
|
||||
tensor_parallel_size: 1 # (default is 1, no TP)
|
||||
|
||||
# Number of GPUs for Context/Sequence Parallelism.
|
||||
context_parallel_size: 1 # (default is 1, no CP)
|
||||
```
|
||||
|
||||
Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size`.
|
||||
|
||||
## Examples
|
||||
|
||||
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
|
||||
- You want FSDP within each node and DDP across nodes.
|
||||
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
|
||||
|
||||
2. FSDP + TP on a single 8-GPU node:
|
||||
- You want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP.
|
||||
- Set `dp_shard_size: 4` and `tensor_parallel_size: 2`.
|
||||
|
||||
3. FSDP + CP on a single 8-GPU node for long context:
|
||||
- You want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs.
|
||||
- Set `dp_shard_size: 8` and `context_parallel_size: 8`. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group.
|
||||
|
||||
## Support Matrix
|
||||
|
||||
This matrix describes how different parallelism methods can be combined in Axolotl.
|
||||
|
||||
| Combination | `dp_replicate_size` | `dp_shard_size` | `tp_size` | `cp_size` | Status & Notes |
|
||||
| --- | :---: | :---: |:---:|:---:|---|
|
||||
| **FSDP** (ZeRO-3) | 1 | >1 | 1 | 1 | ✅ Fully supported. Shards model across all GPUs. |
|
||||
| **HSDP** | >1 | >1 | 1 | 1 | ✅ Fully supported. FSDP intra-node, DDP inter-node. |
|
||||
| **FSDP + TP** | 1 | >1 | >1 | 1 | ✅ **2D Parallelism**. Shards the model across a `dp_shard` group, and TP-splits layers within the `tp` group. |
|
||||
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
|
||||
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
|
||||
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
|
||||
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP/CP without FSDP is inefficient and complex. You should use FSDP instead (`dp_shard_size > 1`). |
|
||||
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
|
||||
|
||||
- `tp_size` refers to `tensor_parallel_size`
|
||||
- `cp_size` refers to `context_parallel_size`
|
||||
@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
context_parallel_size: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -30,7 +30,7 @@ heads_k_stride: 1
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||
- With 4 GPUs, valid values would be 2 or 4
|
||||
@@ -66,7 +66,7 @@ sequence_len: 8192
|
||||
|
||||
...
|
||||
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
|
||||
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
@@ -20,7 +20,7 @@ min_sample_len: 200_000
|
||||
sample_packing: true
|
||||
|
||||
tiled_mlp: true
|
||||
sequence_parallel_degree: 8
|
||||
context_parallel_size: 8
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -25,9 +25,12 @@ lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
relora_steps: 150
|
||||
relora_warmup_ratio: 0.1
|
||||
relora: true
|
||||
relora_prune_ratio: 0.9
|
||||
relora_cpu_offload: false
|
||||
jagged_restart_steps: 150
|
||||
jagged_restart_warmup_steps: 10
|
||||
jagged_restart_anneal_steps: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -6,19 +6,19 @@ triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.6.0
|
||||
liger-kernel==0.6.1
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft==0.16.0
|
||||
transformers==4.54.0
|
||||
transformers==4.54.1
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.9.0
|
||||
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
|
||||
datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.19.1
|
||||
trl==0.20.0
|
||||
hf_xet==1.1.5
|
||||
|
||||
optimum==1.16.2
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
|
||||
)
|
||||
|
||||
3
setup.py
3
setup.py
@@ -72,12 +72,13 @@ def parse_requirements(extras_require_map):
|
||||
extras_require_map.pop("vllm")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.31")
|
||||
extras_require_map["vllm"] = ["vllm>=0.10.0"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
# since we only support 2.6.0+cu126
|
||||
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
|
||||
@@ -30,8 +30,6 @@ class TrainerCliArgs:
|
||||
debug_num_examples: int = field(default=0)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
main_process_port: Optional[int] = field(default=None)
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -3,7 +3,7 @@ launch axolotl in supported cloud platforms
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -11,7 +11,7 @@ from axolotl.cli.cloud.modal_ import ModalCloud
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
def load_cloud_cfg(cloud_config: Path | str) -> DictDefault:
|
||||
"""Load and validate cloud configuration."""
|
||||
# Load cloud configuration.
|
||||
with open(cloud_config, encoding="utf-8") as file:
|
||||
@@ -20,8 +20,8 @@ def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
|
||||
|
||||
def do_cli_preprocess(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
) -> None:
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
@@ -31,9 +31,10 @@ def do_cli_preprocess(
|
||||
|
||||
|
||||
def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
accelerate: bool = True,
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
cwd=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -44,12 +45,18 @@ def do_cli_train(
|
||||
local_dirs = {}
|
||||
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
|
||||
local_dirs = {"/workspace/mounts": cwd}
|
||||
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
|
||||
cloud.train(
|
||||
config_yaml,
|
||||
launcher=launcher,
|
||||
launcher_args=launcher_args,
|
||||
local_dirs=local_dirs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
) -> None:
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
|
||||
@@ -3,6 +3,7 @@ base class for cloud platforms from cli
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class Cloud(ABC):
|
||||
@@ -15,5 +16,12 @@ class Cloud(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self, config_yaml: str, accelerate: bool = True) -> str:
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
from typing import Literal
|
||||
|
||||
import modal
|
||||
|
||||
@@ -230,8 +230,9 @@ class ModalCloud(Cloud):
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
accelerate: bool = True,
|
||||
local_dirs: Optional[dict[str, str]] = None,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
modal_fn = self.get_train_env(local_dirs)(_train)
|
||||
@@ -239,7 +240,8 @@ class ModalCloud(Cloud):
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
launcher=launcher,
|
||||
launcher_args=launcher_args,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
**kwargs,
|
||||
)
|
||||
@@ -270,20 +272,35 @@ def _preprocess(config_yaml: str, volumes=None):
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
def _train(
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
volumes=None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
if accelerate:
|
||||
accelerate_args = "--accelerate"
|
||||
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
# Build the base command
|
||||
if launcher == "accelerate":
|
||||
launcher_arg = "--launcher accelerate"
|
||||
elif launcher == "torchrun":
|
||||
launcher_arg = "--launcher torchrun"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
num_processes_args = ""
|
||||
if num_processes := kwargs.pop("num_processes", None):
|
||||
num_processes_args = f"--num-processes {num_processes}"
|
||||
launcher_arg = "--launcher python"
|
||||
|
||||
# Build launcher args string
|
||||
launcher_args_str = ""
|
||||
if launcher_args:
|
||||
launcher_args_str = "-- " + " ".join(launcher_args)
|
||||
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
|
||||
f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(),
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
@@ -200,14 +200,13 @@ def load_cfg(
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
for key, value in kwargs.items():
|
||||
# If not strict, allow writing to cfg even if it's not in the yml already
|
||||
if key in cfg_keys or not cfg.strict:
|
||||
if isinstance(cfg[key], bool):
|
||||
cfg[key] = bool(value)
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
cfg[key] = value
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Generator, Union
|
||||
import fire
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
@@ -152,5 +151,4 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
@@ -13,7 +12,6 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.evaluate import evaluate
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -30,9 +28,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
@@ -64,5 +59,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Union
|
||||
import fire
|
||||
import torch
|
||||
import transformers
|
||||
from dotenv import load_dotenv
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.cli.args import InferenceCliArgs
|
||||
@@ -268,5 +267,4 @@ def do_cli(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -4,12 +4,9 @@
|
||||
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
@@ -21,13 +18,14 @@ from axolotl.cli.args import (
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
build_command,
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
generate_config_files,
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
@@ -36,12 +34,19 @@ from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
LAUNCHER_COMMAND_MAPPING = {
|
||||
"accelerate": ["accelerate", "launch"],
|
||||
"torchrun": ["torchrun"],
|
||||
}
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
patch_optimized_env()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -50,7 +55,7 @@ def cli():
|
||||
@add_options_from_dataclass(PreprocessCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Preprocess datasets before training.
|
||||
|
||||
@@ -60,7 +65,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
patch_optimized_env()
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
@@ -72,12 +76,15 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU training",
|
||||
)
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
@@ -88,126 +95,81 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
@click.pass_context
|
||||
def train(
|
||||
ctx: click.Context,
|
||||
config: str,
|
||||
accelerate: bool,
|
||||
cloud: Optional[str] = None,
|
||||
sweep: Optional[str] = None,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
cloud: str | None = None,
|
||||
sweep: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Train or fine-tune a model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python").
|
||||
cloud: Path to a cloud accelerator configuration file
|
||||
sweep: Path to YAML config for sweeping hyperparameters.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
if sweep:
|
||||
# load the sweep configuration yaml file
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
# Handle Ray launcher override
|
||||
_launcher = None if kwargs.get("use_ray") else launcher
|
||||
|
||||
# generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
|
||||
def iter_configs():
|
||||
for perm in permutations:
|
||||
# open temp directory for temporary configurations
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with open(
|
||||
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
|
||||
) as fout:
|
||||
yaml.dump(perm, fout)
|
||||
yield str(Path(temp_dir) / "config.yaml")
|
||||
|
||||
else:
|
||||
|
||||
def iter_configs():
|
||||
yield config
|
||||
|
||||
for cfg_file in iter_configs():
|
||||
# handle errors from subprocess so we can continue rest of sweeps
|
||||
# Process each configuration
|
||||
for cfg_file in generate_config_files(config, sweep):
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
cwd = os.getcwd()
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=config,
|
||||
accelerate=True,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num_processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||
)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
raise exc
|
||||
finally:
|
||||
# Only delete temp files, not the original config
|
||||
if cfg_file != config:
|
||||
os.unlink(cfg_file)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU evaluation",
|
||||
)
|
||||
@add_options_from_dataclass(EvaluateCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
|
||||
"""
|
||||
Evaluate a model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python").
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.evaluate"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
@@ -218,30 +180,42 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=False,
|
||||
help="Use accelerate launch for multi-GPU inference",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU inference",
|
||||
)
|
||||
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
|
||||
"""
|
||||
Run inference with a trained model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
|
||||
gradio: Whether to use Gradio browser interface or command line for inference.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.inference"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
if gradio:
|
||||
@@ -254,33 +228,42 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
do_cli(config=config, gradio=gradio, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for weight merging",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for weight merging",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def merge_sharded_fsdp_weights(
|
||||
ctx: click.Context, config: str, launcher: str, **kwargs
|
||||
):
|
||||
"""
|
||||
Merge sharded FSDP model weights.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python").
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
"axolotl.cli.merge_sharded_fsdp_weights",
|
||||
]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.merge_sharded_fsdp_weights"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
@@ -296,7 +279,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def merge_lora(config: str, **kwargs) -> None:
|
||||
def merge_lora(config: str, **kwargs):
|
||||
"""
|
||||
Merge trained LoRA adapters into a base model.
|
||||
|
||||
@@ -313,7 +296,7 @@ def merge_lora(config: str, **kwargs) -> None:
|
||||
@cli.command()
|
||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||
@click.option("--dest", help="Destination directory")
|
||||
def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
def fetch(directory: str, dest: Optional[str]):
|
||||
"""
|
||||
Fetch example configs or other resources.
|
||||
|
||||
@@ -351,7 +334,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs):
|
||||
@cli.command()
|
||||
@click.argument("model", type=click.Path(exists=True, path_type=str))
|
||||
@click.argument("output", type=click.Path(exists=False, path_type=str))
|
||||
def delinearize_llama4(model: str, output: str) -> None:
|
||||
def delinearize_llama4(model: str, output: str):
|
||||
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
|
||||
|
||||
do_delinearize_llama4(model, output)
|
||||
@@ -365,5 +348,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
main()
|
||||
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
@@ -70,7 +69,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
sequence_parallel_degree=None,
|
||||
context_parallel_size=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
fsdp_config=None,
|
||||
@@ -88,5 +87,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -17,7 +17,6 @@ from accelerate.utils import (
|
||||
WEIGHTS_NAME,
|
||||
is_torch_version,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
@@ -204,5 +203,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -9,7 +9,6 @@ import fire
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from colorama import Fore
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.args import PreprocessCliArgs
|
||||
@@ -109,5 +108,4 @@ def do_cli(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Union
|
||||
|
||||
import fire
|
||||
from accelerate import Accelerator
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
@@ -16,7 +15,6 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -31,9 +29,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
@@ -122,5 +117,4 @@ def ray_train_func(kwargs: dict):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
"""Utility methods for axolotl CLI."""
|
||||
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
Args:
|
||||
field_type: Type of field for Axolotl CLI command.
|
||||
|
||||
Returns:
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
return field_type
|
||||
|
||||
|
||||
def filter_none_kwargs(func: Callable) -> Callable:
|
||||
"""
|
||||
Wraps function to remove `None`-valued `kwargs`.
|
||||
|
||||
Args:
|
||||
func: Function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Callable:
|
||||
"""Filters out `None`-valued `kwargs`."""
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
return func(*args, **filtered_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
config_class: Dataclass with fields to parse from the CLI.
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = strip_optional_type(field.type)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{field.name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = strip_optional_type(field.annotation)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
Args:
|
||||
base_cmd: Command without options.
|
||||
options: Options to parse and append to base command.
|
||||
|
||||
Returns:
|
||||
List of strings giving shell command.
|
||||
"""
|
||||
cmd = base_cmd.copy()
|
||||
|
||||
for key, value in options.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
key = key.replace("_", "-")
|
||||
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
cmd.append(f"--{key}")
|
||||
else:
|
||||
cmd.extend([f"--{key}", str(value)])
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
Args:
|
||||
file_info: Tuple of (file_path, remote_sha).
|
||||
raw_base_url: Base URL for raw GitHub content.
|
||||
dest_path: Local destination directory.
|
||||
dir_prefix: Directory prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||
"""
|
||||
file_path, remote_sha = file_info
|
||||
raw_url = f"{raw_base_url}/{file_path}"
|
||||
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||
|
||||
# Check if file exists and needs updating
|
||||
if dest_file.exists():
|
||||
with open(dest_file, "rb") as file:
|
||||
content = file.read()
|
||||
# Calculate git blob SHA
|
||||
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||
|
||||
if local_sha == remote_sha:
|
||||
print(f"Skipping {file_path} (unchanged)")
|
||||
return file_path, "unchanged"
|
||||
|
||||
print(f"Updating {file_path}")
|
||||
status = "new"
|
||||
else:
|
||||
print(f"Downloading {file_path}")
|
||||
status = "new"
|
||||
|
||||
# Create directories if needed
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download and save file
|
||||
try:
|
||||
response = requests.get(raw_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(dest_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
return file_path, status
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||
return file_path, "error"
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
Only downloads files that don't exist locally or have changed.
|
||||
|
||||
Args:
|
||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||
'deepspeed_configs/').
|
||||
dest_dir: Local destination directory.
|
||||
max_workers: Maximum number of concurrent downloads.
|
||||
"""
|
||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||
|
||||
# Get repository tree with timeout
|
||||
response = requests.get(api_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
tree = json.loads(response.text)
|
||||
|
||||
# Filter for files and get their SHA
|
||||
files = {
|
||||
item["path"]: item["sha"]
|
||||
for item in tree["tree"]
|
||||
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||
}
|
||||
|
||||
if not files:
|
||||
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||
|
||||
# Default destination directory is the last part of dir_prefix
|
||||
default_dest = Path(dir_prefix.rstrip("/"))
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
"error": [],
|
||||
}
|
||||
|
||||
# Process files in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(
|
||||
download_file,
|
||||
(file_path, remote_sha),
|
||||
raw_base_url,
|
||||
dest_path,
|
||||
dir_prefix,
|
||||
): file_path
|
||||
for file_path, remote_sha in files.items()
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
file_path, status = future.result()
|
||||
files_processed[status].append(file_path)
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error processing {file_path}: {str(request_error)}")
|
||||
files_processed["error"].append(file_path)
|
||||
|
||||
# Log summary
|
||||
LOG.info("\nSync Summary:")
|
||||
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||
if files_processed["error"]:
|
||||
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
23
src/axolotl/cli/utils/__init__.py
Normal file
23
src/axolotl/cli/utils/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Init for axolotl.cli.utils module."""
|
||||
|
||||
from .args import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from .fetch import fetch_from_github
|
||||
from .load import load_model_and_tokenizer
|
||||
from .sweeps import generate_sweep_configs
|
||||
from .train import build_command, generate_config_files, launch_training
|
||||
|
||||
__all__ = [
|
||||
"filter_none_kwargs",
|
||||
"add_options_from_dataclass",
|
||||
"add_options_from_config",
|
||||
"build_command",
|
||||
"generate_config_files",
|
||||
"generate_sweep_configs",
|
||||
"load_model_and_tokenizer",
|
||||
"launch_training",
|
||||
"fetch_from_github",
|
||||
]
|
||||
120
src/axolotl/cli/utils/args.py
Normal file
120
src/axolotl/cli/utils/args.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Utilities for axolotl CLI args."""
|
||||
|
||||
import dataclasses
|
||||
from functools import wraps
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
Args:
|
||||
field_type: Type of field for Axolotl CLI command.
|
||||
|
||||
Returns:
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
return field_type
|
||||
|
||||
|
||||
def filter_none_kwargs(func: Callable) -> Callable:
|
||||
"""
|
||||
Wraps function to remove `None`-valued `kwargs`.
|
||||
|
||||
Args:
|
||||
func: Function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Callable:
|
||||
"""Filters out `None`-valued `kwargs`."""
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
return func(*args, **filtered_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
config_class: Dataclass with fields to parse from the CLI.
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = _strip_optional_type(field.type)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{field.name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = _strip_optional_type(field.annotation)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
142
src/axolotl/cli/utils/fetch.py
Normal file
142
src/axolotl/cli/utils/fetch.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Utilities for axolotl fetch CLI command."""
|
||||
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import requests
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
Args:
|
||||
file_info: Tuple of (file_path, remote_sha).
|
||||
raw_base_url: Base URL for raw GitHub content.
|
||||
dest_path: Local destination directory.
|
||||
dir_prefix: Directory prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||
"""
|
||||
file_path, remote_sha = file_info
|
||||
raw_url = f"{raw_base_url}/{file_path}"
|
||||
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||
|
||||
# Check if file exists and needs updating
|
||||
if dest_file.exists():
|
||||
with open(dest_file, "rb") as file:
|
||||
content = file.read()
|
||||
# Calculate git blob SHA
|
||||
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||
|
||||
if local_sha == remote_sha:
|
||||
print(f"Skipping {file_path} (unchanged)")
|
||||
return file_path, "unchanged"
|
||||
|
||||
print(f"Updating {file_path}")
|
||||
status = "updated"
|
||||
else:
|
||||
print(f"Downloading {file_path}")
|
||||
status = "new"
|
||||
|
||||
# Create directories if needed
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download and save file
|
||||
try:
|
||||
response = requests.get(raw_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(dest_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
return file_path, status
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||
return file_path, "error"
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
Only downloads files that don't exist locally or have changed.
|
||||
|
||||
Args:
|
||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||
'deepspeed_configs/').
|
||||
dest_dir: Local destination directory.
|
||||
max_workers: Maximum number of concurrent downloads.
|
||||
"""
|
||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||
|
||||
# Get repository tree with timeout
|
||||
response = requests.get(api_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
tree = json.loads(response.text)
|
||||
|
||||
# Filter for files and get their SHA
|
||||
files = {
|
||||
item["path"]: item["sha"]
|
||||
for item in tree["tree"]
|
||||
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||
}
|
||||
|
||||
if not files:
|
||||
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||
|
||||
# Default destination directory is the last part of dir_prefix
|
||||
default_dest = Path(dir_prefix.rstrip("/"))
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
"error": [],
|
||||
}
|
||||
|
||||
# Process files in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(
|
||||
_download_file,
|
||||
(file_path, remote_sha),
|
||||
raw_base_url,
|
||||
dest_path,
|
||||
dir_prefix,
|
||||
): file_path
|
||||
for file_path, remote_sha in files.items()
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
file_path, status = future.result()
|
||||
files_processed[status].append(file_path)
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error processing {file_path}: {str(request_error)}")
|
||||
files_processed["error"].append(file_path)
|
||||
|
||||
# Log summary
|
||||
LOG.info("\nSync Summary:")
|
||||
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||
if files_processed["error"]:
|
||||
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||
52
src/axolotl/cli/utils/load.py
Normal file
52
src/axolotl/cli/utils/load.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Utilities for model, tokenizer, etc. loading."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the
|
||||
given `axolotl` config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
188
src/axolotl/cli/utils/train.py
Normal file
188
src/axolotl/cli/utils/train.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Utilities for axolotl train CLI command."""
|
||||
|
||||
import os
|
||||
import subprocess # nosec
|
||||
import tempfile
|
||||
from typing import Any, Iterator, Literal
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.utils.sweeps import generate_sweep_configs
|
||||
|
||||
|
||||
def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]:
|
||||
"""
|
||||
Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing.
|
||||
|
||||
Args:
|
||||
launcher_args: List of launcher arguments
|
||||
|
||||
Returns:
|
||||
Updated launcher args with defaults added if needed
|
||||
"""
|
||||
args = launcher_args.copy()
|
||||
|
||||
# Check if rdzv_endpoint is present
|
||||
has_rdzv_endpoint = any("--rdzv_endpoint" in arg for arg in args)
|
||||
|
||||
if has_rdzv_endpoint:
|
||||
# Check if rdzv_backend is already provided
|
||||
has_rdzv_backend = any("--rdzv_backend" in arg for arg in args)
|
||||
if not has_rdzv_backend:
|
||||
args.extend(["--rdzv_backend", "c10d"])
|
||||
|
||||
# Check if rdzv_id is already provided
|
||||
has_rdzv_id = any("--rdzv_id" in arg for arg in args)
|
||||
if not has_rdzv_id:
|
||||
import uuid
|
||||
|
||||
args.extend(["--rdzv_id", str(uuid.uuid4())[:8]])
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
Args:
|
||||
base_cmd: Command without options.
|
||||
options: Options to parse and append to base command.
|
||||
|
||||
Returns:
|
||||
List of strings giving shell command.
|
||||
"""
|
||||
cmd = base_cmd.copy()
|
||||
|
||||
for key, value in options.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
key = key.replace("_", "-")
|
||||
cmd.append(f"--{key}={value}")
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
"""Generate list of configuration files to process."""
|
||||
if not sweep:
|
||||
yield config
|
||||
return
|
||||
|
||||
# Load sweep and base configurations
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
|
||||
# Generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
for permutation in permutations:
|
||||
# pylint: disable=consider-using-with
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=".yaml",
|
||||
delete=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
yaml.dump(permutation, temp_file)
|
||||
temp_file.close()
|
||||
yield temp_file.name
|
||||
|
||||
|
||||
def launch_training(
|
||||
cfg_file: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] | None,
|
||||
cloud: str | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Execute training with the given configuration."""
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
if cloud:
|
||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||
elif launcher:
|
||||
if launcher == "accelerate":
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
||||
elif launcher == "torchrun":
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
||||
elif launcher == "python":
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
|
||||
|
||||
def _launch_cloud_training(
|
||||
cloud: str,
|
||||
cfg_file: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Execute training via cloud launcher."""
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
launcher_args = launcher_args or []
|
||||
cwd = os.getcwd() if launcher else None
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=cfg_file,
|
||||
launcher=launcher or "accelerate",
|
||||
launcher_args=launcher_args,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _launch_accelerate_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
) -> None:
|
||||
"""Execute training via accelerate launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
internal_launcher_args = []
|
||||
|
||||
# Extract launcher-specific arguments from kwargs (legacy support)
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port")
|
||||
internal_launcher_args.extend(["--main_process_port", str(main_process_port)])
|
||||
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes")
|
||||
internal_launcher_args.extend(["--num_processes", str(num_processes)])
|
||||
|
||||
# Combine internal args with user-provided launcher args
|
||||
all_launcher_args = internal_launcher_args + launcher_args
|
||||
|
||||
base_cmd = (
|
||||
["accelerate", "launch"] + all_launcher_args + ["-m", "axolotl.cli.train"]
|
||||
)
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_torchrun_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
) -> None:
|
||||
"""Execute training via torchrun launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
# Add default RDZV arguments if rdzv_endpoint is set
|
||||
launcher_args = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
base_cmd = ["torchrun"] + launcher_args + ["-m", "axolotl.cli.train"]
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||
"""Execute training via python launcher."""
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
@@ -24,9 +24,11 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from accelerate import PartialState
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
@@ -34,7 +36,6 @@ from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
GPUStatsCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveModelOnFirstStepCallback,
|
||||
)
|
||||
@@ -139,8 +140,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.save_first_step:
|
||||
callbacks.append(SaveModelOnFirstStepCallback())
|
||||
|
||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
@@ -434,8 +433,30 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
partial_state = PartialState()
|
||||
has_pc_attr = (
|
||||
hasattr(partial_state, "parallelism_config")
|
||||
and partial_state.parallelism_config
|
||||
)
|
||||
has_pc_key = (
|
||||
"parallelism_config"
|
||||
in partial_state._shared_state # pylint: disable=protected-access
|
||||
and partial_state._shared_state[ # pylint: disable=protected-access
|
||||
"parallelism_config"
|
||||
]
|
||||
)
|
||||
use_configured_state = has_pc_attr or has_pc_key
|
||||
if self.cfg.accelerator_config:
|
||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
||||
use_configured_state = self.cfg.accelerator_config.pop(
|
||||
"use_configured_state", use_configured_state
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state,
|
||||
)
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.activation_offloading is True:
|
||||
|
||||
@@ -19,7 +19,6 @@ from axolotl.core.trainers import (
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
@@ -58,7 +57,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
if self.cfg.relora:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
@@ -131,8 +130,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
if trainer_cls:
|
||||
return trainer_cls
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.reward_model:
|
||||
@@ -271,20 +268,25 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.sample_packing_eff_est
|
||||
)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = (
|
||||
self.cfg.relora_warmup_steps
|
||||
)
|
||||
if self.cfg.relora_anneal_steps:
|
||||
training_arguments_kwargs["relora_anneal_steps"] = (
|
||||
self.cfg.relora_anneal_steps
|
||||
)
|
||||
if self.cfg.relora and self.cfg.jagged_restart_steps:
|
||||
if self.cfg.relora_prune_ratio:
|
||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||
self.cfg.relora_prune_ratio
|
||||
)
|
||||
|
||||
if self.cfg.jagged_restart_steps:
|
||||
training_arguments_kwargs["jagged_restart_steps"] = (
|
||||
self.cfg.jagged_restart_steps
|
||||
)
|
||||
if self.cfg.jagged_restart_warmup_steps:
|
||||
training_arguments_kwargs["jagged_restart_warmup_steps"] = (
|
||||
self.cfg.jagged_restart_warmup_steps
|
||||
)
|
||||
if self.cfg.jagged_restart_anneal_steps:
|
||||
training_arguments_kwargs["jagged_restart_anneal_steps"] = (
|
||||
self.cfg.jagged_restart_anneal_steps
|
||||
)
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||
training_arguments_kwargs["lisa_step_interval"] = (
|
||||
|
||||
@@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
|
||||
@@ -27,6 +27,7 @@ from typing_extensions import override
|
||||
from axolotl.core.trainers.mixins import (
|
||||
ActivationOffloadingMixin,
|
||||
CheckpointSaveMixin,
|
||||
DistributedParallelMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
RngLoaderMixin,
|
||||
@@ -37,6 +38,8 @@ from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils import get_not_null
|
||||
from axolotl.utils.bench import get_gpu_memory_usage
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -50,6 +53,7 @@ class AxolotlTrainer(
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
ActivationOffloadingMixin,
|
||||
DistributedParallelMixin,
|
||||
Trainer,
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
@@ -558,6 +562,17 @@ class AxolotlTrainer(
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
|
||||
if is_main_process():
|
||||
# Add memory usage
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage()
|
||||
logs["memory/max_memory_active"] = active
|
||||
logs["memory/max_memory_allocated"] = allocated
|
||||
logs["memory/device_memory_reserved"] = reserved
|
||||
except (ValueError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
return super().log(logs, start_time)
|
||||
|
||||
@@ -8,7 +8,11 @@ import torch
|
||||
from torch import nn
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import (
|
||||
DistributedParallelMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
@@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import (
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DPOTrainer,
|
||||
DistributedParallelMixin,
|
||||
):
|
||||
"""Extend the base DPOTrainer for axolotl helpers."""
|
||||
|
||||
|
||||
@@ -49,7 +49,8 @@ class GRPOStrategy:
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode:
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
@@ -82,8 +83,13 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||
if cfg.context_parallel_size > 1:
|
||||
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
|
||||
|
||||
if trl.importance_sampling_level is not None:
|
||||
grpo_args_kwargs["importance_sampling_level"] = (
|
||||
trl.importance_sampling_level
|
||||
)
|
||||
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
- Data is properly distributed across SP groups.
|
||||
|
||||
In the table below, the values represent dataset indices. Each SP group has
|
||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
`context_parallel_size = 2` GPUs working together on the same data. There are 2
|
||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
|
||||
Sequence Parallel Groups
|
||||
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: Rank of current process.
|
||||
batch_size: Number of samples per batch.
|
||||
repeat_count: How many times to repeat the full sampling process.
|
||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
||||
context_parallel_size: Number of ranks in a sequence parallel group.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
seed: Random seed for shuffling.
|
||||
drop_last: Whether to drop the last incomplete batch.
|
||||
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
sequence_parallel_degree: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
self.rank = rank
|
||||
|
||||
# Sequence parallelism parameters
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
||||
self.sp_group_id = rank // sequence_parallel_degree
|
||||
self.context_parallel_size = context_parallel_size
|
||||
self.num_sp_groups = world_size // context_parallel_size
|
||||
self.sp_group_id = rank // context_parallel_size
|
||||
|
||||
# Adjust dataset size for distributed sampling
|
||||
self.num_samples = len(self.dataset)
|
||||
|
||||
@@ -43,7 +43,11 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import (
|
||||
DistributedParallelMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
|
||||
|
||||
@@ -53,7 +57,12 @@ if is_peft_available():
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
GRPOTrainer,
|
||||
):
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
|
||||
@@ -100,7 +109,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
# Get number of SP groups (number of processes divided by SP degree)
|
||||
num_processes = self.accelerator.num_processes
|
||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||
num_sp_groups = num_processes // self.args.context_parallel_size
|
||||
|
||||
# Calculate batch size per SP group (not per process)
|
||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||
@@ -130,7 +139,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
|
||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"must be evenly divisible by the number of generations per prompt "
|
||||
f"({self.num_generations}). Given the current eval batch size, "
|
||||
@@ -167,9 +176,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
rank=self.rank,
|
||||
batch_size=effective_batch_size
|
||||
// self.num_generations
|
||||
// self.args.sequence_parallel_degree,
|
||||
// self.args.context_parallel_size,
|
||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
||||
context_parallel_size=self.args.context_parallel_size,
|
||||
shuffle=True,
|
||||
seed=self.args.seed,
|
||||
drop_last=True,
|
||||
@@ -235,7 +244,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
@@ -308,18 +317,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate sequence parallel group information
|
||||
world_size = self.accelerator.num_processes
|
||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
||||
num_sp_groups = world_size // sequence_parallel_degree
|
||||
context_parallel_size = self.args.context_parallel_size
|
||||
num_sp_groups = world_size // context_parallel_size
|
||||
|
||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each SP group
|
||||
ordered_set_of_prompts = []
|
||||
for sp_group_id in range(num_sp_groups):
|
||||
# Get the first process from each SP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
||||
group_leader_rank = sp_group_id * context_parallel_size
|
||||
|
||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each SP group
|
||||
@@ -335,7 +344,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[
|
||||
:: self.num_generations * self.args.sequence_parallel_degree
|
||||
:: self.num_generations * self.args.context_parallel_size
|
||||
]
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
@@ -352,14 +361,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * (
|
||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||
len(all_prompts_text) // self.args.context_parallel_size
|
||||
)
|
||||
|
||||
# Broadcast the completions from the main process to all processes
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
|
||||
# Determine the appropriate slice based on sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
@@ -583,7 +592,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""Mamba specific trainer to handle loss calculation"""
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
from .activation_checkpointing import ActivationOffloadingMixin
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .distributed_parallel import DistributedParallelMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
|
||||
@@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer):
|
||||
def _save_optimizer_and_scheduler(self, output_dir):
|
||||
try:
|
||||
super()._save_optimizer_and_scheduler(output_dir)
|
||||
except NotImplementedError as exc:
|
||||
LOG.warning(
|
||||
except (NotImplementedError, KeyError) as exc:
|
||||
# TODO: fix fsdp2 optimizer saving
|
||||
LOG.warning_once(
|
||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||
"for this training run will not be possible."
|
||||
"for this training run will not be possible.",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
20
src/axolotl/core/trainers/mixins/distributed_parallel.py
Normal file
20
src/axolotl/core/trainers/mixins/distributed_parallel.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
class DistributedParallelMixin(Trainer):
|
||||
"""
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
def _save(self, output_dir: str | None = None, state_dict=None):
|
||||
if (
|
||||
state_dict is None
|
||||
and self.accelerator.parallelism_config
|
||||
and self.accelerator.parallelism_config.dp_shard_enabled
|
||||
):
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
super()._save(output_dir, state_dict=state_dict)
|
||||
@@ -7,6 +7,7 @@ from transformers.trainer import Trainer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schedulers import (
|
||||
JaggedLRRestartScheduler,
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
@@ -113,7 +114,7 @@ class SchedulerMixin(Trainer):
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning(
|
||||
@@ -123,4 +124,22 @@ class SchedulerMixin(Trainer):
|
||||
LOG.warning(
|
||||
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
if self.args.jagged_restart_steps:
|
||||
warmup_steps = (
|
||||
self.args.jagged_restart_warmup_steps or 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.jagged_restart_anneal_steps or 1
|
||||
)
|
||||
if not self.lr_scheduler:
|
||||
super().create_scheduler(num_training_steps, optimizer)
|
||||
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
self.lr_scheduler,
|
||||
self.args.jagged_restart_steps,
|
||||
warmup_steps,
|
||||
anneal_steps,
|
||||
min_lr_scale=self.args.cosine_min_lr_ratio or 0.001,
|
||||
)
|
||||
|
||||
return self.lr_scheduler # type: ignore
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Module for ReLoRA trainer"""
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: torch.optim.Optimizer | None = None,
|
||||
) -> LRScheduler:
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler: LRScheduler = super().create_scheduler(
|
||||
num_training_steps, optimizer
|
||||
)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler( # type: ignore
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler # type: ignore
|
||||
|
||||
return self.lr_scheduler # type: ignore
|
||||
@@ -8,13 +8,18 @@ from trl import (
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
ORPOTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
@@ -24,7 +29,12 @@ class AxolotlORPOTrainer(
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
KTOTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
@@ -34,7 +44,12 @@ class AxolotlKTOTrainer(
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
CPOTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
@@ -44,7 +59,12 @@ class AxolotlCPOTrainer(
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
RewardTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
@@ -54,7 +74,12 @@ class AxolotlRewardTrainer(
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
PRMTrainer,
|
||||
):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
|
||||
@@ -82,18 +82,26 @@ class AxolotlTrainingMixins:
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
jagged_restart_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for jagged restarts"},
|
||||
)
|
||||
jagged_restart_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "how many warmup steps to take after reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
jagged_restart_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "how many anneal steps to take before reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -41,6 +41,8 @@ plugins:
|
||||
- gemma3n_text
|
||||
- glm
|
||||
- glm4
|
||||
- granite
|
||||
- granitemoe
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
@@ -56,6 +58,8 @@ plugins:
|
||||
- qwen2_5_vl
|
||||
- qwen3
|
||||
- qwen3_moe
|
||||
- smollm3
|
||||
- voxtral
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Custom trainer subclass for Knowledge Distillation (KD)
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -30,13 +28,13 @@ class LigerArgs(BaseModel):
|
||||
Input args for LIGER.
|
||||
"""
|
||||
|
||||
liger_rope: Optional[bool] = None
|
||||
liger_rms_norm: Optional[bool] = None
|
||||
liger_layer_norm: Optional[bool] = None
|
||||
liger_swiglu: Optional[bool] = None
|
||||
liger_glu_activation: Optional[bool] = None
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
liger_rope: bool | None = None
|
||||
liger_rms_norm: bool | None = None
|
||||
liger_layer_norm: bool | None = None
|
||||
liger_swiglu: bool | None = None
|
||||
liger_glu_activation: bool | None = None
|
||||
liger_cross_entropy: bool | None = None
|
||||
liger_fused_linear_cross_entropy: bool | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -66,3 +64,20 @@ class LigerArgs(BaseModel):
|
||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_liger_rms_norm_tensor_parallel(cls, data):
|
||||
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
|
||||
raise ValueError(
|
||||
"`liger_rms_norm` is incompatible with tensor parallelism, "
|
||||
"see https://github.com/linkedin/Liger-Kernel/issues/826"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
|
||||
# TODO @SalmanMohammadi this is a larger fix - investigate
|
||||
if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy:
|
||||
raise ValueError("Tensor parallelism is not compatible with liger losses.")
|
||||
return self
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
"""
|
||||
Axolotl custom modeling module
|
||||
"""
|
||||
|
||||
from .args import AxolotlModelingArgs
|
||||
from .plugin import AxolotlModelingPlugin
|
||||
|
||||
__all__ = [
|
||||
"AxolotlModelingArgs",
|
||||
"AxolotlModelingPlugin",
|
||||
]
|
||||
@@ -1,13 +0,0 @@
|
||||
"""
|
||||
Args for using Axolotl custom modeling
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AxolotlModelingArgs(BaseModel):
|
||||
"""
|
||||
Args for using Axolotl custom modeling
|
||||
"""
|
||||
|
||||
use_liger_fused_rms_add: bool = False
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
Gemma3 modeling
|
||||
"""
|
||||
|
||||
from .modeling_gemma3 import patch_gemma3
|
||||
|
||||
__all__ = [
|
||||
"patch_gemma3",
|
||||
]
|
||||
@@ -1,110 +0,0 @@
|
||||
"""
|
||||
Gemma3 custom decoder layer using liger fused add rms norm kernels
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm
|
||||
from transformers import Cache, GradientCheckpointingLayer
|
||||
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
Gemma3Attention,
|
||||
Gemma3MLP,
|
||||
Gemma3RMSNorm,
|
||||
)
|
||||
|
||||
|
||||
class Gemma3AddRMSNorm(LigerFusedAddRMSNorm):
|
||||
"""
|
||||
Fused add rms norm
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__(hidden_size, eps, offset=1.0, casting_mode="gemma")
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(GradientCheckpointingLayer):
|
||||
"""
|
||||
Gemma3 decoder layer using liger fused add rms norm
|
||||
"""
|
||||
|
||||
def __init__(self, config: Gemma3TextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = Gemma3MLP(config)
|
||||
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Gemma3RMSNorm(
|
||||
self.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.pre_feedforward_layernorm = Gemma3AddRMSNorm(
|
||||
self.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = Gemma3RMSNorm(
|
||||
self.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings_global: torch.Tensor,
|
||||
position_embeddings_local: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
use_cache: bool | None = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[
|
||||
torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor | None] | None
|
||||
]:
|
||||
# pylint: disable=duplicate-code
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# apply global RoPE to non-sliding layer only
|
||||
if self.self_attn.is_sliding:
|
||||
position_embeddings = position_embeddings_local
|
||||
else:
|
||||
position_embeddings = position_embeddings_global
|
||||
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states, residual = self.pre_feedforward_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,) # type: ignore
|
||||
|
||||
return outputs # type: ignore
|
||||
|
||||
|
||||
def patch_gemma3():
|
||||
import transformers.models.gemma3.modeling_gemma3
|
||||
|
||||
transformers.models.gemma3.modeling_gemma3.Gemma3DecoderLayer = Gemma3DecoderLayer
|
||||
sys.modules["transformers.models.gemma3.modeling_gemma3"].Gemma3DecoderLayer = (
|
||||
Gemma3DecoderLayer
|
||||
)
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
Llama modeling
|
||||
"""
|
||||
|
||||
from modeling_llama import patch_llama
|
||||
|
||||
__all__ = [
|
||||
"patch_llama",
|
||||
]
|
||||
@@ -1,86 +0,0 @@
|
||||
"""
|
||||
Custom modeling for Llama for fused rms add kernels
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm
|
||||
from transformers import Cache, GradientCheckpointingLayer
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaMLP,
|
||||
LlamaRMSNorm,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import TransformersKwargs
|
||||
|
||||
|
||||
class LlamaAddRMSNorm(LigerFusedAddRMSNorm):
|
||||
"""
|
||||
Fused add rms norm
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__(hidden_size, eps, casting_mode="llama")
|
||||
|
||||
|
||||
class LlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
"""
|
||||
Llama decoder layer using liger fused add rms norm
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = LlamaMLP(config)
|
||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = LlamaAddRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
use_cache: bool | None = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: (
|
||||
tuple[torch.Tensor, torch.Tensor] | None
|
||||
) = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
||||
# pylint: disable=duplicate-code
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
def patch_llama():
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||
sys.modules["transformers.models.llama.modeling_llama"].LlamaDecoderLayer = (
|
||||
LlamaDecoderLayer
|
||||
)
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
Axolotl custom modeling plugin
|
||||
"""
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
|
||||
class AxolotlModelingPlugin(BasePlugin):
|
||||
"""
|
||||
Axolotl custom modeling plugin
|
||||
"""
|
||||
|
||||
def get_input_args(self) -> str | None:
|
||||
return "axolotl.integrations.modeling.AxolotlModelingArgs"
|
||||
|
||||
def register(self, cfg): # pylint: disable=unused-argument
|
||||
if cfg.use_liger_fused_rms_add:
|
||||
from .gemma3 import patch_gemma3
|
||||
from .llama import patch_llama
|
||||
|
||||
patch_gemma3()
|
||||
patch_llama()
|
||||
@@ -13,7 +13,8 @@ import peft
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate import PartialState, init_empty_weights
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from peft import (
|
||||
PeftConfig,
|
||||
PeftMixedModel,
|
||||
@@ -48,10 +49,7 @@ from axolotl.loaders.utils import (
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import (
|
||||
get_device_count,
|
||||
get_device_type,
|
||||
)
|
||||
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
@@ -87,6 +85,9 @@ class ModelLoader:
|
||||
`AutoModelForCausalLM`).
|
||||
"""
|
||||
|
||||
use_parallel_config: bool | None = False
|
||||
parallelism_config: ParallelismConfig | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: DictDefault,
|
||||
@@ -183,6 +184,20 @@ class ModelLoader:
|
||||
|
||||
def _apply_pre_model_load_setup(self):
|
||||
"""Apply patches and setup configurations before model loading."""
|
||||
if self.use_parallel_config is not None:
|
||||
self.use_parallel_config = (
|
||||
self.cfg.fsdp_config
|
||||
or (self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1)
|
||||
or (
|
||||
self.cfg.context_parallel_size
|
||||
and self.cfg.context_parallel_size > 1
|
||||
)
|
||||
)
|
||||
if self.cfg.fsdp_config and self.cfg.fsdp_version != 2:
|
||||
self.use_parallel_config = False
|
||||
|
||||
if self.use_parallel_config:
|
||||
self._set_parallel_config()
|
||||
self._set_auto_model_loader()
|
||||
self._set_device_map_config()
|
||||
if self.cfg.revision_of_model:
|
||||
@@ -390,6 +405,86 @@ class ModelLoader:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@staticmethod
|
||||
def _get_parallel_config_kwargs(
|
||||
world_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
dp_shard_size: int | None = None,
|
||||
dp_replicate_size: int | None = None,
|
||||
is_fsdp: bool = False,
|
||||
):
|
||||
pc_kwargs = {}
|
||||
remaining_world_size = world_size
|
||||
|
||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||
|
||||
if context_parallel_size and context_parallel_size > 1:
|
||||
pc_kwargs["cp_size"] = context_parallel_size
|
||||
remaining_world_size = remaining_world_size // context_parallel_size
|
||||
|
||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||
if remaining_world_size > 1:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if dp_replicate_size and dp_replicate_size > 1:
|
||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||
|
||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||
if not is_fsdp:
|
||||
raise ValueError(
|
||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||
"Please ensure you have configured FSDP using fsdp_config."
|
||||
)
|
||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||
remaining_world_size = remaining_world_size // dp_shard_size
|
||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
raise ValueError(
|
||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||
f"{pc_kwargs}"
|
||||
)
|
||||
|
||||
return pc_kwargs
|
||||
|
||||
def _set_parallel_config(self):
|
||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
|
||||
get_world_size(),
|
||||
self.cfg.tensor_parallel_size,
|
||||
self.cfg.context_parallel_size,
|
||||
self.cfg.dp_shard_size,
|
||||
self.cfg.dp_replicate_size,
|
||||
bool(self.cfg.fsdp or self.cfg.fsdp_config),
|
||||
)
|
||||
|
||||
if pc_kwargs:
|
||||
self.parallelism_config = ParallelismConfig(
|
||||
**pc_kwargs,
|
||||
)
|
||||
device_mesh = self.parallelism_config.build_device_mesh("cuda")
|
||||
partial_state = PartialState()
|
||||
# fmt: off
|
||||
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
|
||||
self.parallelism_config
|
||||
)
|
||||
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
|
||||
device_mesh
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def _set_auto_model_loader(self):
|
||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||
@@ -622,6 +717,14 @@ class ModelLoader:
|
||||
def _build_model(self) -> bool:
|
||||
"""Load model, with load strategy depending on config."""
|
||||
skip_move_to_device = False
|
||||
|
||||
if self.cfg.tensor_parallel_size > 1:
|
||||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||
self.model_kwargs["tp_plan"] = "auto"
|
||||
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
||||
if "device_map" in self.model_kwargs:
|
||||
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
@@ -734,6 +837,14 @@ class ModelLoader:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_move_to_device = True
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if self.cfg.tensor_parallel_size > 1:
|
||||
# workaround for upstream 4.54.0 not setting _tp_size or _device_mesh
|
||||
# TODO(wing): remove once 4.54.1 is released
|
||||
if self.model._tp_size != self.cfg.tensor_parallel_size:
|
||||
self.model._tp_size = self.cfg.tensor_parallel_size
|
||||
self.model._device_mesh = self.model_kwargs["device_mesh"]
|
||||
|
||||
return skip_move_to_device
|
||||
|
||||
def _set_z3_leaf_modules(self):
|
||||
|
||||
@@ -49,6 +49,7 @@ class PatchManager:
|
||||
|
||||
def apply_pre_model_load_patches(self):
|
||||
"""Apply pre-model load patches based on config."""
|
||||
self._apply_transformers_patches()
|
||||
# self._apply_flex_attention_patches()
|
||||
self._apply_flash_attention_patches()
|
||||
self._apply_chunked_cross_entropy_patch()
|
||||
@@ -64,13 +65,19 @@ class PatchManager:
|
||||
self._patch_llama_derived_model()
|
||||
self._apply_mistral_cross_entropy_patch()
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_sequence_parallel_patches()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
self._apply_voxtral_patches()
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
||||
patch_prepare_from_posids,
|
||||
)
|
||||
|
||||
patch_prepare_from_posids()
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
self._apply_llama_flash_attn_patches(model)
|
||||
@@ -253,17 +260,6 @@ class PatchManager:
|
||||
has_remote_code=has_remote_code,
|
||||
)
|
||||
|
||||
def _apply_sequence_parallel_patches(self):
|
||||
"""Apply sequence parallelism patches."""
|
||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.ring_attn.patch import (
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
)
|
||||
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
from axolotl.monkeypatch.tiled_mlp import (
|
||||
|
||||
@@ -131,6 +131,17 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||
f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.tensor_parallel_size
|
||||
and cfg.tensor_parallel_size > 1
|
||||
and hasattr(model_config, "tie_word_embeddings")
|
||||
and model_config.tie_word_embeddings
|
||||
):
|
||||
raise ValueError(
|
||||
"Tensor parallelism is incompatible with models configured with `tie_word_embeddings` enabled. "
|
||||
"Please use a model without `tie_word_embeddings`, or disable tensor parallelism."
|
||||
)
|
||||
|
||||
|
||||
def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
|
||||
"""Loads and configures a model configuration from HuggingFace or local sources.
|
||||
|
||||
@@ -249,13 +249,19 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
|
||||
)
|
||||
|
||||
mesh = getattr(accelerator.state, "device_mesh", None)
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||
"mesh": (
|
||||
mesh[tuple(accelerator.state.parallelism_config.fsdp_dim_names)]
|
||||
if mesh is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
model_has_params4bit = False
|
||||
for _, param in model.named_parameters():
|
||||
# this is a temporary fix whereby loading models with bnb params cannot be moved from
|
||||
|
||||
@@ -36,6 +36,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"glm",
|
||||
"glm4",
|
||||
"smollm3",
|
||||
"granite",
|
||||
"granitemoe",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import os.path
|
||||
import shutil
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
@@ -14,8 +14,6 @@ import safetensors.torch as st
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
@@ -84,7 +82,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
||||
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.relora_steps = cfg.relora_steps
|
||||
self.relora_steps = cfg.jagged_restart_steps
|
||||
self.cpu_offload = cfg.relora_cpu_offload
|
||||
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
||||
self.last_full_model = cfg.base_model
|
||||
@@ -255,51 +253,6 @@ class ReLoRACallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
|
||||
class ReLoRAScheduler(LRScheduler):
|
||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
relora_steps: int,
|
||||
warmup_steps: int,
|
||||
anneal_steps: int = 1,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.relora_steps = relora_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.anneal_steps = anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch)
|
||||
|
||||
def get_lr(self) -> float:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.relora_steps - self.warmup_steps:
|
||||
scale = 1
|
||||
else:
|
||||
per_relora_progress = step % self.relora_steps
|
||||
if per_relora_progress < self.warmup_steps:
|
||||
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
||||
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
|
||||
cycle_t = min(
|
||||
1.0,
|
||||
(self.relora_steps - per_relora_progress) / self.anneal_steps,
|
||||
)
|
||||
else:
|
||||
cycle_t = 1
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
return original * scale
|
||||
|
||||
|
||||
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
||||
model_name = "model.safetensors"
|
||||
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||
|
||||
@@ -5,18 +5,14 @@
|
||||
|
||||
from .patch import (
|
||||
get_ring_attn_group,
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
register_ring_attn_from_device_mesh,
|
||||
set_ring_attn_group,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"get_ring_attn_group",
|
||||
"patch_prepare_data_loader",
|
||||
"patch_prepare_device_mesh",
|
||||
"register_ring_attn",
|
||||
"register_ring_attn_from_device_mesh",
|
||||
"set_ring_attn_group",
|
||||
"update_ring_attn_params",
|
||||
)
|
||||
|
||||
@@ -8,13 +8,12 @@ We also provide some patches for accelerate functions to prepare the dataloader
|
||||
sequence parallelism training.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import DeviceMesh
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||
@@ -29,39 +28,13 @@ from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
RING_ATTN_GROUP = None
|
||||
|
||||
ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
|
||||
submesh_dp_size = 1
|
||||
submesh_tp_size = 1
|
||||
if "tp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_tp_size = torch_device_mesh["tp"].size()
|
||||
if "dp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_dp_size = torch_device_mesh["dp"].size()
|
||||
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
||||
process_index = process_index // submesh_tp_size"""
|
||||
|
||||
NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
|
||||
submesh_dp_size = 1
|
||||
submesh_tp_size = 1
|
||||
submesh_cp_size = 1
|
||||
if "cp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_cp_size = torch_device_mesh["cp"].size()
|
||||
if "tp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_tp_size = torch_device_mesh["tp"].size()
|
||||
if "dp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_dp_size = torch_device_mesh["dp"].size()
|
||||
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
||||
process_index = process_index // (submesh_tp_size * submesh_cp_size)"""
|
||||
|
||||
|
||||
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||
"""Getter for ring attention group on this rank."""
|
||||
if RING_ATTN_GROUP is None:
|
||||
raise RuntimeError("register_ring_attn() not yet called")
|
||||
raise RuntimeError("register_ring_attn_from_device_mesh() not yet called")
|
||||
return RING_ATTN_GROUP
|
||||
|
||||
|
||||
@@ -161,15 +134,17 @@ def create_ring_flash_attention_forward(
|
||||
]
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
sequence_parallel_degree: int,
|
||||
def register_ring_attn_from_device_mesh(
|
||||
device_mesh: "DeviceMesh",
|
||||
context_parallel_dim: tuple[str, ...],
|
||||
heads_k_stride: int | None,
|
||||
ring_attn_func: RingAttnFunc | None,
|
||||
):
|
||||
"""Create ring attention group and substitute flash attn with ring flash attn.
|
||||
"""Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
device_mesh: DeviceMesh object containing the parallelism topology.
|
||||
context_parallel_dim: Name of the sequence parallel dimension in the device mesh.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||
@@ -177,44 +152,39 @@ def register_ring_attn(
|
||||
`batch` function.
|
||||
"""
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
LOG.info(
|
||||
f"Enabling ring attention sequence parallelism using DeviceMesh "
|
||||
f"dimension '{context_parallel_dim}'",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
# Extract the sequence parallel submesh
|
||||
try:
|
||||
sequence_mesh = device_mesh[context_parallel_dim]
|
||||
except (KeyError, IndexError) as e:
|
||||
raise ValueError(
|
||||
f"Dimension '{context_parallel_dim}' not found in device_mesh. "
|
||||
f"Available dimensions: {device_mesh.mesh_dim_names}"
|
||||
) from e
|
||||
|
||||
# Get the process group for context parallelism
|
||||
sequence_pg = sequence_mesh.get_group()
|
||||
context_parallel_size = sequence_mesh.size()
|
||||
|
||||
if rank == 0:
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
f"Sequence parallel degree: {context_parallel_size}, "
|
||||
f"mesh shape: {sequence_mesh.mesh.shape}"
|
||||
)
|
||||
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
)
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
# Log which ranks are in the current process group
|
||||
if sequence_pg != dist.GroupMember.WORLD:
|
||||
ranks_in_group = dist.get_process_group_ranks(sequence_pg)
|
||||
LOG.info(f"Current sequence parallel group ranks: {ranks_in_group}")
|
||||
|
||||
# Assign ranks to sequence parallel groups
|
||||
group_assignments = {}
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
i * sequence_parallel_degree,
|
||||
(i + 1) * sequence_parallel_degree,
|
||||
)
|
||||
)
|
||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||
|
||||
# Track which GPUs are in which groups
|
||||
for r in ring_attn_ranks:
|
||||
group_assignments[r] = i
|
||||
|
||||
if rank in ring_attn_ranks:
|
||||
set_ring_attn_group(group)
|
||||
|
||||
# Log the GPU group assignments
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
# Set the ring attention group
|
||||
set_ring_attn_group(sequence_pg)
|
||||
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
# fmt: off
|
||||
@@ -257,92 +227,3 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
|
||||
|
||||
def patch_prepare_data_loader():
|
||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If source code to patch does not exist.
|
||||
"""
|
||||
original_fn = accelerate.data_loader.prepare_data_loader
|
||||
original_source = inspect.getsource(original_fn)
|
||||
|
||||
if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source:
|
||||
raise RuntimeError(
|
||||
"SP patch failed - target snippet not found. "
|
||||
"Check accelerate's version or update the patch."
|
||||
)
|
||||
|
||||
patched_source = original_source.replace(
|
||||
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
|
||||
)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(accelerate.data_loader):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Create a new function from the patched source
|
||||
namespace = {}
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
patched_source, globals(), namespace
|
||||
)
|
||||
|
||||
patched_function = namespace["prepare_data_loader"]
|
||||
original_fn.__code__ = patched_function.__code__
|
||||
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||
|
||||
|
||||
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
|
||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||
that includes sequence parallelism with the specified degree.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: The degree of sequence parallelism to use.
|
||||
fsdp: Whether to use FSDP.
|
||||
"""
|
||||
|
||||
def _prepare_device_mesh(self):
|
||||
"""Prepare the device mesh for distributed training. The dataloader will
|
||||
determine how to load data based on the device mesh.
|
||||
"""
|
||||
if self.state.torch_tp_plugin:
|
||||
return self.state.torch_tp_plugin.torch_device_mesh
|
||||
if (
|
||||
self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED
|
||||
and hasattr(self.state, "ds_device_mesh")
|
||||
):
|
||||
return self.state.ds_device_mesh
|
||||
|
||||
# Create device mesh with sequence parallelism
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // sequence_parallel_degree,
|
||||
sequence_parallel_degree,
|
||||
)
|
||||
device_ids = list(range(world_size))
|
||||
|
||||
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
|
||||
# parallelism" implementation naming.
|
||||
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
|
||||
# only use "fsdp" and "cp" for the device mesh.
|
||||
return dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(device_ids).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
|
||||
)
|
||||
|
||||
# Replace the original method with our new method
|
||||
# pylint: disable=protected-access
|
||||
accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh
|
||||
|
||||
LOG.info(
|
||||
"Successfully patched Accelerator._prepare_device_mesh "
|
||||
f"with sequence_parallel_degree={sequence_parallel_degree}"
|
||||
)
|
||||
|
||||
0
src/axolotl/monkeypatch/transformers/__init__.py
Normal file
0
src/axolotl/monkeypatch/transformers/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Monkey patch to fix transformers.modeling_flash_attention_utils.
|
||||
|
||||
see https://github.com/huggingface/transformers/pull/39653/files
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _prepare_from_posids(query, key, value, position_ids):
|
||||
"""
|
||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||
All three query, key, value states will be flattened.
|
||||
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
||||
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
||||
Arguments:
|
||||
query (`torch.Tensor`):
|
||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
position_ids (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
Return:
|
||||
query (`torch.Tensor`):
|
||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
indices_q (`torch.Tensor`):
|
||||
The indices of non-masked tokens from the flattened input target sequence.
|
||||
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
|
||||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||
"""
|
||||
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(
|
||||
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
cu_seq_lens = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
torch.tensor(
|
||||
position_ids.size(), device=position_ids.device, dtype=torch.int32
|
||||
),
|
||||
)
|
||||
)
|
||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
||||
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||
# for some models (e.g. qwen2-vl).
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
return (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
indices_q,
|
||||
(cu_seq_lens, cu_seq_lens),
|
||||
(max_length, max_length),
|
||||
)
|
||||
|
||||
|
||||
def patch_prepare_from_posids():
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
|
||||
_prepare_from_posids
|
||||
)
|
||||
setattr(
|
||||
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||
"_prepare_from_posids",
|
||||
_prepare_from_posids,
|
||||
)
|
||||
@@ -205,7 +205,7 @@ def execute_training(
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
if cfg.context_parallel_size > 1:
|
||||
models = [trainer.model]
|
||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||
models.append(trainer.ref_model)
|
||||
@@ -213,7 +213,7 @@ def execute_training(
|
||||
stack.enter_context(
|
||||
SequenceParallelContextManager(
|
||||
models=models,
|
||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||
context_parallel_size=cfg.context_parallel_size,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
@@ -267,7 +267,7 @@ def save_trained_model(
|
||||
"your model weights with `axolotl quantize`."
|
||||
)
|
||||
# Handle ReLoRA early return case
|
||||
if cfg.relora_steps:
|
||||
if cfg.relora:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
model = model.merge_and_unload()
|
||||
else:
|
||||
|
||||
@@ -57,10 +57,10 @@ def gpu_memory_usage(device=0):
|
||||
|
||||
@check_cuda_device((0.0, 0.0, 0.0))
|
||||
def gpu_memory_usage_all(device=0):
|
||||
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
||||
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
||||
smi = gpu_memory_usage_smi(device)
|
||||
return usage, reserved - usage, max(0, smi - reserved)
|
||||
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
|
||||
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
|
||||
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
|
||||
return active, allocated, reserved
|
||||
|
||||
|
||||
def mps_memory_usage_all():
|
||||
@@ -92,27 +92,38 @@ def gpu_memory_usage_smi(device=0):
|
||||
return 0.0
|
||||
|
||||
|
||||
def log_gpu_memory_usage(
|
||||
log: logging.Logger | logging.LoggerAdapter,
|
||||
msg: str = "",
|
||||
device: int | torch.device = 0,
|
||||
):
|
||||
def get_gpu_memory_usage(device: int | torch.device = 0):
|
||||
cur_device_type = str(get_device_type())
|
||||
if torch.backends.mps.is_available():
|
||||
usage, cache, misc = mps_memory_usage_all()
|
||||
elif "npu" in cur_device_type and is_torch_npu_available():
|
||||
usage, cache, misc = npu_memory_usage_all(device)
|
||||
elif "gpu" in cur_device_type and torch.cuda.is_available():
|
||||
elif "cuda" in cur_device_type and torch.cuda.is_available():
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
else:
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
return usage, cache, misc
|
||||
|
||||
|
||||
def log_gpu_memory_usage(
|
||||
log: logging.Logger | logging.LoggerAdapter,
|
||||
msg: str = "",
|
||||
device: int | torch.device = 0,
|
||||
):
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage(device)
|
||||
except ValueError:
|
||||
# likely CPU, ignore
|
||||
return
|
||||
cur_device_type = str(get_device_type())
|
||||
extras = []
|
||||
if cache > 0:
|
||||
extras.append(f"+{cache:.03f}GB cache")
|
||||
if misc > 0:
|
||||
extras.append(f"+{misc:.03f}GB misc")
|
||||
msg = f"{cur_device_type} memory usage:" if not msg else msg
|
||||
log.info(
|
||||
f"{msg} {usage:.03f}GB ({', '.join(extras)})",
|
||||
if allocated > 0:
|
||||
extras.append(f"+{allocated:.03f}GB allocated")
|
||||
if reserved > 0:
|
||||
extras.append(f"+{reserved:.03f}GB reserved")
|
||||
msg = f"{cur_device_type} memory active:" if not msg else msg
|
||||
log.debug(
|
||||
f"{msg} {active:.03f}GB ({', '.join(extras)})",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -35,7 +35,6 @@ from transformers.trainer_utils import (
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
@@ -93,28 +92,6 @@ class SaveBetterTransformerModelCallback(
|
||||
return control
|
||||
|
||||
|
||||
class GPUStatsCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||
"""Callback to track GPU utilization"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
) -> TrainerControl:
|
||||
if not self.logged and state.global_step > 1:
|
||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||
self.logged = True
|
||||
return control
|
||||
|
||||
|
||||
class LossWatchDogCallback(TrainerCallback):
|
||||
"""Callback to track loss and stop training if loss is too high"""
|
||||
|
||||
|
||||
62
src/axolotl/utils/chat_templates/templates/granite.jinja
Normal file
62
src/axolotl/utils/chat_templates/templates/granite.jinja
Normal file
@@ -0,0 +1,62 @@
|
||||
{# Alias tools -> available_tools #}
|
||||
{%- if tools and not available_tools -%}
|
||||
{%- set available_tools = tools -%}
|
||||
{%- endif -%}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- set system_message = messages[0]['content'] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set system_message = "Knowledge Cutoff Date: April 2024.
|
||||
Today's Date: " + strftime_now('%B %d, %Y') + ".
|
||||
You are Granite, developed by IBM." %}
|
||||
{%- if available_tools and documents %}
|
||||
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.
|
||||
Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||
{%- elif available_tools %}
|
||||
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
|
||||
{%- elif documents %}
|
||||
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||
{%- elif thinking %}
|
||||
{%- set system_message = system_message + " You are a helpful AI assistant.
|
||||
Respond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between <think></think> and write your response between <response></response> for each user query." %}
|
||||
{%- else %}
|
||||
{%- set system_message = system_message + " You are a helpful AI assistant." %}
|
||||
{%- endif %}
|
||||
{%- if 'citations' in controls and documents %}
|
||||
{%- set system_message = system_message + '
|
||||
Use the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
|
||||
{%- endif %}
|
||||
{%- if 'hallucinations' in controls and documents %}
|
||||
{%- set system_message = system_message + '
|
||||
Finally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %}
|
||||
{%- endif %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
|
||||
' }}
|
||||
{%- if available_tools %}
|
||||
{{- '<|start_of_role|>available_tools<|end_of_role|>' }}
|
||||
{{- available_tools | tojson(indent=4) }}
|
||||
{{- '<|end_of_text|>
|
||||
' }}
|
||||
{%- endif %}
|
||||
{%- if documents %}
|
||||
{%- for document in documents %}
|
||||
{{- '<|start_of_role|>document {"document_id": "' + document['doc_id'] | string + '"}<|end_of_role|>
|
||||
' }}
|
||||
{{- document['text'] }}
|
||||
{{- '<|end_of_text|>
|
||||
' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- for message in loop_messages %}
|
||||
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
|
||||
' }}
|
||||
{%- if loop.last and add_generation_prompt %}
|
||||
{{- '<|start_of_role|>assistant' }}
|
||||
{%- if controls %}
|
||||
{{- ' ' + controls | tojson()}}
|
||||
{%- endif %}
|
||||
{{- '<|end_of_role|>' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
64
src/axolotl/utils/chat_templates/templates/granitemoe.jinja
Normal file
64
src/axolotl/utils/chat_templates/templates/granitemoe.jinja
Normal file
@@ -0,0 +1,64 @@
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- set system_message = messages[0]['content'] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set system_message = "Knowledge Cutoff Date: April 2024.
|
||||
Today's Date: " + strftime_now('%B %d, %Y') + ".
|
||||
You are Granite, developed by IBM." %}
|
||||
{%- if tools and documents %}
|
||||
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.
|
||||
|
||||
Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||
{%- elif tools %}
|
||||
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
|
||||
{%- elif documents %}
|
||||
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||
{%- else %}
|
||||
{%- set system_message = system_message + " You are a helpful AI assistant." %}
|
||||
{%- endif %}
|
||||
{%- if 'citations' in controls and documents %}
|
||||
{%- set system_message = system_message + '
|
||||
|
||||
In your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
|
||||
{%- endif %}
|
||||
{%- if 'hallucinations' in controls and documents %}
|
||||
{%- set system_message = system_message + '
|
||||
|
||||
Finally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.' %}
|
||||
{%- endif %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
|
||||
' }}
|
||||
{%- if tools %}
|
||||
{{- '<|start_of_role|>tools<|end_of_role|>' }}
|
||||
{{- tools | tojson(indent=4) }}
|
||||
{{- '<|end_of_text|>
|
||||
' }}
|
||||
{%- endif %}
|
||||
{%- if documents %}
|
||||
{{- '<|start_of_role|>documents<|end_of_role|>' }}
|
||||
{%- for document in documents %}
|
||||
{{- 'Document ' + loop.index0 | string + '
|
||||
' }}
|
||||
{{- document['text'] }}
|
||||
{%- if not loop.last %}
|
||||
{{- '
|
||||
|
||||
'}}
|
||||
{%- endif%}
|
||||
{%- endfor %}
|
||||
{{- '<|end_of_text|>
|
||||
' }}
|
||||
{%- endif %}
|
||||
{%- for message in loop_messages %}
|
||||
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
|
||||
' }}
|
||||
{%- if loop.last and add_generation_prompt %}
|
||||
{{- '<|start_of_role|>assistant' }}
|
||||
{%- if controls %}
|
||||
{{- ' ' + controls | tojson()}}
|
||||
{%- endif %}
|
||||
{{- '<|end_of_role|>' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
@@ -5,6 +5,7 @@ import inspect
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from torch import nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
@@ -12,7 +13,7 @@ from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
register_ring_attn_from_device_mesh,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
@@ -150,9 +151,18 @@ def apply_sequence_parallelism(
|
||||
if "num_items_in_batch" in batch:
|
||||
# Approximation; this needed since num_items_in_batch may be counted across
|
||||
# all samples in a gradient accumulated batch, not on a per-step basis.
|
||||
local_valid_tokens = (batch["labels"] != -100).sum()
|
||||
|
||||
# All-reduce across sequence parallel ranks to get global token count
|
||||
cp_group = get_ring_attn_group()
|
||||
global_valid_tokens = local_valid_tokens.clone()
|
||||
# we use AVG instead of SUM as using sum seems to scale down the loss by over-accounting the number of tokens
|
||||
dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.AVG, group=cp_group)
|
||||
global_valid_tokens = int(global_valid_tokens.item())
|
||||
|
||||
batch["num_items_in_batch"] = (
|
||||
batch["labels"] != -100
|
||||
).sum() * gradient_accumulation_steps
|
||||
global_valid_tokens * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
return batch, original_seq_len, pad_len
|
||||
|
||||
@@ -167,7 +177,7 @@ class SequenceParallelContextManager:
|
||||
Args:
|
||||
models: List of models to apply sequence parallelism to pre- and post- forward
|
||||
hooks.
|
||||
sequence_parallel_degree: Number of processes to split sequences over.
|
||||
context_parallel_size: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
@@ -179,14 +189,14 @@ class SequenceParallelContextManager:
|
||||
def __init__(
|
||||
self,
|
||||
models: list[nn.Module],
|
||||
sequence_parallel_degree: int,
|
||||
context_parallel_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
gather_outputs: bool,
|
||||
):
|
||||
self.models = models
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.context_parallel_size = context_parallel_size
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
@@ -230,8 +240,10 @@ class SequenceParallelContextManager:
|
||||
|
||||
def _register_ring_attn(self):
|
||||
# Initialize ring attn for sequence parallelism
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree,
|
||||
partial_state = PartialState()
|
||||
register_ring_attn_from_device_mesh(
|
||||
device_mesh=partial_state.device_mesh,
|
||||
context_parallel_dim=("cp",),
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
@@ -430,10 +430,11 @@ def save_preprocessed_dataset(
|
||||
num_shards=cfg.num_dataset_shards_to_save,
|
||||
)
|
||||
else:
|
||||
min_rows_per_proc = 256
|
||||
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||
dataset.save_to_disk(
|
||||
str(prepared_ds_path),
|
||||
num_proc=min(max(1, len(dataset) // 8), num_workers),
|
||||
num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers),
|
||||
max_shard_size=None,
|
||||
num_shards=cfg.num_dataset_shards_to_save,
|
||||
)
|
||||
|
||||
@@ -2,12 +2,15 @@
|
||||
utils to get GPU info for the current environment
|
||||
"""
|
||||
|
||||
from importlib.metadata import version
|
||||
|
||||
from accelerate.utils.environment import (
|
||||
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||
)
|
||||
from accelerate.utils.environment import (
|
||||
get_gpu_info,
|
||||
)
|
||||
from packaging.version import Version, parse
|
||||
|
||||
|
||||
def check_cuda_p2p_ib_support():
|
||||
@@ -26,3 +29,13 @@ def check_cuda_p2p_ib_support():
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
def get_package_version(package: str) -> Version:
|
||||
version_str = version(package)
|
||||
return parse(version_str)
|
||||
|
||||
|
||||
def is_package_version_ge(package: str, version_: str) -> bool:
|
||||
package_version = get_package_version(package)
|
||||
return package_version >= parse(version_)
|
||||
|
||||
@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import cpu_count, get_context
|
||||
from typing import Iterable, Iterator, Union
|
||||
@@ -453,7 +454,10 @@ class MultipackBatchSampler(BatchSampler):
|
||||
_sampled_lens = []
|
||||
for _ in range(self.num_count_samples):
|
||||
self._batches = None # Reset cached batches
|
||||
# log timer for generating batches
|
||||
start_time = time.time()
|
||||
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
|
||||
LOG.debug(f"generate_batches time: {time.time() - start_time}")
|
||||
len_batches = min(_sampled_lens)
|
||||
|
||||
# Gather minimum across all ranks
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
@@ -292,3 +293,50 @@ def get_cosine_schedule_with_warmup_decay_constant(
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
class JaggedLRRestartScheduler(LRScheduler):
|
||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
jagged_restart_steps: int,
|
||||
jagged_restart_warmup_steps: int,
|
||||
jagged_restart_anneal_steps: int = 1,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.inner_schedule = inner_schedule
|
||||
self.restarts_steps = jagged_restart_steps
|
||||
self.warmup_steps = jagged_restart_warmup_steps
|
||||
self.anneal_steps = jagged_restart_anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch)
|
||||
|
||||
def get_lr(self) -> float | Sequence[float]:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.restarts_steps - self.anneal_steps:
|
||||
scale = 1
|
||||
else:
|
||||
per_restart_progress = step % self.restarts_steps
|
||||
if per_restart_progress < self.warmup_steps:
|
||||
cycle_t = min(1.0, (per_restart_progress) / self.warmup_steps)
|
||||
elif per_restart_progress > (self.restarts_steps - self.anneal_steps):
|
||||
cycle_t = min(
|
||||
1.0,
|
||||
(self.restarts_steps - per_restart_progress) / self.anneal_steps,
|
||||
)
|
||||
else:
|
||||
cycle_t = 1
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
|
||||
return original * scale
|
||||
|
||||
@@ -43,7 +43,7 @@ from axolotl.utils.schemas.model import (
|
||||
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig, JaggedLRConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.validation import ValidationMixin
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
@@ -57,6 +57,7 @@ class AxolotlInputConfig(
|
||||
ModelOutputConfig,
|
||||
LoraConfig,
|
||||
ReLoRAConfig,
|
||||
JaggedLRConfig,
|
||||
HyperparametersConfig,
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
@@ -650,7 +651,23 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
dp_shard_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of devices to shard across. If not set, will use all available devices."
|
||||
},
|
||||
)
|
||||
dp_replicate_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of devices to replicate across."},
|
||||
)
|
||||
sequence_parallel_degree: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Deprecated: use `context_parallel_size` instead"
|
||||
},
|
||||
)
|
||||
context_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."
|
||||
|
||||
@@ -67,6 +67,8 @@ class ChatTemplate(str, Enum):
|
||||
command_a_tool_use = "command_a_tool_use"
|
||||
command_a_rag = "command_a_rag"
|
||||
aya = "aya"
|
||||
granite = "granite"
|
||||
granitemoe = "granitemoe"
|
||||
|
||||
|
||||
class CustomSupportedOptimizers(str, Enum):
|
||||
|
||||
@@ -187,18 +187,10 @@ class LoraConfig(BaseModel):
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
relora_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of steps per ReLoRA restart"},
|
||||
)
|
||||
relora_warmup_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of per-restart warmup steps"},
|
||||
)
|
||||
relora_anneal_steps: int | None = Field(
|
||||
relora: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of anneal steps for each relora cycle"
|
||||
"description": "Whether to use ReLoRA. Use with jagged_restart_*steps options."
|
||||
},
|
||||
)
|
||||
relora_prune_ratio: float | None = Field(
|
||||
|
||||
@@ -160,3 +160,24 @@ class HyperparametersConfig(BaseModel):
|
||||
if learning_rate and isinstance(learning_rate, str):
|
||||
learning_rate = float(learning_rate)
|
||||
return learning_rate
|
||||
|
||||
|
||||
class JaggedLRConfig(BaseModel):
|
||||
"""JaggedLR configuration subset, can be used w/ ReLoRA training"""
|
||||
|
||||
jagged_restart_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "how often to reset for jagged restarts"},
|
||||
)
|
||||
jagged_restart_warmup_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "how many warmup steps to take after reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
jagged_restart_anneal_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "how many anneal steps to take before reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -80,6 +80,14 @@ class TRLConfig(BaseModel):
|
||||
"description": "Number of completions to print when log_completions is True."
|
||||
},
|
||||
)
|
||||
importance_sampling_level: Literal["sequence", "token"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Controls whether importance sampling ratios are computed at the `'token'` or `'sequence'` level. "
|
||||
"For GSPO, use `sequence`, default is None which corresponds to the original GRPO paper."
|
||||
},
|
||||
)
|
||||
|
||||
sync_ref_model: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to sync the reference model."},
|
||||
|
||||
@@ -644,6 +644,19 @@ class LoRAValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_dropout_parameters(cls, data):
|
||||
if (
|
||||
data.get("lora_dropout", 0.0)
|
||||
and data.get("lora_dropout") > 0.0
|
||||
and data.get("lora_target_parameters")
|
||||
):
|
||||
# lora.ParamWrapper does not work with lora_dropout != 0
|
||||
raise ValueError(
|
||||
"`lora_dropout` does not work when using `lora_target_parameters`"
|
||||
)
|
||||
|
||||
|
||||
class RLValidationMixin:
|
||||
"""Validation methods related to RL training configuration."""
|
||||
@@ -673,7 +686,7 @@ class RLValidationMixin:
|
||||
data.get("rl") == "grpo"
|
||||
and data.get("trl", {})
|
||||
and data.get("trl").get("use_liger_loss")
|
||||
and data.get("sequence_parallel_degree", 1) > 1
|
||||
and data.get("context_parallel_size", 1) > 1
|
||||
):
|
||||
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||
return data
|
||||
@@ -881,17 +894,19 @@ class OptimizationValidationMixin:
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fsdp_sharded_state_dict_w_safetensors(self):
|
||||
def lr_groups_ao_optimizer(self):
|
||||
if (
|
||||
hasattr(self, "fsdp_config")
|
||||
and self.fsdp_config
|
||||
and hasattr(self, "save_safetensors")
|
||||
and self.save_safetensors
|
||||
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
|
||||
and str(getattr(self, "fsdp_version", "1")) != "2"
|
||||
):
|
||||
self.loraplus_lr_ratio is not None
|
||||
or self.embedding_lr_scale is not None
|
||||
or self.embedding_lr is not None
|
||||
or self.lr_groups is not None
|
||||
) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]:
|
||||
# TODO(wing): remove this once ao>0.12.0
|
||||
# requires https://github.com/pytorch/ao/pull/2606 in an ao release
|
||||
raise ValueError(
|
||||
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
|
||||
"lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not "
|
||||
"supported with ao low-bit optimizers until ao>0.12.0. "
|
||||
"Please refer to https://github.com/pytorch/ao/pull/2606."
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -900,31 +915,30 @@ class OptimizationValidationMixin:
|
||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||
tensor_parallel_size = data.get("tensor_parallel_size")
|
||||
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
||||
if not data.get("deepspeed"):
|
||||
raise ValueError(
|
||||
"Tensor parallelism (TP) is only supported with DeepSpeed"
|
||||
)
|
||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||
ds_config = json.load(ds_fin)
|
||||
should_save = False
|
||||
if "tensor_parallel" not in ds_config:
|
||||
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
|
||||
should_save = True
|
||||
if (
|
||||
"gather_16bit_weights_on_model_save"
|
||||
not in ds_config["zero_optimization"]
|
||||
):
|
||||
ds_config["zero_optimization"][
|
||||
if data.get("deepspeed"):
|
||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||
ds_config = json.load(ds_fin)
|
||||
should_save = False
|
||||
if "tensor_parallel" not in ds_config:
|
||||
ds_config["tensor_parallel"] = {
|
||||
"autotp_size": tensor_parallel_size
|
||||
}
|
||||
should_save = True
|
||||
if (
|
||||
"gather_16bit_weights_on_model_save"
|
||||
] = True
|
||||
should_save = True
|
||||
if should_save:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(
|
||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||
) as ds_fout:
|
||||
json.dump(ds_config, ds_fout, indent=4)
|
||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||
not in ds_config["zero_optimization"]
|
||||
):
|
||||
ds_config["zero_optimization"][
|
||||
"gather_16bit_weights_on_model_save"
|
||||
] = True
|
||||
should_save = True
|
||||
if should_save:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(
|
||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||
) as ds_fout:
|
||||
json.dump(ds_config, ds_fout, indent=4)
|
||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||
|
||||
return data
|
||||
|
||||
@@ -1164,7 +1178,9 @@ class ComplexValidationMixin:
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_relora(self):
|
||||
if self.relora_steps:
|
||||
if self.relora:
|
||||
if not self.jagged_restart_steps:
|
||||
raise ValueError("jagged_restart_steps must be set to use ReLoRA")
|
||||
if self.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
@@ -1203,13 +1219,18 @@ class ComplexValidationMixin:
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_sequence_parallel_degree(self):
|
||||
if not self.sequence_parallel_degree:
|
||||
self.sequence_parallel_degree = 1
|
||||
elif self.sequence_parallel_degree > 1:
|
||||
def check_context_parallel_size(self):
|
||||
if self.sequence_parallel_degree and not self.context_parallel_size:
|
||||
LOG.warning(
|
||||
"`sequence_parallel_degree` is deprecated, use `context_parallel_size`"
|
||||
)
|
||||
self.context_parallel_size = self.sequence_parallel_degree
|
||||
if not self.context_parallel_size:
|
||||
self.context_parallel_size = 1
|
||||
elif self.context_parallel_size > 1:
|
||||
if not self.flash_attention:
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
"flash_attention: true must be set with context_parallel_size > 1"
|
||||
)
|
||||
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
@@ -1219,17 +1240,23 @@ class ComplexValidationMixin:
|
||||
)
|
||||
|
||||
try:
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
# pylint: disable=protected-access
|
||||
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
|
||||
transformers.modeling_flash_attention_utils._flash_supports_window
|
||||
)
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"context_parallel_size > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||
f"context_parallel_size={self.context_parallel_size}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
@@ -1240,7 +1267,7 @@ class ComplexValidationMixin:
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_ring_attn_func(self):
|
||||
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
||||
if getattr(self, "context_parallel_size", 1) == 1:
|
||||
return self
|
||||
|
||||
if self.ring_attn_func is not None:
|
||||
@@ -1257,6 +1284,33 @@ class ComplexValidationMixin:
|
||||
return self
|
||||
|
||||
|
||||
class DistributedValidationMixin:
|
||||
"""validation for distributed training."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tensor_parallel_optimizer(self):
|
||||
if self.tensor_parallel_size > 1:
|
||||
if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]:
|
||||
raise ValueError(
|
||||
"tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class GRPOVllmValidationMixin:
|
||||
"""Validation mixin for vllm when using GRPO."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_vllm_mode_set(self):
|
||||
if self.trl and self.trl.use_vllm and not self.trl.vllm_mode:
|
||||
LOG.warning(
|
||||
"vllm_mode must be set to either `server` or `colocate` when using vllm, using default value `server`"
|
||||
)
|
||||
self.trl.vllm_mode = "server"
|
||||
return self
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class ValidationMixin(
|
||||
DatasetValidationMixin,
|
||||
@@ -1270,5 +1324,6 @@ class ValidationMixin(
|
||||
PretrainingValidationMixin,
|
||||
ModelCompatibilityValidationMixin,
|
||||
ComplexValidationMixin,
|
||||
GRPOVllmValidationMixin,
|
||||
):
|
||||
"""Full validation mixin for Axolotl configuration."""
|
||||
|
||||
@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.context_parallel_size
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
LOG.debug(
|
||||
@@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.floor(
|
||||
data_loader_len
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.context_parallel_size
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
)
|
||||
@@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.context_parallel_size
|
||||
* cfg.tensor_parallel_size
|
||||
/ cfg.batch_size
|
||||
)
|
||||
|
||||
@@ -17,16 +17,23 @@ class BaseCliTest:
|
||||
command: Command to test (train/evaluate)
|
||||
"""
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
|
||||
result = cli_runner.invoke(cli, [command, "--launcher", "python"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
|
||||
result = cli_runner.invoke(
|
||||
cli, [command, "nonexistent.yml", "--launcher", "python"]
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
def _test_basic_execution(
|
||||
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
|
||||
self,
|
||||
cli_runner,
|
||||
tmp_path: Path,
|
||||
valid_test_config: str,
|
||||
command: str,
|
||||
train: bool = True,
|
||||
):
|
||||
"""Test basic execution with accelerate.
|
||||
|
||||
@@ -35,6 +42,7 @@ class BaseCliTest:
|
||||
tmp_path: Temporary path fixture
|
||||
valid_test_config: Valid config fixture
|
||||
command: Command to test (train/evaluate)
|
||||
train: Whether to test training (default) or evaluation
|
||||
"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
@@ -43,15 +51,21 @@ class BaseCliTest:
|
||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.args[0] == [
|
||||
|
||||
expected = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
f"axolotl.cli.{command}",
|
||||
str(config_path),
|
||||
"--debug-num-examples",
|
||||
"0",
|
||||
"--debug=False",
|
||||
"--debug-text-only=False",
|
||||
"--debug-num-examples=0",
|
||||
]
|
||||
if train:
|
||||
expected.append("--shard=False")
|
||||
|
||||
assert mock.call_args.args[0] == expected
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for evaluate CLI command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -18,7 +20,9 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
|
||||
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
|
||||
self._test_basic_execution(
|
||||
cli_runner, tmp_path, valid_test_config, "evaluate", train=False
|
||||
)
|
||||
|
||||
def test_evaluate_basic_execution_no_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
@@ -27,13 +31,15 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -55,7 +61,8 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
"2",
|
||||
"--sequence-len",
|
||||
"128",
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -65,3 +72,104 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
cfg = mock_evaluate.call_args[0][0]
|
||||
assert cfg.micro_batch_size == 2
|
||||
assert cfg.sequence_len == 128
|
||||
|
||||
def test_evaluate_with_launcher_args_torchrun(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test evaluate with torchrun launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.evaluate" in called_cmd
|
||||
|
||||
def test_evaluate_with_launcher_args_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test evaluate with accelerate launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.evaluate" in called_cmd
|
||||
|
||||
def test_evaluate_backward_compatibility_no_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test that existing evaluate commands work without launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
) # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""pytest tests for axolotl CLI inference command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -10,7 +12,7 @@ def test_inference_basic(cli_runner, config_path):
|
||||
with patch("axolotl.cli.inference.do_inference") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
["inference", str(config_path), "--no-accelerate"],
|
||||
["inference", str(config_path), "--launcher", "python"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
@@ -23,9 +25,124 @@ def test_inference_gradio(cli_runner, config_path):
|
||||
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
["inference", str(config_path), "--no-accelerate", "--gradio"],
|
||||
["inference", str(config_path), "--launcher", "python", "--gradio"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_inference_with_launcher_args_torchrun(cli_runner, config_path):
|
||||
"""Test inference with torchrun launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_with_launcher_args_accelerate(cli_runner, config_path):
|
||||
"""Test inference with accelerate launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_gradio_with_launcher_args(cli_runner, config_path):
|
||||
"""Test inference with gradio and launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--gradio",
|
||||
"--",
|
||||
"--num_processes=2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify both gradio flag and launcher args are present
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--num_processes=2" in called_cmd
|
||||
assert "--gradio" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path):
|
||||
"""Test that existing inference commands work without launcher args"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -18,11 +18,10 @@ def test_build_command():
|
||||
assert result == [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--learning-rate",
|
||||
"0.0001",
|
||||
"--batch-size",
|
||||
"8",
|
||||
"--debug",
|
||||
"--learning-rate=0.0001",
|
||||
"--batch-size=8",
|
||||
"--debug=True",
|
||||
"--use-fp16=False",
|
||||
]
|
||||
|
||||
|
||||
@@ -38,7 +37,7 @@ def test_invalid_command_options(cli_runner):
|
||||
],
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "No such option" in result.output
|
||||
assert "does not exist" in result.output
|
||||
|
||||
|
||||
def test_required_config_argument(cli_runner):
|
||||
|
||||
@@ -11,9 +11,101 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
|
||||
"""Test merge_sharded_fsdp_weights command without accelerate"""
|
||||
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
|
||||
cli,
|
||||
["merge-sharded-fsdp-weights", str(config_path), "--launcher", "python"],
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test merge-sharded-fsdp-weights with torchrun launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test merge-sharded-fsdp-weights with accelerate launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test that existing merge-sharded-fsdp-weights commands work without launcher args"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
unit tests for generating sweep configurations
|
||||
"""
|
||||
|
||||
from axolotl.cli.main import generate_sweep_configs
|
||||
from axolotl.cli.utils import generate_sweep_configs
|
||||
|
||||
|
||||
def test_generate_sweep_configs_no_pairs():
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for train CLI command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -18,7 +20,9 @@ class TestTrainCommand(BaseCliTest):
|
||||
|
||||
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
|
||||
self._test_basic_execution(
|
||||
cli_runner, tmp_path, valid_test_config, "train", train=True
|
||||
)
|
||||
|
||||
def test_train_basic_execution_no_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
@@ -37,7 +41,8 @@ class TestTrainCommand(BaseCliTest):
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -59,11 +64,10 @@ class TestTrainCommand(BaseCliTest):
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
"--no-accelerate",
|
||||
"--learning-rate=1e-4",
|
||||
"--micro-batch-size=2",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -73,3 +77,174 @@ class TestTrainCommand(BaseCliTest):
|
||||
cfg = mock_train.call_args[1]["cfg"]
|
||||
assert cfg["learning_rate"] == 1e-4
|
||||
assert cfg["micro_batch_size"] == 2
|
||||
|
||||
def test_train_with_launcher_args_torchrun(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with torchrun launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.train" in called_cmd
|
||||
|
||||
def test_train_with_launcher_args_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with accelerate launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.train" in called_cmd
|
||||
|
||||
def test_train_backward_compatibility_no_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test that existing train commands work without launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
) # No launcher args between 'launch' and '-m'
|
||||
|
||||
def test_train_mixed_args_with_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with both regular CLI args and launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--learning-rate",
|
||||
"2e-4",
|
||||
"--micro-batch-size",
|
||||
"4",
|
||||
"--",
|
||||
"--nproc_per_node=8",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
# Verify launcher args
|
||||
assert "--nproc_per_node=8" in called_cmd
|
||||
# Verify axolotl args are also present
|
||||
assert "--learning-rate=2e-4" in called_cmd
|
||||
assert "--micro-batch-size=4" in called_cmd
|
||||
|
||||
def test_train_cloud_with_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with cloud and launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
cloud_path = tmp_path / "cloud.yml"
|
||||
cloud_path.write_text("provider: modal\ngpu: a100")
|
||||
|
||||
with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--cloud",
|
||||
str(cloud_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=4",
|
||||
"--nnodes=2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_cloud_train.assert_called_once()
|
||||
|
||||
# Verify cloud training was called with launcher args
|
||||
call_kwargs = mock_cloud_train.call_args.kwargs
|
||||
assert call_kwargs["launcher"] == "torchrun"
|
||||
assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"]
|
||||
|
||||
@@ -72,3 +72,160 @@ def test_fetch_from_github_network_error():
|
||||
with patch("requests.get", side_effect=requests.RequestException):
|
||||
with pytest.raises(requests.RequestException):
|
||||
fetch_from_github("examples/", None)
|
||||
|
||||
|
||||
def assert_launcher_args_in_command(
|
||||
mock_subprocess_call,
|
||||
launcher: str,
|
||||
expected_launcher_args: list[str],
|
||||
command_module: str,
|
||||
):
|
||||
"""
|
||||
Helper function to verify launcher arguments are properly passed in subprocess calls.
|
||||
|
||||
Args:
|
||||
mock_subprocess_call: The mock subprocess.run call
|
||||
launcher: Expected launcher ("accelerate", "torchrun", etc.)
|
||||
expected_launcher_args: List of expected launcher arguments
|
||||
command_module: Expected module name (e.g., "axolotl.cli.train")
|
||||
"""
|
||||
assert mock_subprocess_call.called, "subprocess.run should have been called"
|
||||
called_cmd = mock_subprocess_call.call_args.args[0]
|
||||
|
||||
# Verify launcher
|
||||
assert (
|
||||
called_cmd[0] == launcher
|
||||
), f"Expected launcher {launcher}, got {called_cmd[0]}"
|
||||
|
||||
# Verify launcher args are present
|
||||
for arg in expected_launcher_args:
|
||||
assert (
|
||||
arg in called_cmd
|
||||
), f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
|
||||
|
||||
# Verify module is present
|
||||
assert "-m" in called_cmd, "Expected -m flag for module execution"
|
||||
assert (
|
||||
command_module in called_cmd
|
||||
), f"Expected module {command_module} not found in command: {called_cmd}"
|
||||
|
||||
|
||||
def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
|
||||
"""
|
||||
Helper function to verify no unwanted launcher arguments are present.
|
||||
|
||||
Args:
|
||||
mock_subprocess_call: The mock subprocess.run call
|
||||
launcher: Expected launcher ("accelerate", "torchrun", etc.)
|
||||
"""
|
||||
assert mock_subprocess_call.called, "subprocess.run should have been called"
|
||||
called_cmd = mock_subprocess_call.call_args.args[0]
|
||||
|
||||
if launcher == "accelerate":
|
||||
# For accelerate, launcher args should be between 'launch' and '-m'
|
||||
launch_idx = called_cmd.index("launch")
|
||||
m_idx = called_cmd.index("-m")
|
||||
launcher_section = called_cmd[launch_idx + 1 : m_idx]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
), f"Unexpected launcher args found: {launcher_section}"
|
||||
elif launcher == "torchrun":
|
||||
# For torchrun, launcher args should be between 'torchrun' and '-m'
|
||||
torchrun_idx = called_cmd.index("torchrun")
|
||||
m_idx = called_cmd.index("-m")
|
||||
launcher_section = called_cmd[torchrun_idx + 1 : m_idx]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
), f"Unexpected launcher args found: {launcher_section}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def common_launcher_args():
|
||||
"""Fixture providing common launcher argument combinations for testing."""
|
||||
return {
|
||||
"torchrun": ["--nproc_per_node=2", "--nnodes=1"],
|
||||
"accelerate": ["--config_file=accelerate_config.yml", "--num_processes=4"],
|
||||
}
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_endpoint():
|
||||
"""Test that default RDZV args are added when rdzv_endpoint is present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = ["--nnodes=2", "--rdzv_endpoint=127.0.0.1:29400"]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should have added rdzv_backend
|
||||
assert "--rdzv_backend" in result
|
||||
assert "c10d" in result
|
||||
|
||||
# Original args should still be present
|
||||
assert "--nnodes=2" in result
|
||||
assert "--rdzv_endpoint=127.0.0.1:29400" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_existing_backend():
|
||||
"""Test that existing rdzv_backend is not overridden."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_backend=static",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add another rdzv_backend
|
||||
backend_count = sum(1 for arg in result if "--rdzv_backend" in arg)
|
||||
assert backend_count == 1
|
||||
assert "--rdzv_backend=static" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_existing_id():
|
||||
"""Test that existing rdzv_id is not overridden."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_id=my_job_123",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add another rdzv_id
|
||||
id_count = sum(1 for arg in result if "--rdzv_id" in arg)
|
||||
assert id_count == 1
|
||||
assert "--rdzv_id=my_job_123" in result
|
||||
|
||||
# Should still add rdzv_backend
|
||||
assert "--rdzv_backend" in result
|
||||
assert "c10d" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_without_endpoint():
|
||||
"""Test that no RDZV args are added when rdzv_endpoint is not present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = ["--nnodes=2", "--nproc_per_node=4"]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add any rdzv args
|
||||
assert "--rdzv_backend" not in result
|
||||
assert result == launcher_args
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_all_existing():
|
||||
"""Test that no defaults are added when all RDZV args are present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_backend=static",
|
||||
"--rdzv_id=existing_job",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add any additional args
|
||||
assert len(result) == len(launcher_args)
|
||||
assert result == launcher_args
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user