Compare commits
22 Commits
main
...
activeblue
| Author | SHA1 | Date | |
|---|---|---|---|
| c6da9b9e92 | |||
| c7c4885369 | |||
| 981a13e110 | |||
| 74f2263ac7 | |||
| 8693a1f61b | |||
| 71c6a56e7a | |||
| 38adf5cd37 | |||
| 3f29fa017b | |||
| c02a76f132 | |||
| b9ceebfe7e | |||
| e9a3fd483f | |||
| eadd15c960 | |||
| 396ce4a9dd | |||
|
|
b7ec06b8a1 | ||
|
|
e2f01de0e8 | ||
|
|
5352d41d32 | ||
|
|
c15f6cffe2 | ||
|
|
e4032fc90f | ||
|
|
6136ae627b | ||
|
|
e662972a29 | ||
|
|
ebbd7fa847 | ||
|
|
ac77da96da |
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
|||||||
exclude:
|
exclude:
|
||||||
- python_version: "3.14"
|
- python_version: "3.14"
|
||||||
pytorch_version: "2.9.1"
|
pytorch_version: "2.9.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 25
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: cleanup node
|
- name: cleanup node
|
||||||
|
|||||||
@@ -29,6 +29,9 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
|
- 2026/04:
|
||||||
|
- New model support has been added in Axolotl for [Mistral Medium 3.5](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral-medium-3_5) and [Gemma 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma4).
|
||||||
|
- Axolotl is now [uv-first](https://github.com/axolotl-ai-cloud/axolotl/pull/3545) and has [SonicMoE fused LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3519) support.
|
||||||
- 2026/03:
|
- 2026/03:
|
||||||
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
||||||
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
||||||
|
|||||||
273
SETUP_MIAAI.md
Normal file
273
SETUP_MIAAI.md
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
# Axolotl Setup — miaai (RTX 5080, CUDA 13.2)
|
||||||
|
|
||||||
|
## System Info
|
||||||
|
- GPU: NVIDIA RTX 5080 (16GB VRAM, sm_120 / Blackwell)
|
||||||
|
- Driver: 580.126.09 — max CUDA 13.0 shown by nvidia-smi, but nvcc from conda is 13.2
|
||||||
|
- OS: Ubuntu 25.10 (Python 3.13 system — do NOT use system Python for ML)
|
||||||
|
- Axolotl repo: `/home/tocmo0nlord/axolotl` (branch: `activeblue/main`)
|
||||||
|
- Conda env: `axolotl` at `/opt/miniconda3/envs/axolotl`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Starting from Bare Ubuntu 25.10
|
||||||
|
|
||||||
|
If rebuilding from scratch, complete these steps first before anything else.
|
||||||
|
|
||||||
|
### A. System packages
|
||||||
|
```bash
|
||||||
|
sudo apt update && sudo apt upgrade -y
|
||||||
|
sudo apt install -y \
|
||||||
|
build-essential cmake git curl wget \
|
||||||
|
python3-dev libssl-dev zlib1g-dev \
|
||||||
|
ca-certificates gnupg lsb-release
|
||||||
|
```
|
||||||
|
|
||||||
|
### B. NVIDIA driver (580.xx)
|
||||||
|
Ubuntu 25.10 is too new for NVIDIA's apt repo. Install via ubuntu-drivers:
|
||||||
|
```bash
|
||||||
|
sudo ubuntu-drivers autoinstall
|
||||||
|
sudo reboot
|
||||||
|
```
|
||||||
|
|
||||||
|
After reboot, verify:
|
||||||
|
```bash
|
||||||
|
nvidia-smi
|
||||||
|
# Must show: NVIDIA GeForce RTX 5080, Driver Version: 580.x
|
||||||
|
```
|
||||||
|
|
||||||
|
If ubuntu-drivers installs the wrong version, force the right one:
|
||||||
|
```bash
|
||||||
|
sudo apt install -y nvidia-driver-580
|
||||||
|
sudo reboot
|
||||||
|
```
|
||||||
|
|
||||||
|
### C. Install Ollama
|
||||||
|
```bash
|
||||||
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
|
|
||||||
|
# Verify it's running
|
||||||
|
systemctl status ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
### D. HuggingFace CLI
|
||||||
|
```bash
|
||||||
|
pip3 install huggingface_hub
|
||||||
|
huggingface-cli login
|
||||||
|
# Paste your HF token — required for gated models like meta-llama
|
||||||
|
```
|
||||||
|
|
||||||
|
Once steps A–D are done, continue with the One-time Setup below.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Pre-Training Checklist (every session)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Stop Ollama — if it receives a request mid-training it will compete for VRAM
|
||||||
|
sudo systemctl stop ollama
|
||||||
|
|
||||||
|
# 2. Activate conda env
|
||||||
|
export PATH="/opt/miniconda3/bin:$PATH"
|
||||||
|
conda activate axolotl
|
||||||
|
|
||||||
|
# 3. Set env vars
|
||||||
|
export CUDA_HOME=$CONDA_PREFIX
|
||||||
|
export PATH=$CUDA_HOME/bin:$PATH
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||||
|
|
||||||
|
# 4. Confirm GPU is clear (should show no processes before training)
|
||||||
|
nvidia-smi --query-compute-apps=pid,process_name,used_memory --format=csv
|
||||||
|
|
||||||
|
# 5. Go to axolotl directory
|
||||||
|
cd /home/tocmo0nlord/axolotl
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run Training
|
||||||
|
```bash
|
||||||
|
axolotl train ~/human_chat_qlora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
## After Training
|
||||||
|
```bash
|
||||||
|
# Restart Ollama
|
||||||
|
sudo systemctl start ollama
|
||||||
|
|
||||||
|
# Test the adapter interactively
|
||||||
|
axolotl inference ~/human_chat_qlora.yml \
|
||||||
|
--lora-model-dir ~/outputs/llama31-8b-humanchat \
|
||||||
|
--prompter chat
|
||||||
|
|
||||||
|
# (Optional) Merge adapter into base model for standalone deployment
|
||||||
|
axolotl merge-lora ~/human_chat_qlora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## One-time Setup (fresh machine — after bare Ubuntu steps above)
|
||||||
|
|
||||||
|
### 1. Install Miniconda
|
||||||
|
```bash
|
||||||
|
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
|
||||||
|
bash miniconda.sh -b -p /opt/miniconda3
|
||||||
|
/opt/miniconda3/bin/conda init bash
|
||||||
|
source ~/.bashrc
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Create Python 3.11 environment
|
||||||
|
```bash
|
||||||
|
conda create -n axolotl python=3.11 -y
|
||||||
|
conda activate axolotl
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Clone axolotl repo
|
||||||
|
```bash
|
||||||
|
git clone https://git.activeblue.net/tocmo0nlord/axolotl.git /home/tocmo0nlord/axolotl
|
||||||
|
cd /home/tocmo0nlord/axolotl
|
||||||
|
git remote add upstream https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
git fetch upstream
|
||||||
|
git rebase upstream/main # keeps activeblue patches on top
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Install CUDA toolkit (needed to compile flash-attn and bitsandbytes)
|
||||||
|
```bash
|
||||||
|
conda install -y -c "nvidia/label/cuda-12.8.0" cuda-toolkit
|
||||||
|
export CUDA_HOME=$CONDA_PREFIX
|
||||||
|
export PATH=$CUDA_HOME/bin:$PATH
|
||||||
|
```
|
||||||
|
|
||||||
|
> NOTE: Despite installing from the cuda-12.8.0 channel, conda resolves nvcc to **13.2.78**.
|
||||||
|
> This is fine — use cu132 everywhere to match.
|
||||||
|
|
||||||
|
### 5. Install PyTorch — use cu132 (matches nvcc from conda)
|
||||||
|
```bash
|
||||||
|
# torchaudio has no cu132 wheel — skip it, not needed for LLM training
|
||||||
|
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu132
|
||||||
|
python -c "import torch; print('CUDA:', torch.version.cuda); print('GPU:', torch.cuda.get_device_name(0))"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Install Axolotl
|
||||||
|
```bash
|
||||||
|
cd /home/tocmo0nlord/axolotl
|
||||||
|
pip install -e "."
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. Install flash-attn
|
||||||
|
> Compiles CUDA kernels from source — takes 15–25 min on 10 cores of i7-14700K.
|
||||||
|
```bash
|
||||||
|
MAX_JOBS=10 pip install flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8. Compile bitsandbytes from source for sm_120 (RTX 5080 / Blackwell)
|
||||||
|
|
||||||
|
Prebuilt wheels do not include sm_120. CUDA 13.2 also dropped sm_50–53.
|
||||||
|
Must compile from source with a patched CMakeLists.txt.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone bitsandbytes v0.49.1
|
||||||
|
git clone --branch v0.49.1 --depth 1 \
|
||||||
|
https://github.com/bitsandbytes-foundation/bitsandbytes.git /tmp/bnb_0491
|
||||||
|
|
||||||
|
# Patch CMakeLists.txt: insert sm_120 override before the foreach loop
|
||||||
|
# (cmake >= 3.23.0 uses its own built-in arch list which does not include sm_120)
|
||||||
|
sed -i '/ foreach(capability \${CMAKE_CUDA_ARCHITECTURES_ALL})/i\ # RTX 5080 sm_120 patch\n set(CMAKE_CUDA_ARCHITECTURES_ALL 120)' /tmp/bnb_0491/CMakeLists.txt
|
||||||
|
|
||||||
|
# Verify patch landed correctly — set() line must appear immediately before foreach
|
||||||
|
grep -n "ARCHITECTURES_ALL\|foreach" /tmp/bnb_0491/CMakeLists.txt | tail -5
|
||||||
|
|
||||||
|
# Configure — must point cmake at conda's nvcc explicitly
|
||||||
|
cmake \
|
||||||
|
-DCMAKE_CUDA_COMPILER=/opt/miniconda3/envs/axolotl/bin/nvcc \
|
||||||
|
-DCOMPUTE_BACKEND=cuda \
|
||||||
|
-S /tmp/bnb_0491 \
|
||||||
|
-B /tmp/bnb_0491/build 2>&1 | grep -E "(Capabilit|CUDA Ver|Error)"
|
||||||
|
# Must show: CUDA Capabilities Selected: 120
|
||||||
|
|
||||||
|
# Build (adjust -j to your CPU core count)
|
||||||
|
cmake --build /tmp/bnb_0491/build -j10
|
||||||
|
|
||||||
|
# Install into conda site-packages
|
||||||
|
cp -r /tmp/bnb_0491/bitsandbytes \
|
||||||
|
/opt/miniconda3/envs/axolotl/lib/python3.11/site-packages/
|
||||||
|
|
||||||
|
# Verify CUDA works
|
||||||
|
python3 -c "
|
||||||
|
import torch, bitsandbytes as bnb
|
||||||
|
x = torch.randn(64, 64, device='cuda')
|
||||||
|
l = bnb.nn.Linear8bitLt(64, 64).cuda()
|
||||||
|
print('bitsandbytes CUDA OK:', l(x).shape)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 9. Copy training config to home
|
||||||
|
```bash
|
||||||
|
cp /home/tocmo0nlord/axolotl/human_chat_qlora.yml ~/human_chat_qlora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 10. Verify the full stack
|
||||||
|
```bash
|
||||||
|
python3 -c "
|
||||||
|
import torch, bitsandbytes as bnb, flash_attn, transformers
|
||||||
|
print('torch :', torch.__version__, '| CUDA:', torch.version.cuda)
|
||||||
|
print('bitsandbytes:', bnb.__version__)
|
||||||
|
print('flash_attn :', flash_attn.__version__)
|
||||||
|
print('transformers:', transformers.__version__)
|
||||||
|
print('GPU :', torch.cuda.get_device_name(0))
|
||||||
|
print('VRAM :', round(torch.cuda.get_device_properties(0).total_memory/1e9, 1), 'GB')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected output:
|
||||||
|
```
|
||||||
|
torch : 2.x.x | CUDA: 13.2
|
||||||
|
bitsandbytes: 0.50.0.dev0
|
||||||
|
flash_attn : 2.x.x
|
||||||
|
transformers: 5.x.x
|
||||||
|
GPU : NVIDIA GeForce RTX 5080
|
||||||
|
VRAM : 16.3 GB
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Training Config — human_chat_qlora.yml
|
||||||
|
|
||||||
|
Key settings tuned for RTX 5080 (16GB):
|
||||||
|
|
||||||
|
| Setting | Value | Notes |
|
||||||
|
|---|---|---|
|
||||||
|
| `sequence_len` | `2048` | 4096 OOMs during loss computation (logits x 128k vocab) |
|
||||||
|
| `micro_batch_size` | `1` | Effective batch = micro x grad_accum = 8 |
|
||||||
|
| `gradient_accumulation_steps` | `8` | Keeps effective batch size at 8 |
|
||||||
|
| `adapter` | `qlora` | 4-bit via bitsandbytes compiled from source |
|
||||||
|
| `attn_implementation` | `flash_attention_2` | Not the deprecated `flash_attention: true` |
|
||||||
|
| `type` (datasets) | `chat_template` | Not the deprecated `sharegpt` |
|
||||||
|
|
||||||
|
Expected training metrics (RTX 5080, ~65k samples, 2 epochs):
|
||||||
|
- VRAM: ~10–11 GB active, ~11 GB allocated
|
||||||
|
- Training duration: ~3.5 hours
|
||||||
|
- Initial eval loss: ~0.81, perplexity ~2.25
|
||||||
|
- Final loss target: ~0.55–0.60
|
||||||
|
|
||||||
|
To push VRAM to ~14GB and improve training: set `micro_batch_size: 2` and `gradient_accumulation_steps: 4`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Common Pitfalls
|
||||||
|
|
||||||
|
| Problem | Cause | Fix |
|
||||||
|
|---|---|---|
|
||||||
|
| `externally-managed-environment` | System Python 3.13 blocks pip | Use conda env, never system pip |
|
||||||
|
| `No module named torch` (flash-attn) | pip builds in isolated env | Use `--no-build-isolation` |
|
||||||
|
| `CUDA_HOME not set` | CUDA toolkit not installed | `conda install cuda-toolkit` from nvidia channel |
|
||||||
|
| `CUDA version mismatch 13.2 vs 12.8` | Conda nvcc is 13.2, torch was cu128 | Reinstall torch with `--index-url .../cu132` |
|
||||||
|
| `torchaudio` not found for cu132 | No cu132 wheel exists | Skip torchaudio — not needed |
|
||||||
|
| flash-attn compile is slow | Single-threaded by default | Set `MAX_JOBS=<cpu_count>` before pip install |
|
||||||
|
| `nvcc fatal: Unsupported gpu architecture 'compute_50'` | bitsandbytes CMakeLists.txt hardcodes sm_50; CUDA 13.2 dropped it | Patch CMakeLists.txt (see step 8 above) |
|
||||||
|
| `CUDA Capabilities Selected: 50;52;...` ignores -D flag | cmake >= 3.23 built-in arch list lacks sm_120; CMakeLists.txt overrides -D | Insert `set(CMAKE_CUDA_ARCHITECTURES_ALL 120)` before foreach loop |
|
||||||
|
| `BackendUnavailable: scikit_build_core` | pip install of bnb triggers cmake rebuild | Copy .so directly to site-packages instead |
|
||||||
|
| `torch.OutOfMemoryError` during eval | logits tensor (batch x 4096 x 128k vocab) too large | Set `sequence_len: 2048`, `micro_batch_size: 1` |
|
||||||
|
| `type: sharegpt` deprecation warning | axolotl removed sharegpt type | Use `type: chat_template` with field mappings |
|
||||||
|
| `flash_attention: true` deprecation | Old config key removed | Use `attn_implementation: flash_attention_2` |
|
||||||
|
| Capybara dataset `field_messages null` | Capybara uses input/output format, not conversations | Switch to SlimOrca or OpenHermes-2.5 |
|
||||||
|
| Ollama loads model mid-training | Ollama is enabled and receives a request | `sudo systemctl stop ollama` before training |
|
||||||
|
| Training much slower than eval speed | The fast it/s on screen is the eval loop (forward only) | Normal — training includes backward pass and optimizer (~3.5h total) |
|
||||||
|
| ubuntu-drivers installs wrong NVIDIA version | Multiple driver candidates available | Force with `apt install nvidia-driver-580` |
|
||||||
@@ -311,6 +311,7 @@ website:
|
|||||||
- docs/dataset_loading.qmd
|
- docs/dataset_loading.qmd
|
||||||
- docs/qat.qmd
|
- docs/qat.qmd
|
||||||
- docs/quantize.qmd
|
- docs/quantize.qmd
|
||||||
|
- docs/1_58bit_finetuning.qmd
|
||||||
- docs/optimizations.qmd
|
- docs/optimizations.qmd
|
||||||
|
|
||||||
- section: "Core Concepts"
|
- section: "Core Concepts"
|
||||||
|
|||||||
@@ -1,16 +1,15 @@
|
|||||||
ARG CUDA_VERSION="12.8.1"
|
ARG CUDA_VERSION="12.8.2"
|
||||||
ARG CUDNN_VERSION="8"
|
|
||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
FROM nvidia/cuda:12.8.2-devel-ubuntu22.04 AS base-builder
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniforge3/bin:${PATH}"
|
||||||
|
|
||||||
ARG PYTHON_VERSION="3.11"
|
ARG PYTHON_VERSION="3.11"
|
||||||
ARG PYTORCH_VERSION="next"
|
ARG PYTORCH_VERSION="next"
|
||||||
ARG CUDA="128"
|
ARG CUDA="128"
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0 12.0+PTX"
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
@@ -18,13 +17,13 @@ ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
|||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||||
&& wget \
|
&& wget \
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh \
|
||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
&& bash Miniforge3-Linux-x86_64.sh -b \
|
||||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
&& rm -f Miniforge3-Linux-x86_64.sh \
|
||||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
&& /root/miniforge3/bin/conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
ENV PATH="/root/miniforge3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
|||||||
70
docs/1_58bit_finetuning.qmd
Normal file
70
docs/1_58bit_finetuning.qmd
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
---
|
||||||
|
title: "1.58-bit Finetuning"
|
||||||
|
back-to-top-navigation: true
|
||||||
|
toc: true
|
||||||
|
toc-expand: 2
|
||||||
|
toc-depth: 4
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
1.58-bit finetuning allows you to finetune BitNet models when their prequantized weights are provided. In theory, it will be possible to fine-tune any LLM in 1.58bit format but the performance degradation will be dramatic.
|
||||||
|
|
||||||
|
Axolotl supports 1.58-bit finetuning via the [`onebitllms`](https://github.com/tiiuae/onebitllms) library, which replaces standard linear layers with BitNet-compatible counterparts ready to use for training.
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
LoRA is not supported for BitNet models
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Install the `onebitllms` package before using this feature:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install onebitllms
|
||||||
|
```
|
||||||
|
|
||||||
|
Or from source:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install git+https://github.com/tiiuae/onebitllms
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported models
|
||||||
|
|
||||||
|
For now, only `Falcon-E` series of models are supported. Make sure to use their `-prequantized` version:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tiiuae/Falcon-E-3B-Base-prequantized
|
||||||
|
tiiuae/Falcon-E-1B-Base-prequantized
|
||||||
|
```
|
||||||
|
|
||||||
|
In theory, any other model would 'work' but the performance degradation will be huge. This remains an area of exploration.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
To enable 1.58-bit finetuning, set the following in your configuration file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: tiiuae/Falcon-E-3B-Base-prequantized # A BitNet-compatible model
|
||||||
|
|
||||||
|
use_onebitllms: true
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
For BitNet models, it is recommended to use a higher learning rate than classic models (usually in the order of magnitude of 10x).
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Considerations after training
|
||||||
|
|
||||||
|
Once your model has been trained with 1.58bit fine-tuning, you can convert the trained model in ternary format using the `onebitllms` CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
onebitllms quantize_to_1bit INPUT_PATH OUTPUT_PATH
|
||||||
|
```
|
||||||
|
|
||||||
|
After that, you can use supported packages such as `llama.cpp` or Apple MLX package to run the trained model.
|
||||||
|
|
||||||
|
## Example Configuration
|
||||||
|
|
||||||
|
You can find example configurations in `examples/falcon-e` which contain one configuration for SFT and one configuration for DPO.
|
||||||
@@ -121,11 +121,11 @@ Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2
|
|||||||
|
|
||||||
| Backend | Config | head_dim limit | torch_compile | Notes |
|
| Backend | Config | head_dim limit | torch_compile | Notes |
|
||||||
|---------|--------|---------------|---------------|-------|
|
|---------|--------|---------------|---------------|-------|
|
||||||
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
|
| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
|
||||||
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||||||
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
|
| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
|
||||||
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||||||
| eager | neither set | None | ✅ | Slowest, always works |
|
| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
|
||||||
|
|
||||||
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
|||||||
| Issue | Fix |
|
| Issue | Fix |
|
||||||
|-------|-----|
|
|-------|-----|
|
||||||
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
||||||
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
|
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
|
||||||
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
||||||
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
||||||
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
||||||
|
|||||||
@@ -3,28 +3,71 @@ title: Attention
|
|||||||
description: Supported attention modules in Axolotl
|
description: Supported attention modules in Axolotl
|
||||||
---
|
---
|
||||||
|
|
||||||
## SDP Attention
|
Axolotl routes attention via a single config field:
|
||||||
|
|
||||||
This is the default built-in attention in PyTorch.
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
sdp_attention: true
|
attn_implementation: <backend>
|
||||||
```
|
```
|
||||||
|
|
||||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
`attn_implementation` is passed through to `transformers` verbatim (via
|
||||||
|
`model.config._attn_implementation`). Accepted values are the HF-native
|
||||||
|
backends, axolotl-registered backends, or a hub-kernel path.
|
||||||
|
|
||||||
## Flash Attention
|
## Backends
|
||||||
|
|
||||||
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
| `attn_implementation` | Description |
|
||||||
based on your installed packages and GPU.
|
|---|---|
|
||||||
|
| `eager` | Plain PyTorch attention. No packing support. |
|
||||||
|
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
|
||||||
|
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
|
||||||
|
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
|
||||||
|
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
|
||||||
|
| `xformers` | xFormers memory-efficient attention. |
|
||||||
|
| `sage` | SageAttention (QK int8 / PV fp16). |
|
||||||
|
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
|
||||||
|
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
|
||||||
|
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
|
||||||
|
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
|
||||||
|
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
|
||||||
|
|
||||||
|
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
|
||||||
|
set the canonical name above.
|
||||||
|
|
||||||
|
### Capability flags
|
||||||
|
|
||||||
|
Axolotl derives three boolean capability flags from `attn_implementation` and
|
||||||
|
exposes them on the validated config:
|
||||||
|
|
||||||
|
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
|
||||||
|
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
|
||||||
|
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
|
||||||
|
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
|
||||||
|
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
|
||||||
|
(everything except `eager` and `sdpa`).
|
||||||
|
|
||||||
|
These are **computed** — they cannot be overridden from YAML.
|
||||||
|
|
||||||
|
## Per-backend notes
|
||||||
|
|
||||||
|
### SDPA
|
||||||
|
|
||||||
|
Default PyTorch attention. See
|
||||||
|
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
flash_attention: true
|
attn_implementation: sdpa
|
||||||
```
|
```
|
||||||
|
|
||||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
### Flash Attention
|
||||||
|
|
||||||
### Flash Attention 2
|
Axolotl supports FA2, FA3, and FA4. The best available version is used
|
||||||
|
automatically based on your installed packages and GPU.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
attn_implementation: flash_attention_2 # or flash_attention_3
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Flash Attention 2
|
||||||
|
|
||||||
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||||
|
|
||||||
@@ -39,23 +82,25 @@ Alternatively, try reinstall or downgrade a version.
|
|||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### Flash Attention 3
|
#### Flash Attention 3
|
||||||
|
|
||||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||||
cd flash-attention/hopper
|
cd flash-attention/hopper
|
||||||
|
|
||||||
python setup.py install
|
python setup.py install
|
||||||
```
|
```
|
||||||
|
|
||||||
### Flash Attention 4
|
#### Flash Attention 4
|
||||||
|
|
||||||
Requirements: Hopper or Blackwell GPUs
|
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
|
||||||
|
is true and FA4 is importable.
|
||||||
|
|
||||||
|
FA4 is still a pre-release on PyPI, so `--pre` is required:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install flash-attn-4
|
pip install --pre flash-attn-4
|
||||||
```
|
```
|
||||||
|
|
||||||
Or from source:
|
Or from source:
|
||||||
@@ -63,7 +108,6 @@ Or from source:
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||||
cd flash-attention/flash_attn/cute
|
cd flash-attention/flash_attn/cute
|
||||||
|
|
||||||
pip install -e .
|
pip install -e .
|
||||||
|
|
||||||
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||||
@@ -86,93 +130,113 @@ and falls back to FA2/3.
|
|||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
|
||||||
|
|
||||||
### AMD
|
### AMD
|
||||||
|
|
||||||
Requirements: ROCm 6.0 and above.
|
Requirements: ROCm 6.0 and above. See
|
||||||
|
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||||
|
|
||||||
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
### Flex Attention
|
||||||
|
|
||||||
## Flex Attention
|
|
||||||
|
|
||||||
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
flex_attention: true
|
attn_implementation: flex_attention
|
||||||
|
torch_compile: true # recommended
|
||||||
# recommended
|
|
||||||
torch_compile: true
|
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-note}
|
Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
|
||||||
|
|
||||||
We recommend using latest stable version of PyTorch for best performance.
|
### SageAttention
|
||||||
|
|
||||||
:::
|
Requirements: Ampere, Ada, or Hopper GPUs.
|
||||||
|
|
||||||
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
|
||||||
|
|
||||||
## SageAttention
|
|
||||||
|
|
||||||
Attention kernels with QK Int8 and PV FP16 accumulator.
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
sage_attention: true
|
attn_implementation: sage
|
||||||
```
|
```
|
||||||
|
|
||||||
Requirements: Ampere, Ada, or Hopper GPUs
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install sageattention==2.2.0 --no-build-isolation
|
pip install sageattention==2.2.0 --no-build-isolation
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-warning}
|
::: {.callout-warning}
|
||||||
|
|
||||||
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
|
||||||
|
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
|
||||||
|
|
||||||
::: {.callout-note}
|
### xFormers
|
||||||
|
|
||||||
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
|
|
||||||
## xFormers
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
|
|
||||||
We recommend using with Turing GPUs or below (such as on Colab).
|
Recommended for Turing GPUs or below (e.g. Colab T4).
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
### Shifted Sparse Attention
|
||||||
|
|
||||||
## Shifted Sparse Attention
|
|
||||||
|
|
||||||
::: {.callout-warning}
|
::: {.callout-warning}
|
||||||
|
|
||||||
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
Planned for deprecation. Prefer one of the backends above.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Requirements: LLaMA model architecture
|
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
|
||||||
|
patched to implement shifted-sparse attention. Does not support sample packing.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
flash_attention: true
|
attn_implementation: s2
|
||||||
s2_attention: true
|
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-tip}
|
### FP8
|
||||||
|
|
||||||
No sample packing support!
|
torchao low-precision attention. Loaded as SDPA and patched post-load.
|
||||||
|
|
||||||
|
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
|
||||||
|
flash-attn with FA3. KV caching must be disabled.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
attn_implementation: fp8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hub kernels
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
attn_implementation: kernels-community/flash-attn3
|
||||||
|
```
|
||||||
|
|
||||||
|
Passed through to `transformers`; axolotl does not install the kernel itself.
|
||||||
|
For recognized hub paths the capability flags are set automatically; for
|
||||||
|
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
|
||||||
|
`attn_uses_flash_lib=False`).
|
||||||
|
|
||||||
|
## Migrating from legacy boolean flags
|
||||||
|
|
||||||
|
The following legacy config fields are **deprecated** and will be removed in a
|
||||||
|
future release. Each emits a `DeprecationWarning` when set and is stripped from
|
||||||
|
the validated config.
|
||||||
|
|
||||||
|
| Legacy | Canonical |
|
||||||
|
|---|---|
|
||||||
|
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
|
||||||
|
| `sdp_attention: true` | `attn_implementation: sdpa` |
|
||||||
|
| `xformers_attention: true` | `attn_implementation: xformers` |
|
||||||
|
| `flex_attention: true` | `attn_implementation: flex_attention` |
|
||||||
|
| `sage_attention: true` | `attn_implementation: sage` |
|
||||||
|
| `s2_attention: true` | `attn_implementation: s2` |
|
||||||
|
| `eager_attention: true` | `attn_implementation: eager` |
|
||||||
|
|
||||||
|
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
|
||||||
|
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
|
||||||
|
Existing example configs under `examples/` still use the legacy flags. They
|
||||||
|
continue to work with a deprecation warning; they will be migrated in a
|
||||||
|
follow-up pass.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ gradient_accumulation_steps: 4
|
|||||||
max_steps: 20
|
max_steps: 20
|
||||||
learning_rate: 5.0e-6
|
learning_rate: 5.0e-6
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
output_dir: ./outputs/ebft-quickstart
|
output_dir: ./outputs/ebft-quickstart
|
||||||
```
|
```
|
||||||
@@ -304,7 +304,7 @@ lora_alpha: 32
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flex_attention: true
|
attn_implementation: flex_attention
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: true # Required with flex_attention
|
use_reentrant: true # Required with flex_attention
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ lr_scheduler: cosine
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
|
|
||||||
bf16: true
|
bf16: true
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
84
docs/multimodal_assistant_mask.md
Normal file
84
docs/multimodal_assistant_mask.md
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Multimodal assistant-only loss masking
|
||||||
|
|
||||||
|
## Correct placement
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Top-level: only train_on_inputs lives here.
|
||||||
|
train_on_inputs: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: data/train.jsonl
|
||||||
|
type: chat_template
|
||||||
|
roles_to_train: # per-dataset — this is what the MM scanner reads
|
||||||
|
- assistant
|
||||||
|
train_on_eos: turn # per-dataset — same
|
||||||
|
|
||||||
|
test_datasets:
|
||||||
|
- path: data/val.jsonl
|
||||||
|
type: chat_template
|
||||||
|
split: train
|
||||||
|
roles_to_train:
|
||||||
|
- assistant
|
||||||
|
train_on_eos: turn
|
||||||
|
```
|
||||||
|
|
||||||
|
## How to verify at runtime
|
||||||
|
|
||||||
|
`build_collator` logs the resolved knobs at INFO:
|
||||||
|
|
||||||
|
```text
|
||||||
|
MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
|
||||||
|
```
|
||||||
|
|
||||||
|
If `roles_to_train` logs as `None`, the YAML knobs are not reaching the
|
||||||
|
scanner — check that they are under `datasets[0]`, not at the root.
|
||||||
|
|
||||||
|
Each verified strategy additionally logs its resolved boundary token ids at
|
||||||
|
strategy init (e.g. `<|turn>model` → `[105, 4368]`, `<turn|>` → `[106]` for
|
||||||
|
Gemma 4). If a strategy emits the "has no built-in role boundaries ... only
|
||||||
|
pad and media tokens are masked" one-shot warning instead, it is on the
|
||||||
|
fallback path — declare per-role markers in YAML via `cfg.role_boundaries`
|
||||||
|
(below) to activate masking. The strategies currently on this path are
|
||||||
|
listed in the audit table above under `fallback + warn`.
|
||||||
|
|
||||||
|
## Config-based override: `cfg.role_boundaries`
|
||||||
|
|
||||||
|
For the "unverified" strategies above, or for custom chat templates that
|
||||||
|
don't match a built-in strategy's markers, users can declare role boundaries
|
||||||
|
directly in YAML without subclassing:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
role_boundaries:
|
||||||
|
- role: assistant
|
||||||
|
start: "<|turn>model"
|
||||||
|
end: "<turn|>"
|
||||||
|
- role: user
|
||||||
|
start: "<|turn>user"
|
||||||
|
end: "<turn|>"
|
||||||
|
# Optional keys:
|
||||||
|
# include_start: false # default False
|
||||||
|
# include_end: true # default True, respects cfg.train_on_eos
|
||||||
|
# end: eos_token # sentinel: resolves to tokenizer.eos_token_id
|
||||||
|
# end: null # span runs to end of sequence
|
||||||
|
```
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
|
||||||
|
- `start` and `end` are literal strings; axolotl encodes them at strategy
|
||||||
|
init via `tokenizer.encode(..., add_special_tokens=False)` and logs the
|
||||||
|
resolved token-id sequences at INFO level.
|
||||||
|
- The special value `end: eos_token` is the portable way to express
|
||||||
|
"Pixtral-style assistant turns end at EOS" without hard-coding an id.
|
||||||
|
- `role_boundaries` is an **opt-in override**. A non-empty list **replaces**
|
||||||
|
the strategy's built-in declarations wholesale (partial overlays are
|
||||||
|
intentionally unsupported — they're hard to reason about at review time).
|
||||||
|
Leaving the field unset *or* setting it to an empty list (`[]`) both mean
|
||||||
|
"use the strategy's built-ins." Writing `role_boundaries: []` is almost
|
||||||
|
always a typo or leftover — honoring it literally would produce all-masked
|
||||||
|
labels and zero gradient, so it is treated the same as unset.
|
||||||
|
- `cfg.roles_to_train` still governs which declared roles contribute to
|
||||||
|
loss. You can declare `user` and `assistant` boundaries and set
|
||||||
|
`roles_to_train: ["assistant"]` to have the scanner correctly identify
|
||||||
|
user spans as masking boundaries without training on their content.
|
||||||
|
- Invalid specs fail loudly at strategy init (missing `role`/`start`,
|
||||||
|
unencodable markers), not silently at loss-compute time.
|
||||||
@@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac
|
|||||||
|
|
||||||
Using an optimized attention implementation is critical for training speed.
|
Using an optimized attention implementation is critical for training speed.
|
||||||
|
|
||||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
|
||||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
|
||||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
|
||||||
|
|
||||||
*Note: You should only enable one attention backend.*
|
See [Attention](attention.qmd) for the full list of backends and the canonical values.
|
||||||
|
|
||||||
### LoRA Optimizations
|
### LoRA Optimizations
|
||||||
|
|
||||||
|
|||||||
@@ -1147,8 +1147,7 @@ datasets:
|
|||||||
type: ebft_strided_structured.transform
|
type: ebft_strided_structured.transform
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
|
|
||||||
flash_attention: false
|
attn_implementation: flex_attention # Strided mode uses flex_attention
|
||||||
flex_attention: true # Strided mode uses flex_attention
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: true # Required for flex_attention
|
use_reentrant: true # Required for flex_attention
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ examples:
|
|||||||
title: Arcee AFM
|
title: Arcee AFM
|
||||||
|
|
||||||
# MistralAI
|
# MistralAI
|
||||||
|
- name: mistral-medium-3_5
|
||||||
|
title: Mistral Medium 3.5
|
||||||
- name: ministral3/think
|
- name: ministral3/think
|
||||||
title: Ministral 3 Thinking
|
title: Ministral 3 Thinking
|
||||||
- name: ministral3/vision
|
- name: ministral3/vision
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ To use sequence parallelism, you need:
|
|||||||
|
|
||||||
## Limitations
|
## Limitations
|
||||||
|
|
||||||
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
|
||||||
- May have a small performance overhead due to communication between GPUs
|
- May have a small performance overhead due to communication between GPUs
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with
|
|||||||
Reduces attention memory from O(n^2) to O(n):
|
Reduces attention memory from O(n^2) to O(n):
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
```
|
```
|
||||||
|
|
||||||
### Step 6: Offload with DeepSpeed
|
### Step 6: Offload with DeepSpeed
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ tf32: true
|
|||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 2
|
evals_per_epoch: 2
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 2
|
evals_per_epoch: 2
|
||||||
|
|||||||
@@ -50,8 +50,7 @@ tf32: true
|
|||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ activation_offloading: legacy
|
|||||||
|
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ activation_offloading: legacy
|
|||||||
|
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -59,8 +59,7 @@ gradient_checkpointing: false
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
|
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
sdp_attention:
|
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
|
|||||||
@@ -39,8 +39,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -43,8 +43,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -73,8 +73,7 @@ early_stopping_patience: 3
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
auto_resume_from_checkpoints: true
|
auto_resume_from_checkpoints: true
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -36,8 +36,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -37,8 +37,7 @@ bf16: auto
|
|||||||
tf32: true
|
tf32: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 5
|
logging_steps: 5
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ bf16: auto
|
|||||||
tf32: true
|
tf32: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 5
|
logging_steps: 5
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ tf32: false
|
|||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ tf32: false
|
|||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ bf16: auto
|
|||||||
tf32: true
|
tf32: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 5
|
logging_steps: 5
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ tf32: true
|
|||||||
gradient_checkpointing:
|
gradient_checkpointing:
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
flash_attn_cross_entropy: false
|
flash_attn_cross_entropy: false
|
||||||
flash_attn_rms_norm: true
|
flash_attn_rms_norm: true
|
||||||
flash_attn_fuse_mlp: true
|
flash_attn_fuse_mlp: true
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
flash_attn_cross_entropy: false
|
flash_attn_cross_entropy: false
|
||||||
flash_attn_rms_norm: true
|
flash_attn_rms_norm: true
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: false
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 0
|
evals_per_epoch: 0
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -71,8 +71,7 @@ early_stopping_patience: 3
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
auto_resume_from_checkpoints: true
|
auto_resume_from_checkpoints: true
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ load_in_4bit: true
|
|||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
bf16: auto
|
bf16: auto
|
||||||
tf32: false
|
tf32: false
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
special_tokens:
|
special_tokens:
|
||||||
bos_token: "<|startoftext|>"
|
bos_token: "<|startoftext|>"
|
||||||
eos_token: "<|endoftext|>"
|
eos_token: "<|endoftext|>"
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 2
|
evals_per_epoch: 2
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 2
|
evals_per_epoch: 2
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0
|
lora_dropout: 0
|
||||||
@@ -51,8 +50,8 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
scaling_softmax: true
|
# scaling_softmax: true # needs flex_attention
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
loss_watchdog_threshold: 5.0
|
||||||
loss_watchdog_patience: 3
|
loss_watchdog_patience: 3
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ output_dir: ./outputs/ndp-out/
|
|||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ output_dir: ./outputs/ndp-out/
|
|||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1 # must be 1 when using context parallel
|
micro_batch_size: 1 # must be 1 when using context parallel
|
||||||
|
|||||||
@@ -65,8 +65,7 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
attn_implementation: flash_attention_2
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ lora_dropout: 0.05
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ lora_target_linear: true
|
|||||||
|
|
||||||
# --- Hardware ---
|
# --- Hardware ---
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -47,8 +47,7 @@ lora_dropout: 0.05
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
attn_implementation: flex_attention
|
||||||
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
|
|
||||||
# (full-model compile conflicts with gradient checkpointing + flex_attention)
|
# (full-model compile conflicts with gradient checkpointing + flex_attention)
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ lora_dropout: 0.05
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ lora_target_linear: true
|
|||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
torch_dtype: bfloat16
|
torch_dtype: bfloat16
|
||||||
flash_attention: false
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
torch_compile: true
|
torch_compile: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ warmup_steps: 10
|
|||||||
weight_decay: 0.01
|
weight_decay: 0.01
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: false # strided EBFT uses flex_attention at runtime
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ lora_dropout: 0.0
|
|||||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ lora_dropout: 0.0
|
|||||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ lora_dropout: 0.0
|
|||||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
93
examples/falcon-e/falcon-e-3b-dpo.yaml
Normal file
93
examples/falcon-e/falcon-e-3b-dpo.yaml
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
base_model: axolotl-ai-co/Falcon-E-1.2-3B-Exp-prequantized
|
||||||
|
output_dir: ./output
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
|
||||||
|
use_kernels: false
|
||||||
|
use_scattermoe: false
|
||||||
|
use_sonicmoe: false
|
||||||
|
use_onebitllms: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
|
||||||
|
chat_template: tokenizer_default
|
||||||
|
|
||||||
|
rl: dpo
|
||||||
|
datasets:
|
||||||
|
- path: allenai/Dolci-Think-DPO-7B
|
||||||
|
split: train
|
||||||
|
type: chatml.ultra
|
||||||
|
|
||||||
|
dataset_prepared_path: ./axolotl_dataset_cache
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
trust_remote_code: false
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4 # This can run on 4 GPUs
|
||||||
|
|
||||||
|
# Very important to enable gradient accumulation with FSDP
|
||||||
|
# https://github.com/huggingface/transformers/issues/29425
|
||||||
|
accelerator_config:
|
||||||
|
gradient_accumulation_kwargs:
|
||||||
|
sync_each_batch: True
|
||||||
|
|
||||||
|
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 3
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
# adamw hyperparams
|
||||||
|
adam_beta1: 0.9
|
||||||
|
adam_beta2: 0.95
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 15.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 128
|
||||||
|
evals_per_epoch: 0
|
||||||
|
|
||||||
|
save_steps: 500
|
||||||
|
save_strategy: steps
|
||||||
|
|
||||||
|
weight_decay: 0.01
|
||||||
|
|
||||||
|
shuffle_merged_datasets: true
|
||||||
|
experimental_skip_move_to_device: true
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
|
||||||
|
# Comment to disable CP
|
||||||
|
# The number of GPUs to shard the model parameters across (FSDP dimension).
|
||||||
|
dp_shard_size: 1
|
||||||
|
|
||||||
|
# The number of times to replicate the sharded model (DDP dimension).
|
||||||
|
dp_replicate_size: 1
|
||||||
|
|
||||||
|
# Number of GPUs for Tensor Parallelism.
|
||||||
|
tensor_parallel_size: 1 # (default is 1, no TP)
|
||||||
|
|
||||||
|
# Number of GPUs for Context/Sequence Parallelism.
|
||||||
|
context_parallel_size: 1 # (default is 1, no CP)
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eos_token: <|end_of_text|>
|
||||||
|
|
||||||
|
eot_tokens:
|
||||||
|
- <|im_end|>
|
||||||
100
examples/falcon-e/falcon-e-3b-ft.yaml
Normal file
100
examples/falcon-e/falcon-e-3b-ft.yaml
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
base_model: tiiuae/Falcon-E-3B-Base-prequantized
|
||||||
|
output_dir: ./output
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
|
||||||
|
use_kernels: false
|
||||||
|
use_scattermoe: false
|
||||||
|
use_sonicmoe: false
|
||||||
|
use_onebitllms: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
|
||||||
|
chat_template: tokenizer_default
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
dataset_prepared_path: ./axolotl_dataset_cache
|
||||||
|
|
||||||
|
sequence_len: 32768
|
||||||
|
trust_remote_code: false
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4 # This can run on 4 GPUs
|
||||||
|
|
||||||
|
# Very important to enable gradient accumulation with FSDP
|
||||||
|
# https://github.com/huggingface/transformers/issues/29425
|
||||||
|
accelerator_config:
|
||||||
|
gradient_accumulation_kwargs:
|
||||||
|
sync_each_batch: True
|
||||||
|
|
||||||
|
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 3
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5.0e-4
|
||||||
|
# adamw hyperparams
|
||||||
|
adam_beta1: 0.9
|
||||||
|
adam_beta2: 0.95
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 15.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 128
|
||||||
|
evals_per_epoch: 0
|
||||||
|
|
||||||
|
save_steps: 500
|
||||||
|
save_strategy: steps
|
||||||
|
|
||||||
|
weight_decay: 0.01
|
||||||
|
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
shuffle_merged_datasets: true
|
||||||
|
experimental_skip_move_to_device: true
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|
||||||
|
# Comment to disable CP
|
||||||
|
# The number of GPUs to shard the model parameters across (FSDP dimension).
|
||||||
|
dp_shard_size: 1
|
||||||
|
|
||||||
|
# The number of times to replicate the sharded model (DDP dimension).
|
||||||
|
dp_replicate_size: 1
|
||||||
|
|
||||||
|
# Number of GPUs for Tensor Parallelism.
|
||||||
|
tensor_parallel_size: 1 # (default is 1, no TP)
|
||||||
|
|
||||||
|
# Number of GPUs for Context/Sequence Parallelism.
|
||||||
|
context_parallel_size: 1 # (default is 1, no CP)
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eos_token: <|end_of_text|>
|
||||||
|
|
||||||
|
eot_tokens:
|
||||||
|
- <|im_end|>
|
||||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -58,8 +58,7 @@ gradient_checkpointing: true
|
|||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -55,8 +55,7 @@ gradient_checkpointing: true
|
|||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ activation_offloading: true
|
|||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
|
|
||||||
# FA2 not supported
|
# FA2 not supported
|
||||||
sdp_attention: true
|
attn_implementation: sdpa
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ activation_offloading: true
|
|||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
|
|
||||||
# FA not supported
|
# FA not supported
|
||||||
flex_attention: true
|
attn_implementation: flex_attention
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ activation_offloading: true
|
|||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
|
|
||||||
# FA not supported
|
# FA not supported
|
||||||
sdp_attention: true
|
attn_implementation: sdpa
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ gradient_checkpointing: true
|
|||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
sdp_attention: true
|
attn_implementation: sdpa
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
loss_watchdog_threshold: 5.0
|
||||||
loss_watchdog_patience: 3
|
loss_watchdog_patience: 3
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ gradient_checkpointing: true
|
|||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
sdp_attention: true
|
attn_implementation: sdpa
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 0
|
evals_per_epoch: 0
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user