Compare commits
20 Commits
lora_kerne
...
chat-templ
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5198d8734 | ||
|
|
4ab6a1bd7e | ||
|
|
5639552064 | ||
|
|
cda3c82351 | ||
|
|
7c3b428f23 | ||
|
|
01a6bd1a0e | ||
|
|
41709822a7 | ||
|
|
02a37199ee | ||
|
|
7026cd5e9e | ||
|
|
eb0a8a7775 | ||
|
|
294c7fe7a6 | ||
|
|
7b68dfafd7 | ||
|
|
32a7890231 | ||
|
|
563f5eed7a | ||
|
|
6ec282094d | ||
|
|
09dda462ab | ||
|
|
bb1cae1a20 | ||
|
|
22810c97b7 | ||
|
|
2eb7ff95af | ||
|
|
90e5598930 |
4
.github/workflows/preview-docs.yml
vendored
4
.github/workflows/preview-docs.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
||||
|
||||
- name: Netlify Publish
|
||||
uses: nwtgck/actions-netlify@v3.0
|
||||
if: ${{ secrets.NETLIFY_AUTH_TOKEN != '' }}
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
id: netlify
|
||||
with:
|
||||
publish-dir: './_site'
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
|
||||
|
||||
- name: Update PR with preview link
|
||||
if: ${{ steps.netlify.outcome == 'success' && secrets.NETLIFY_AUTH_TOKEN != '' }}
|
||||
if: ${{ steps.netlify.outcome == 'success' }}
|
||||
uses: marocchino/sticky-pull-request-comment@v2
|
||||
with:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
|
||||
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
|
||||
15
_quarto.yml
15
_quarto.yml
@@ -35,25 +35,30 @@ quartodoc:
|
||||
- cli.train
|
||||
- cli.evaluate
|
||||
- cli.args
|
||||
- cli.art
|
||||
- cli.checks
|
||||
- cli.config
|
||||
- cli.delinearize_llama4
|
||||
- cli.inference
|
||||
- cli.merge_lora
|
||||
- cli.merge_sharded_fsdp_weights
|
||||
- cli.preprocess
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.quantize
|
||||
- cli.vllm_serve
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- cli.quantize
|
||||
- cli.utils
|
||||
- cli.utils.args
|
||||
- cli.utils.fetch
|
||||
- cli.utils.load
|
||||
- cli.utils.sweeps
|
||||
- cli.utils.train
|
||||
- title: Trainers
|
||||
desc: Training implementations
|
||||
contents:
|
||||
- core.trainers.base
|
||||
- core.trainers.trl
|
||||
- core.trainers.mamba
|
||||
- core.trainers.relora
|
||||
- core.trainers.dpo.trainer
|
||||
- core.trainers.grpo.trainer
|
||||
- core.trainers.grpo.sampler
|
||||
@@ -269,7 +274,6 @@ website:
|
||||
- docs/dataset_preprocessing.qmd
|
||||
- docs/multipack.qmd
|
||||
- docs/mixed_precision.qmd
|
||||
- docs/gradient_accumulation.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
@@ -279,6 +283,7 @@ website:
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
- docs/nd_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
set -e
|
||||
|
||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||
pytest -v -n2 \
|
||||
pytest -v --durations=10 -n2 \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||
|
||||
@@ -65,6 +65,9 @@ GPU_CONFIG = f"L40S:{N_GPUS}"
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
sp_env = os.environ.copy()
|
||||
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
@@ -16,7 +16,10 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
|
||||
ibverbs-providers ibverbs-utils infiniband-diags \
|
||||
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
||||
&& rm -rf /var/cache/apt/archives \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
|
||||
@@ -15,7 +15,7 @@ COPY scripts/motd /etc/motd
|
||||
RUN pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt update && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
mkdir -p ~/.ssh && \
|
||||
|
||||
26
docs/cli.qmd
26
docs/cli.qmd
@@ -23,6 +23,20 @@ axolotl <command> [config.yml] [options]
|
||||
|
||||
The config file can be local or a URL to a raw YAML file.
|
||||
|
||||
### Launcher Arguments
|
||||
|
||||
For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:
|
||||
|
||||
```bash
|
||||
# Pass torchrun arguments
|
||||
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
|
||||
|
||||
# Pass accelerate arguments
|
||||
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4
|
||||
```
|
||||
|
||||
Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).
|
||||
|
||||
## Command Reference
|
||||
|
||||
### fetch
|
||||
@@ -80,7 +94,11 @@ axolotl train config.yml \
|
||||
--num-epochs 3
|
||||
|
||||
# Training without accelerate
|
||||
axolotl train config.yml --no-accelerate
|
||||
axolotl train config.yml --launcher python
|
||||
|
||||
# Pass launcher-specific arguments using -- separator
|
||||
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
|
||||
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml
|
||||
|
||||
# Resume training from checkpoint
|
||||
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
|
||||
@@ -175,6 +193,9 @@ Evaluates a model's performance (loss etc) on the train and eval datasets.
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl evaluate config.yml
|
||||
|
||||
# Evaluation with launcher arguments
|
||||
axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2
|
||||
```
|
||||
|
||||
### lm-eval
|
||||
@@ -287,9 +308,6 @@ axolotl preprocess config.yml --cloud cloud_config.yml
|
||||
# Train on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml
|
||||
|
||||
# Train without accelerate on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
|
||||
|
||||
# Run lm-eval on cloud
|
||||
axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
```
|
||||
|
||||
@@ -69,11 +69,19 @@ export NCCL_BUFFSIZE=2097152
|
||||
|
||||
Run the following on each node:
|
||||
|
||||
### Option 1: New Axolotl CLI with launcher args (Recommended)
|
||||
|
||||
```bash
|
||||
axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port"
|
||||
```
|
||||
|
||||
### Option 2: Direct torchrun (Legacy)
|
||||
|
||||
```bash
|
||||
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
Please make sure to substitute the placeholder variables.
|
||||
Please make sure to substitute the placeholder variables:
|
||||
|
||||
- `num_nodes`: Number of nodes (containing GPUs)
|
||||
- `gpu_per_node`: Number of gpus per node
|
||||
@@ -81,8 +89,6 @@ Please make sure to substitute the placeholder variables.
|
||||
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
|
||||
- `rdzv_id`: A unique job ID that is used by the job across nodes.
|
||||
|
||||
::: {.callout-note}
|
||||
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
|
||||
:::
|
||||
The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.
|
||||
|
||||
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)
|
||||
|
||||
102
docs/nd_parallelism.qmd
Normal file
102
docs/nd_parallelism.qmd
Normal file
@@ -0,0 +1,102 @@
|
||||
# N-D Parallelism
|
||||
|
||||
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
||||
|
||||
- A model's weights are too large to fit on a single GPU's memory.
|
||||
- A model's activations, especially with very long contexts, are too large for a single GPU.
|
||||
- You want to accelerate training by using multiple GPUs or nodes.
|
||||
|
||||
or combinations of the above!
|
||||
|
||||
## Core Concepts
|
||||
|
||||
Parallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorch's `DeviceMesh` is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid.
|
||||
|
||||
### Data Parallelism {#sec-dp}
|
||||
|
||||
Data Parallelism focuses on splitting the global data batch across GPUs.
|
||||
|
||||
- Distributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states.
|
||||
|
||||
- [Fully Sharded Data Parallel (FSDP)](multi-gpu.qmd#fully-sharded-data-parallel-(fsdp)): A highly memory-efficient form of data parallelism (inspired by DeepSpeed's ZeRO). Instead of replicating the model, FSDP shards the model's *parameters, gradients, and optimizer states* across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an `all_gather` operation just before they are used, and they can be discarded immediately after (`reshard-after-forward`).
|
||||
- FSDP maps to ZeRO stages:
|
||||
- ZeRO-2 (`reshard_after_forward=False`): Shards gradients and optimizer states. Model weights are replicated on each GPU.
|
||||
- ZeRO-3 (`reshard_after_forward=True`): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes).
|
||||
|
||||
### [Experimental] Tensor Parallelism (TP) {#sec-tp}
|
||||
|
||||
Also known as "horizontal model parallelism," as described in the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Instead of splitting the batch, TP splits the model's layers themselves across GPUs.
|
||||
|
||||
- How it works: For a linear layer `Y = XA`, the weight matrix `A` is split column-wise (`A = [A_1, A_2]`). The computation becomes `Y_1 = XA_1` and `Y_2 = XA_2`, which can happen in parallel on different GPUs. The final output `Y` is simply the concatenation of `Y_1` and `Y_2`. Check [this comment](https://github.com/huggingface/transformers/issues/10321#issuecomment-783543530) for more detailed info.
|
||||
- Requirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes.
|
||||
|
||||
### Context Parallelism (CP) {#sec-cp}
|
||||
|
||||
Context Parallelism, also called [Sequence Parallelism](sequence_parallelism.qmd), addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs.
|
||||
|
||||
- How it works: If you have a sequence of 8192 tokens and a `context_parallel_size` of 4, each GPU will only handle a chunk of 2048 tokens.
|
||||
- The Challenge: Attention is not local; every token needs to "attend to" every other token. Splitting the sequence breaks this.
|
||||
- The Solution (`ring-flash-attention`): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a "ring." After `N-1` steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized `flash-attention` kernel at each step.
|
||||
|
||||
### Hybrid Sharding Data Parallel (HSDP) {#sec-hsdp}
|
||||
|
||||
HSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training.
|
||||
|
||||
- Intra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the `all_gather` operations for sharded parameters fast.
|
||||
- Inter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDP's parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband).
|
||||
- Example: With 2 nodes of 8 GPUs each (16 total), you could have `dp_shard_size=8` (FSDP within each node) and `dp_replicate_size=2` (DDP across the two nodes).
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
# ...
|
||||
|
||||
# The number of GPUs to shard the model parameters across (FSDP dimension).
|
||||
dp_shard_size: 4
|
||||
|
||||
# The number of times to replicate the sharded model (DDP dimension).
|
||||
dp_replicate_size: 2
|
||||
|
||||
# Number of GPUs for Tensor Parallelism.
|
||||
tensor_parallel_size: 1 # (default is 1, no TP)
|
||||
|
||||
# Number of GPUs for Context/Sequence Parallelism.
|
||||
context_parallel_size: 1 # (default is 1, no CP)
|
||||
```
|
||||
|
||||
Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size`.
|
||||
|
||||
## Examples
|
||||
|
||||
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
|
||||
- You want FSDP within each node and DDP across nodes.
|
||||
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
|
||||
|
||||
2. FSDP + TP on a single 8-GPU node:
|
||||
- You want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP.
|
||||
- Set `dp_shard_size: 4` and `tensor_parallel_size: 2`.
|
||||
|
||||
3. FSDP + CP on a single 8-GPU node for long context:
|
||||
- You want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs.
|
||||
- Set `dp_shard_size: 8` and `context_parallel_size: 8`. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group.
|
||||
|
||||
## Support Matrix
|
||||
|
||||
This matrix describes how different parallelism methods can be combined in Axolotl.
|
||||
|
||||
| Combination | `dp_replicate_size` | `dp_shard_size` | `tp_size` | `cp_size` | Status & Notes |
|
||||
| --- | :---: | :---: |:---:|:---:|---|
|
||||
| **FSDP** (ZeRO-3) | 1 | >1 | 1 | 1 | ✅ Fully supported. Shards model across all GPUs. |
|
||||
| **HSDP** | >1 | >1 | 1 | 1 | ✅ Fully supported. FSDP intra-node, DDP inter-node. |
|
||||
| **FSDP + TP** | 1 | >1 | >1 | 1 | ✅ **2D Parallelism**. Shards the model across a `dp_shard` group, and TP-splits layers within the `tp` group. |
|
||||
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
|
||||
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
|
||||
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
|
||||
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP/CP without FSDP is inefficient and complex. You should use FSDP instead (`dp_shard_size > 1`). |
|
||||
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
|
||||
|
||||
- `tp_size` refers to `tensor_parallel_size`
|
||||
- `cp_size` refers to `context_parallel_size`
|
||||
@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
context_parallel_size: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -30,7 +30,7 @@ heads_k_stride: 1
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||
- With 4 GPUs, valid values would be 2 or 4
|
||||
@@ -66,7 +66,7 @@ sequence_len: 8192
|
||||
|
||||
...
|
||||
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
|
||||
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
@@ -20,7 +20,7 @@ min_sample_len: 200_000
|
||||
sample_packing: true
|
||||
|
||||
tiled_mlp: true
|
||||
sequence_parallel_degree: 8
|
||||
context_parallel_size: 8
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ flash_optimum:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
|
||||
warmup_steps: 32
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
save_total_limit:
|
||||
|
||||
@@ -43,7 +43,7 @@ xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -47,7 +47,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -48,7 +48,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -47,7 +47,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -48,7 +48,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -47,7 +47,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -48,7 +48,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -54,7 +54,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -47,7 +47,7 @@ xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 40
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -77,7 +77,7 @@ xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.000001
|
||||
|
||||
@@ -44,7 +44,7 @@ xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 40
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -40,7 +40,7 @@ xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -41,7 +41,7 @@ xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -42,7 +42,7 @@ logging_steps: 5
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0001
|
||||
|
||||
@@ -42,7 +42,7 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -50,7 +50,7 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -43,7 +43,7 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -49,7 +49,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -49,7 +49,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -45,7 +45,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -48,7 +48,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -43,7 +43,7 @@ logging_steps: 5
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0001
|
||||
|
||||
@@ -41,7 +41,7 @@ logging_steps: 1
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0
|
||||
|
||||
@@ -50,7 +50,7 @@ flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -48,7 +48,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
eval_steps:
|
||||
saves_per_epoch: 4
|
||||
|
||||
@@ -49,7 +49,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: false
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -47,7 +47,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -38,7 +38,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -49,7 +49,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -75,7 +75,7 @@ xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -20,7 +20,7 @@ special_tokens:
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
|
||||
# Iterations
|
||||
num_epochs: 1
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@631d646\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -51,7 +51,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -51,7 +51,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -37,7 +37,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -61,7 +61,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,19 +1,65 @@
|
||||
# Gemma-3n
|
||||
# Finetune Gemma-3n with Axolotl
|
||||
|
||||
## Requirements
|
||||
Gemma-3n is a family of multimodal models from Google found on [HuggingFace](https://huggingface.co/collections/google/gemma-3n-685065323f5984ef315c93f4). This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
In addition to Axolotl's requirements, Gemma-3n requires
|
||||
## Getting started
|
||||
|
||||
```
|
||||
pip3 install timm
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
```
|
||||
|
||||
If you will load audio datasets, please also install
|
||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||
|
||||
```
|
||||
pip3 install librosa
|
||||
```bash
|
||||
pip3 install timm==1.0.17
|
||||
|
||||
# for loading audio data
|
||||
pip3 install librosa==0.11.0
|
||||
```
|
||||
|
||||
## Usage
|
||||
3. Run the finetuning example:
|
||||
|
||||
See example configs and the [multimodal doc](https://docs.axolotl.ai/docs/multimodal.html).
|
||||
```bash
|
||||
# text only
|
||||
axolotl train examples/gemma3n/gemma-3n-e2b-qlora.yml
|
||||
|
||||
# text + vision
|
||||
axolotl train examples/gemma3n/gemma-3n-e2b-vision-qlora.yml
|
||||
|
||||
# text + vision + audio
|
||||
axolotl train examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml
|
||||
```
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
WARNING: The loss and grad norm will be much higher than normal. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
|
||||
|
||||
### TIPS
|
||||
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- The multimodal dataset format follows the OpenAI multi-content Messages format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Gemma 3n Blog](https://ai.google.dev/gemma/docs/gemma-3n)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
@@ -34,8 +34,6 @@ eot_tokens:
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-audio-2k-test
|
||||
type: chat_template
|
||||
data_files:
|
||||
- dataset.jsonl
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -55,7 +55,7 @@ flash_attention: true
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -49,7 +49,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -48,7 +48,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ gradient_checkpointing_kwargs:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -48,7 +48,7 @@ flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ logging_steps: 1
|
||||
flash_attention:
|
||||
sdp_attention:
|
||||
flash_optimum:
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -52,7 +52,7 @@ flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -47,7 +47,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -47,7 +47,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -50,7 +50,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -48,7 +48,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -25,9 +25,12 @@ lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
relora_steps: 150
|
||||
relora_warmup_steps: 10
|
||||
relora: true
|
||||
relora_prune_ratio: 0.9
|
||||
relora_cpu_offload: false
|
||||
jagged_restart_steps: 150
|
||||
jagged_restart_warmup_steps: 10
|
||||
jagged_restart_anneal_steps: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
@@ -50,7 +53,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -58,7 +58,7 @@ logging_steps: 1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
|
||||
@@ -9,6 +9,7 @@ liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
|
||||
chat_template: llama3
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
@@ -50,7 +51,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -36,7 +36,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -67,7 +67,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -58,7 +58,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -79,7 +79,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -55,7 +55,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -15,6 +15,7 @@ lora_model_dir:
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
# Currently, we don't support dropout with our custom Triton kernels
|
||||
@@ -58,7 +59,7 @@ flash_attention: true
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -53,7 +53,7 @@ flash_attention: true
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -54,7 +54,7 @@ flash_attention: true
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -51,7 +51,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -55,7 +55,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -56,7 +56,7 @@ flash_attention: true
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -41,7 +41,7 @@ gradient_checkpointing_kwargs:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -50,7 +50,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -48,7 +48,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -47,7 +47,7 @@ logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
|
||||
@@ -66,7 +66,7 @@ gradient_checkpointing: offload
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -84,7 +84,7 @@ fsdp_config:
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
pad_token: <|finetune_right_pad|>
|
||||
eos_token: <|eot|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -69,7 +69,7 @@ tf32: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -88,7 +88,7 @@ fsdp_config:
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_activation_checkpointing: true
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
pad_token: <|finetune_right_pad|>
|
||||
eos_token: <|eot|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -76,12 +76,12 @@ gradient_checkpointing: offload
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
pad_token: <|finetune_right_pad|>
|
||||
eos_token: <|eot|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -65,7 +65,7 @@ tf32: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -84,7 +84,7 @@ fsdp_config:
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_activation_checkpointing: true
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
pad_token: <|finetune_right_pad|>
|
||||
eos_token: <|eot|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -64,7 +64,7 @@ flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -82,7 +82,7 @@ fsdp_config:
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_activation_checkpointing: true
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
pad_token: <|finetune_right_pad|>
|
||||
eos_token: <|eot|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -74,13 +74,13 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
logging_steps: 1
|
||||
warmup_steps: 20
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
pad_token: <|finetune_right_pad|>
|
||||
eos_token: <|eot|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -67,7 +67,7 @@ flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -85,7 +85,7 @@ fsdp_config:
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_activation_checkpointing: true
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
pad_token: <|finetune_right_pad|>
|
||||
eos_token: <|eot|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Finetune Magistral Small with Axolotl
|
||||
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on [HuggingFace](https://huggingface.co/mistralai/Magistral-Small-2506). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506) and [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
|
||||
|
||||
@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 recommended)
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
@@ -31,12 +31,37 @@ This config uses about 24GB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Thinking
|
||||
|
||||
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities. The model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
|
||||
Example format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [{ "type": "text", "text": "..."}]},
|
||||
{"role": "assistant", "content": [{ "type": "thinking", "thinking": "..."}, { "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
Example config: `./magistral-small-think-qlora.yaml`.
|
||||
|
||||
The `thinking` section also supports an optional arg `closed: bool` (`True` default) which controls adding the closing `[/THINK]` tag.
|
||||
|
||||
Limitations:
|
||||
- You cannot mix `content: str` with `content: list[dict]` as the `dataset.load_dataset` may complain about different types for `content` key.
|
||||
- This mode does not work with custom `train_detail` and `training` at the moment.
|
||||
|
||||
### TIPS
|
||||
|
||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
|
||||
@@ -6,6 +6,9 @@ tokenizer_use_mistral_common: true
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
|
||||
@@ -6,6 +6,9 @@ tokenizer_use_mistral_common: true
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
|
||||
68
examples/magistral/magistral-small-think-qlora.yaml
Normal file
68
examples/magistral/magistral-small-think-qlora.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: mistralai/Magistral-Small-2507
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: Nanobit/text-think-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -41,7 +41,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -38,7 +38,7 @@ resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -59,7 +59,7 @@ sdp_attention: true
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -59,7 +59,7 @@ flash_attention: true
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user