Compare commits
20 Commits
attn-imple
...
activeblue
| Author | SHA1 | Date | |
|---|---|---|---|
| 8693a1f61b | |||
| 71c6a56e7a | |||
| 38adf5cd37 | |||
| 3f29fa017b | |||
| c02a76f132 | |||
| b9ceebfe7e | |||
| e9a3fd483f | |||
| eadd15c960 | |||
| 396ce4a9dd | |||
|
|
b7ec06b8a1 | ||
|
|
e2f01de0e8 | ||
|
|
5352d41d32 | ||
|
|
c15f6cffe2 | ||
|
|
e4032fc90f | ||
|
|
6136ae627b | ||
|
|
e662972a29 | ||
|
|
ebbd7fa847 | ||
|
|
ac77da96da | ||
|
|
798c8fba89 | ||
|
|
17fc747f99 |
5
.github/CONTRIBUTING.md
vendored
5
.github/CONTRIBUTING.md
vendored
@@ -31,10 +31,11 @@ PRs are **greatly welcome**!
|
||||
|
||||
Please run below to setup env
|
||||
```bash
|
||||
# Install axolotl + dev and test dependencies from lockfile
|
||||
# Install axolotl + dev and test dependencies
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||
pre-commit install
|
||||
|
||||
# test
|
||||
|
||||
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -30,14 +30,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -168,14 +160,6 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
|
||||
12
.github/workflows/main.yml
vendored
12
.github/workflows/main.yml
vendored
@@ -18,12 +18,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -180,12 +174,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
|
||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
exclude:
|
||||
- python_version: "3.14"
|
||||
pytorch_version: "2.9.1"
|
||||
timeout-minutes: 20
|
||||
timeout-minutes: 25
|
||||
|
||||
steps:
|
||||
- name: cleanup node
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
# `attn-implementation-refactor` branch review
|
||||
|
||||
Review target: `attn-implementation-refactor` (5 commits ahead of main, merge base `69904781`).
|
||||
Scope: 16 files, +682 / −96.
|
||||
|
||||
## 1. What the branch is trying to do
|
||||
|
||||
Collapse seven boolean attention flags (`flash_attention`, `sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`, `s2_attention`, `eager_attention`) into a single `attn_implementation` field, with derived capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) for the gates that used to be ad-hoc OR-chains.
|
||||
|
||||
Mechanism: `normalize_attn_implementation` (a `@model_validator(mode="before")` on `AxolotlInputConfig`) maps bidirectionally between the new field and the legacy flags, with a priority list for legacy combos (`s2 + flash → s2`), and then computes the three capability flags from frozen sets in `enums.py`.
|
||||
|
||||
Adjacent changes: `xformers` and `sage` now register as their own entries in `ALL_ATTENTION_FUNCTIONS` (with FA2 mask behavior) instead of stomping the `flash_attention_2` slot. New `fp8` backend wires `torchao.prototype.attention.apply_low_precision_attention` in `apply_post_model_load_patches`.
|
||||
|
||||
## 2. Target design
|
||||
|
||||
**`cfg.attn_implementation` is the single source of truth on the validated config.**
|
||||
|
||||
- Its type is `str | None`. Accepted values are **canonical names only** — no short-form aliases:
|
||||
- HF-native: `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`. (`flash_attention_3` is net-new to axolotl — the current branch only encodes `flash_attention_2` under the short name `flash`.)
|
||||
- Axolotl-owned (registered into `ALL_ATTENTION_FUNCTIONS` under exactly these names): `xformers`, `sage`, `s2`, `fp8`.
|
||||
- Hub-kernel paths: `kernels-community/sage-attention`, `kernels-community/flash-attn3`, etc. — passthrough. Known-kernel allowlist in `enums.py` classifies the common ones into the capability tables.
|
||||
Short forms like `flash`, `fa2`, `fa3`, `sdp`, `flex` are rejected (Pydantic validation error with a pointer to the canonical name).
|
||||
- `model.py:_set_attention_config` passes `cfg.attn_implementation` to HF verbatim — no `_ATTN_IMPL_TO_HF` translation dict needed.
|
||||
- Legacy booleans (`flash_attention: true`, `sdp_attention: true`, …) are the **only** input aliases, kept for backwards compatibility. The normalizer maps them to the canonical `attn_implementation` value, emits a one-time `DeprecationWarning` per flag, and removes them from `data` so they're never readable on the validated `cfg`. `deprecated=True` on each Field surfaces this in JSON schema. Mapping is 1:1 with the current legacy-flag semantics (`flash_attention → flash_attention_2`, `sdp_attention → sdpa`, `flex_attention → flex_attention`, `xformers_attention → xformers`, `sage_attention → sage`, `s2_attention → s2`, `eager_attention → eager`).
|
||||
- Capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) are **`@computed_field` on the model**, not settable inputs. Lookup is keyed by the canonical `attn_implementation` string.
|
||||
- Unknown / user-supplied strings (custom hub kernels) pass through to HF but get **conservative capability defaults** (packing=False, flash-lib=False, dtype-cast=True). Known hub kernels axolotl can classify live in a small allowlist.
|
||||
- Downstream consumers read *only* `cfg.attn_implementation` and the capability flags. No `cfg.flash_attention`, `cfg.xformers_attention`, etc. anywhere in `src/`.
|
||||
|
||||
This is strictly what the branch is already *trying* to do — the gaps below are places it hasn't landed that goal yet.
|
||||
|
||||
## 3. Gaps and holes
|
||||
|
||||
### A. Legacy flags are still parallel state, not input-only
|
||||
|
||||
1. The normalizer *sets* the legacy flags on `data` (`impl_to_flag[attn_impl]` branch). It does not delete them. So `cfg.flash_attention` is still truthy after validation, and downstream code still reads it (see G).
|
||||
2. Short-form enum values (`flash`, `sdpa`, `fp8`) are persisted as-is on `cfg.attn_implementation`, which is why `model.py` needs `_ATTN_IMPL_TO_HF` to translate before passing to HF. Source-of-truth implies canonicalize at normalize-time, not translate at consume-time.
|
||||
3. Legacy flag + `attn_implementation` (consistent combo, e.g. `attn_implementation: flash + flash_attention: true`) emits no deprecation warning — only legacy-only path warns.
|
||||
4. Legacy Field descriptions (`xformers_attention`, `sdp_attention`, etc.) don't have `deprecated=True`, so JSON schema still advertises them as first-class.
|
||||
|
||||
### B. Validators that still only check the legacy flag
|
||||
|
||||
5. **`check_ebft_activation_offloading`** (`validation.py:1607-1619`) reads only `data.get("flex_attention")`. Users on `attn_implementation: flex_attention` bypass the incompatibility check.
|
||||
6. **`check_sample_packing_without_attention`** (`validation.py:188-203`) early-returns when `attn_implementation` is set but never validates the chosen backend actually supports packing. `attn_implementation: eager + sample_packing: true` silently passes; the old legacy-flag check warned.
|
||||
|
||||
### C. Non-enum strings fall through the capability tables
|
||||
|
||||
7. **HF-native `"flash_attention_2"`** is neither in `impl_to_flag` nor `FLASH_ATTN_LIB_IMPLS`. A user copy-pasting from HF docs gets `attn_uses_flash_lib=False`, silently disabling FA4 auto-apply, LLaMA flash hijack, `_patch_attention` (btlm, stablelm_epoch, mistral3, llava), and `_apply_flash_attention_peft_patches`.
|
||||
8. **Hub kernel strings** (`kernels-community/flash-attn3`, `kernels-community/sage-attention`) default to `attn_supports_packing=True` (silently enters multipack with varlen `position_ids` — correctness depends on the kernel honoring them) and `attn_uses_flash_lib=False` (so `context_parallel_size > 1` raises "requires flash attention" even for FA3 hub kernels).
|
||||
9. **Conflict trap for hub-kernel + legacy flag** (`config.py:1414-1419`): `attn_implementation: kernels-community/flash-attn3 + flash_attention: true` always raises, because `impl_to_flag.get(custom) is None` and the loop treats `flag != None` as conflict. Common combo in existing YAMLs breaks hard on upgrade.
|
||||
|
||||
### D. Silent behaviour change for xformers
|
||||
|
||||
10. Old `_apply_flash_attention_patches` did `self.cfg.flash_attention = True` for `xformers + sample_packing`. The new version doesn't, and xformers is not in `FLASH_ATTN_LIB_IMPLS`. Consumers that keyed off `cfg.flash_attention` now see falsy for xformers, silently dropping `_patch_attention` (btlm / stablelm_epoch+packing / mistral3 / llava model-type FA patches). Some of this is arguably correct cleanup (xformers has its own HF registry entry now), but the btlm/stablelm/mistral3 regression is not called out and not tested. Decide consciously, not by omission.
|
||||
|
||||
### E. Capability flags are writable Pydantic fields, not computed
|
||||
|
||||
11. `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` are declared `bool | None = Field(default=None)` on `AxolotlInputConfig`. YAML is not rejected — a user can set `attn_uses_flash_lib: true` and override the normalizer.
|
||||
|
||||
### F. Validator ordering (not covered by tests)
|
||||
|
||||
12. `AttentionValidationMixin.check_attention_fields` (inherited, `mode="before"`) and `normalize_attn_implementation` (subclass, `mode="before"`) both run during `model_validator` phase. Pydantic MRO may run the inherited one first. For legacy-only `s2_attention: true + flash_attention: true` (the test `test_s2_plus_flash_maps_to_s2` asserts this maps to `s2`), the inherited check may raise "multiple attention implementations set" before the normalizer runs. The test calls the classmethod directly and does not build the model, so this is unverified either way.
|
||||
|
||||
### G. Remaining legacy reads in `src/`
|
||||
|
||||
13. `src/axolotl/integrations/lm_eval/cli.py:120` reads `cfg.flash_attention`. Works for `attn_implementation=flash` only.
|
||||
14. `tests/e2e/multigpu/test_llama.py:524-526` writes `cfg.flash_attention = True` / `cfg.flex_attention = True`. Stale pattern.
|
||||
15. Dual-check idioms in `config.py` (lines 1464, 1478, 1570, 1586, 1774) and `validation.py` (198, 209, 221, 850, 1586, 1611) — `data.get("x_attention") or data.get("attn_implementation") == "x"` — are redundant once legacy flags are input-only; remove them.
|
||||
|
||||
### H. fp8 operational risk
|
||||
|
||||
16. The `fp8` docstring documents hard requirements (PyTorch ≥ 2.11, SM90+, flash-attn with FA3, torchao ≥ 0.17.0) and a runtime constraint (`config.use_cache = False`). None are validated — misconfig surfaces as a torchao runtime error. `xformers` and `sage` availability/compute-capability guards exist; `fp8` should match.
|
||||
|
||||
### I. Test coverage gaps
|
||||
|
||||
17. `test_attn_implementation.py` exercises the classmethod in isolation plus the constant sets. It does **not**:
|
||||
- Build a full `AxolotlInputConfig(**data)` (so validator ordering — item 12 — is untested).
|
||||
- Verify capability flags can't be overridden from YAML (item 11).
|
||||
- Cover `check_sample_packing_without_attention` with `attn_implementation: eager` (item 6).
|
||||
- Cover `check_ebft_activation_offloading` with `attn_implementation: flex_attention` (item 5).
|
||||
- Cover hub-kernel + legacy flag combo (item 9).
|
||||
- Cover `flash_attention_2` canonicalization (item 7).
|
||||
|
||||
## 4. Fix plan
|
||||
|
||||
Four phases, each commit-sized. Phases 1–2 are local and low-risk; phase 3 is the behaviour-changing cleanup; phase 4 is tests + docs.
|
||||
|
||||
### Phase 1 — Lock down the data model
|
||||
|
||||
1. Drop the `AttnImplementation` enum. `attn_implementation` becomes `str | None`, validated against a canonical allowlist (`eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`, `xformers`, `sage`, `s2`, `fp8`) **or** a hub-kernel path (`startswith("kernels-")` / contains `/`). Reject short-form strings like `flash` / `fa2` / `sdp` / `flex` with an explicit error pointing at the canonical name.
|
||||
2. Rewrite `normalize_attn_implementation` so its only job is mapping **legacy booleans → canonical `attn_implementation`** (for BC). Mapping is fixed:
|
||||
- `flash_attention → flash_attention_2`
|
||||
- `sdp_attention → sdpa`
|
||||
- `flex_attention → flex_attention`
|
||||
- `xformers_attention → xformers`
|
||||
- `sage_attention → sage`
|
||||
- `s2_attention → s2`
|
||||
- `eager_attention → eager`
|
||||
Priority for legacy combos stays as in the current branch (`s2 > sage > xformers > flex > flash > sdp > eager`). Emit a one-time `DeprecationWarning` per unique legacy flag seen. After mapping, delete the legacy flag keys from `data` so they never appear on validated `cfg`. If both a canonical `attn_implementation` and any legacy flag are set, raise (no silent precedence).
|
||||
|
||||
**Merge `AttentionValidationMixin.check_attention_fields` into this normalizer and delete the mixin method.** Pydantic v2 runs inherited `mode="before"` validators before subclass ones per MRO, so leaving them as siblings causes the inherited check to reject legacy combos like `s2 + flash` before the normalizer can map them. One validator, one source of conflict detection.
|
||||
|
||||
**Fix the gemma4-hybrid path**: change `data["attn_implementation"] = "flash"` to `data["attn_implementation"] = "flash_attention_2"` (the short name no longer validates after step 1).
|
||||
3. Convert `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` to `@computed_field`. The three capability tables move to `enums.py` as module constants keyed by the canonical `attn_implementation` string (including `flash_attention_3` — missing from the current branch — and known hub kernels):
|
||||
- Packing-capable: `{flash_attention_2, flash_attention_3, flex_attention, xformers, sage, kernels-community/flash-attn3, kernels-community/sage-attention}`.
|
||||
- Flash-lib (axolotl's monkeypatch targets): `{flash_attention_2, flash_attention_3, s2, kernels-community/flash-attn3}`.
|
||||
- No-dtype-cast: `{eager, sdpa}`.
|
||||
Unknown strings: conservative defaults (`packing=False, flash_lib=False, dtype_cast=True`).
|
||||
4. Delete `_ATTN_IMPL_TO_HF` from `model.py` and pass `cfg.attn_implementation` straight through. The gemma4-hybrid branch continues to override to `flash_attention_2` before passing to HF.
|
||||
5. `deprecated=True` on each legacy boolean Field so JSON schema + Pydantic surface the deprecation.
|
||||
|
||||
### Phase 2 — Fix the validators
|
||||
|
||||
6. `check_sample_packing_without_attention`: drop the early-return and gate on `attn_supports_packing`. Warn (or raise — pick one and be consistent) if packing is enabled with a non-packing backend.
|
||||
7. `check_ebft_activation_offloading`: replace `data.get("flex_attention")` with `attn_implementation == "flex_attention"`.
|
||||
8. Sweep items (item 15): remove every `data.get("x_attention") or data.get("attn_implementation") == "x"` dual-check — after phase 1 the legacy side is always `None`. Reduces ~10 lines of noise and eliminates the "which side wins" class of bugs.
|
||||
9. fp8 preflight (item 16): require `env_capabilities.compute_capability ≥ sm_90`, `torch_version ≥ 2.11`, and `torchao_version ≥ 0.17`. Warn if `use_cache` isn't explicitly `False`.
|
||||
|
||||
### Phase 3 — Migrate remaining consumers
|
||||
|
||||
10. `lm_eval/cli.py:120` → `flash_attention=cfg.attn_uses_flash_lib`.
|
||||
11. `lm_eval/__init__.py:26` currently reads `(cfg.attn_implementation == "flash")` — after canonicalization `"flash"` is never stored, so this evaluates `False` for every backend. Change to `cfg.attn_uses_flash_lib`.
|
||||
12. `validation.py:1137-1142` (NPU check) currently iterates `["flash_attention", "sdp_attention", "s2_attention"]` as string keys. Replace with `cfg.attn_implementation in {"flash_attention_2", "flash_attention_3", "sdpa", "s2"}` or the equivalent canonical-string set.
|
||||
13. `tests/e2e/multigpu/test_llama.py:524-526` → `cfg.attn_implementation = "flash_attention_2"` / `"flex_attention"`.
|
||||
14. **Xformers decision** (item 10): the old `cfg.flash_attention = True` side-effect activated `_patch_attention` for btlm/stablelm_epoch+packing/mistral3/llava. Two choices:
|
||||
- Add `xformers` to the set that gates `_patch_attention` (restore old behaviour, keeps patches live).
|
||||
- Document that those patches don't apply to xformers post-refactor and drop the paths if they're dead.
|
||||
Pick one explicitly and leave a commit note. Do not leave it as silent breakage.
|
||||
15. Add a repo-level check (`tests/test_no_legacy_attn_reads.py` or a ruff/grep pre-commit) that fails if anything outside `config.py`'s normalizer reads `cfg.flash_attention` / `cfg.sdp_attention` / etc. Keeps the invariant from rotting.
|
||||
|
||||
### Phase 4 — Tests + docs
|
||||
|
||||
14. Rewrite `test_attn_implementation.py` to build full `AxolotlInputConfig(**data)`, not just the classmethod. Covers validator ordering and the Pydantic-field-override issue.
|
||||
15. Add one test per gap closed above: `attn_implementation: eager + sample_packing`; `attn_implementation: flex_attention + activation_offloading`; short-form `flash` rejected; `flash_attention_2` passthrough; `kernels-community/flash-attn3` capability lookup; `attn_uses_flash_lib: true` in YAML rejected; legacy boolean emits `DeprecationWarning` and is absent from validated `cfg`; fp8 preflight failures.
|
||||
16. Update `docs/attention.qmd` for the single `attn_implementation` field + the deprecation table for legacy flags. One-paragraph migration note in the changelog.
|
||||
17. `examples/` contains ~170 YAML files using legacy flags (`flash_attention: true` etc.). They still validate post-refactor (normalizer maps them with deprecation), but a follow-up sweep to convert them to `attn_implementation: flash_attention_2` is worth scheduling — call this out in the migration note so users know examples will be migrated on a later pass.
|
||||
|
||||
## 5. Ordering & risk
|
||||
|
||||
- Phase 1 is the keystone: it's the largest diff but each step is mechanical once the alias map is in place. No behaviour change for any consumer that was using `attn_implementation` correctly; behaviour change only for consumers that were reading legacy flags (phase 3 step 13 is the explicit decision point).
|
||||
- Phase 2 is independent of phase 1 and can land first as a quick safety net.
|
||||
- Phase 3 step 13 is the only judgment call — flag for review before choosing.
|
||||
- Total: ~10-13 commits beyond what's on the branch, each scoped and individually revertable.
|
||||
@@ -29,6 +29,9 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2026/04:
|
||||
- New model support has been added in Axolotl for [Mistral Medium 3.5](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral-medium-3_5) and [Gemma 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma4).
|
||||
- Axolotl is now [uv-first](https://github.com/axolotl-ai-cloud/axolotl/pull/3545) and has [SonicMoE fused LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3519) support.
|
||||
- 2026/03:
|
||||
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
||||
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
||||
|
||||
83
SETUP_MIAAI.md
Normal file
83
SETUP_MIAAI.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# Axolotl Setup — miaai (RTX 5080, CUDA 13.2)
|
||||
|
||||
## System Info
|
||||
- GPU: NVIDIA RTX 5080 (16GB VRAM)
|
||||
- Driver: 580.126.09 — max CUDA 13.0 (nvcc from conda resolves to 13.2)
|
||||
- OS: Ubuntu (Python 3.13 system — do NOT use system Python for ML)
|
||||
- Axolotl branch: `activeblue/main`
|
||||
|
||||
## One-time Setup
|
||||
|
||||
### 1. Install Miniconda
|
||||
```bash
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
|
||||
bash miniconda.sh -b -p /opt/miniconda3
|
||||
/opt/miniconda3/bin/conda init bash
|
||||
source ~/.bashrc
|
||||
```
|
||||
|
||||
### 2. Create Python 3.11 environment
|
||||
```bash
|
||||
conda create -n axolotl python=3.11 -y
|
||||
conda activate axolotl
|
||||
```
|
||||
|
||||
### 3. Clone and sync repo with upstream
|
||||
```bash
|
||||
git clone https://git.activeblue.net/tocmo0nlord/axolotl.git
|
||||
cd axolotl
|
||||
git remote add upstream https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
git fetch upstream
|
||||
git rebase upstream/main # keeps activeblue patches on top
|
||||
git push origin activeblue/main --force-with-lease
|
||||
```
|
||||
|
||||
### 4. Install CUDA toolkit (needed to compile flash-attn)
|
||||
```bash
|
||||
conda install -y -c "nvidia/label/cuda-12.8.0" cuda-toolkit
|
||||
export CUDA_HOME=$CONDA_PREFIX
|
||||
export PATH=$CUDA_HOME/bin:$PATH
|
||||
```
|
||||
|
||||
### 5. Install PyTorch — use cu132 (matches nvcc from conda)
|
||||
> NOTE: torchaudio has no cu132 wheel — skip it, not needed for LLM training
|
||||
```bash
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu132
|
||||
python -c "import torch; print('CUDA:', torch.version.cuda); print('GPU:', torch.cuda.get_device_name(0))"
|
||||
```
|
||||
|
||||
### 6. Install Axolotl
|
||||
```bash
|
||||
pip install -e "."
|
||||
```
|
||||
|
||||
> **flash-attn compiles CUDA kernels from source — takes 15–25 min on 10 cores of i7-14700K.**
|
||||
> Always set `MAX_JOBS` to the number of available CPU cores to parallelize and speed up compilation:
|
||||
```bash
|
||||
MAX_JOBS=10 pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
## Every Session (after first-time setup)
|
||||
```bash
|
||||
export PATH="/opt/miniconda3/bin:$PATH"
|
||||
conda activate axolotl
|
||||
export CUDA_HOME=$CONDA_PREFIX
|
||||
export PATH=$CUDA_HOME/bin:$PATH
|
||||
cd /home/tocmo0nlord/axolotl
|
||||
```
|
||||
|
||||
## Run Training
|
||||
```bash
|
||||
axolotl train human_chat_qlora.yml
|
||||
```
|
||||
|
||||
## Common Pitfalls Encountered
|
||||
| Problem | Cause | Fix |
|
||||
|---|---|---|
|
||||
| `externally-managed-environment` | System Python 3.13 blocks pip | Use conda env, never system pip |
|
||||
| `No module named torch` (flash-attn) | pip builds in isolated env | Use `--no-build-isolation` |
|
||||
| `CUDA_HOME not set` | CUDA toolkit not installed | `conda install cuda-toolkit` from nvidia channel |
|
||||
| `CUDA version mismatch 13.2 vs 12.8` | Conda nvcc is 13.2, torch was cu128 | Reinstall torch with `--index-url .../cu132` |
|
||||
| `torchaudio` not found for cu132 | No cu132 wheel exists | Skip torchaudio — not needed |
|
||||
| `src refspec main does not match` | Fork default branch is `activeblue/main` | `git push origin activeblue/main` |
|
||||
| flash-attn compile is slow | Single-threaded by default | Set `MAX_JOBS=<cpu_count>` before pip install |
|
||||
@@ -311,6 +311,7 @@ website:
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
- docs/1_58bit_finetuning.qmd
|
||||
- docs/optimizations.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
|
||||
@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
BASE_EXTRAS="optimizers,ray"; \
|
||||
else \
|
||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
BASE_EXTRAS="deepspeed,optimizers,ray"; \
|
||||
fi && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
|
||||
@@ -58,19 +58,3 @@ RUN git lfs install --skip-repo && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="x86_64" ;; \
|
||||
arm64) ARCH_TAG="aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
ARG CUDA_VERSION="12.8.1"
|
||||
ARG CUDNN_VERSION="8"
|
||||
ARG CUDA_VERSION="12.8.2"
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
FROM nvidia/cuda:12.8.2-devel-ubuntu22.04 AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/root/miniforge3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="next"
|
||||
ARG CUDA="128"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0 12.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
@@ -18,13 +17,13 @@ ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
&& bash Miniforge3-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniforge3-Linux-x86_64.sh \
|
||||
&& /root/miniforge3/bin/conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/root/miniforge3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
|
||||
@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
BASE_EXTRAS="optimizers,ray"; \
|
||||
else \
|
||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
BASE_EXTRAS="deepspeed,optimizers,ray"; \
|
||||
fi && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
|
||||
@@ -38,20 +38,3 @@ RUN uv pip install packaging setuptools wheel psutil \
|
||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
||||
fi
|
||||
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
LINUX_TAG="manylinux_" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
|
||||
arm64) ARCH_TAG="2_34_aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
|
||||
70
docs/1_58bit_finetuning.qmd
Normal file
70
docs/1_58bit_finetuning.qmd
Normal file
@@ -0,0 +1,70 @@
|
||||
---
|
||||
title: "1.58-bit Finetuning"
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
1.58-bit finetuning allows you to finetune BitNet models when their prequantized weights are provided. In theory, it will be possible to fine-tune any LLM in 1.58bit format but the performance degradation will be dramatic.
|
||||
|
||||
Axolotl supports 1.58-bit finetuning via the [`onebitllms`](https://github.com/tiiuae/onebitllms) library, which replaces standard linear layers with BitNet-compatible counterparts ready to use for training.
|
||||
|
||||
::: {.callout-note}
|
||||
LoRA is not supported for BitNet models
|
||||
:::
|
||||
|
||||
## Installation
|
||||
|
||||
Install the `onebitllms` package before using this feature:
|
||||
|
||||
```bash
|
||||
uv pip install onebitllms
|
||||
```
|
||||
|
||||
Or from source:
|
||||
|
||||
```bash
|
||||
uv pip install git+https://github.com/tiiuae/onebitllms
|
||||
```
|
||||
|
||||
## Supported models
|
||||
|
||||
For now, only `Falcon-E` series of models are supported. Make sure to use their `-prequantized` version:
|
||||
|
||||
```bash
|
||||
tiiuae/Falcon-E-3B-Base-prequantized
|
||||
tiiuae/Falcon-E-1B-Base-prequantized
|
||||
```
|
||||
|
||||
In theory, any other model would 'work' but the performance degradation will be huge. This remains an area of exploration.
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable 1.58-bit finetuning, set the following in your configuration file:
|
||||
|
||||
```yaml
|
||||
base_model: tiiuae/Falcon-E-3B-Base-prequantized # A BitNet-compatible model
|
||||
|
||||
use_onebitllms: true
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
For BitNet models, it is recommended to use a higher learning rate than classic models (usually in the order of magnitude of 10x).
|
||||
:::
|
||||
|
||||
## Considerations after training
|
||||
|
||||
Once your model has been trained with 1.58bit fine-tuning, you can convert the trained model in ternary format using the `onebitllms` CLI:
|
||||
|
||||
```bash
|
||||
onebitllms quantize_to_1bit INPUT_PATH OUTPUT_PATH
|
||||
```
|
||||
|
||||
After that, you can use supported packages such as `llama.cpp` or Apple MLX package to run the trained model.
|
||||
|
||||
## Example Configuration
|
||||
|
||||
You can find example configurations in `examples/falcon-e` which contain one configuration for SFT and one configuration for DPO.
|
||||
@@ -97,8 +97,10 @@ python setup.py install
|
||||
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
|
||||
is true and FA4 is importable.
|
||||
|
||||
FA4 is still a pre-release on PyPI, so `--pre` is required:
|
||||
|
||||
```bash
|
||||
pip install flash-attn-4
|
||||
pip install --pre flash-attn-4
|
||||
```
|
||||
|
||||
Or from source:
|
||||
|
||||
@@ -77,8 +77,9 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
|
||||
|
||||
```bash
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||
```
|
||||
|
||||
#### Remote Hosts
|
||||
@@ -218,8 +219,9 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
You will now be in the container. Next, install Axolotl with dev dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||
```
|
||||
|
||||
### Attach To Container
|
||||
|
||||
@@ -10,13 +10,16 @@ This section describes the different Docker images that are released by AxolotlA
|
||||
[Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
|
||||
:::
|
||||
### Switch to the `-uv` images
|
||||
|
||||
::: {.callout-tip}
|
||||
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
|
||||
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
||||
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
|
||||
Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
|
||||
(`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
||||
(e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
|
||||
same format as their non-uv counterparts.
|
||||
|
||||
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
|
||||
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
|
||||
the uv image.
|
||||
:::
|
||||
|
||||
## Base
|
||||
@@ -85,7 +88,7 @@ Tags examples:
|
||||
- `main-py3.12-cu130-2.10.0`
|
||||
- `main-latest`
|
||||
- `main-20260315-py3.11-cu128-2.9.1`
|
||||
- `0.12.0`
|
||||
- `0.16.1`
|
||||
|
||||
## Cloud
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ description: Frequently asked questions
|
||||
|
||||
**Q: vLLM is not working with Axolotl**
|
||||
|
||||
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
|
||||
> A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
|
||||
|
||||
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
|
||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.11
|
||||
- PyTorch ≥2.9.0
|
||||
- PyTorch ≥2.9.1
|
||||
|
||||
## Installation {#sec-installation}
|
||||
|
||||
@@ -36,9 +36,9 @@ source $HOME/.local/bin/env
|
||||
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
|
||||
```{.bash}
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv venv --no-project --relocatable
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||
```
|
||||
|
||||
### Edge/Development Build {#sec-edge-build}
|
||||
@@ -49,12 +49,11 @@ For the latest features between releases:
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install --no-build-isolation -e '.[deepspeed]'
|
||||
```
|
||||
|
||||
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
|
||||
|
||||
### Docker {#sec-docker}
|
||||
|
||||
```{.bash}
|
||||
@@ -132,11 +131,11 @@ source $HOME/.local/bin/env
|
||||
|
||||
# Create a fresh venv (recommended for a clean start)
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv venv --no-project --relocatable
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# Reinstall axolotl
|
||||
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||
```
|
||||
|
||||
## Using pip (Alternative) {#sec-pip}
|
||||
@@ -151,13 +150,13 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
|
||||
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
pip3 install --no-build-isolation axolotl[deepspeed]
|
||||
```
|
||||
|
||||
For editable/development installs:
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
pip3 install --no-build-isolation -e '.[deepspeed]'
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
84
docs/multimodal_assistant_mask.md
Normal file
84
docs/multimodal_assistant_mask.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# Multimodal assistant-only loss masking
|
||||
|
||||
## Correct placement
|
||||
|
||||
```yaml
|
||||
# Top-level: only train_on_inputs lives here.
|
||||
train_on_inputs: false
|
||||
|
||||
datasets:
|
||||
- path: data/train.jsonl
|
||||
type: chat_template
|
||||
roles_to_train: # per-dataset — this is what the MM scanner reads
|
||||
- assistant
|
||||
train_on_eos: turn # per-dataset — same
|
||||
|
||||
test_datasets:
|
||||
- path: data/val.jsonl
|
||||
type: chat_template
|
||||
split: train
|
||||
roles_to_train:
|
||||
- assistant
|
||||
train_on_eos: turn
|
||||
```
|
||||
|
||||
## How to verify at runtime
|
||||
|
||||
`build_collator` logs the resolved knobs at INFO:
|
||||
|
||||
```text
|
||||
MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
|
||||
```
|
||||
|
||||
If `roles_to_train` logs as `None`, the YAML knobs are not reaching the
|
||||
scanner — check that they are under `datasets[0]`, not at the root.
|
||||
|
||||
Each verified strategy additionally logs its resolved boundary token ids at
|
||||
strategy init (e.g. `<|turn>model` → `[105, 4368]`, `<turn|>` → `[106]` for
|
||||
Gemma 4). If a strategy emits the "has no built-in role boundaries ... only
|
||||
pad and media tokens are masked" one-shot warning instead, it is on the
|
||||
fallback path — declare per-role markers in YAML via `cfg.role_boundaries`
|
||||
(below) to activate masking. The strategies currently on this path are
|
||||
listed in the audit table above under `fallback + warn`.
|
||||
|
||||
## Config-based override: `cfg.role_boundaries`
|
||||
|
||||
For the "unverified" strategies above, or for custom chat templates that
|
||||
don't match a built-in strategy's markers, users can declare role boundaries
|
||||
directly in YAML without subclassing:
|
||||
|
||||
```yaml
|
||||
role_boundaries:
|
||||
- role: assistant
|
||||
start: "<|turn>model"
|
||||
end: "<turn|>"
|
||||
- role: user
|
||||
start: "<|turn>user"
|
||||
end: "<turn|>"
|
||||
# Optional keys:
|
||||
# include_start: false # default False
|
||||
# include_end: true # default True, respects cfg.train_on_eos
|
||||
# end: eos_token # sentinel: resolves to tokenizer.eos_token_id
|
||||
# end: null # span runs to end of sequence
|
||||
```
|
||||
|
||||
Semantics:
|
||||
|
||||
- `start` and `end` are literal strings; axolotl encodes them at strategy
|
||||
init via `tokenizer.encode(..., add_special_tokens=False)` and logs the
|
||||
resolved token-id sequences at INFO level.
|
||||
- The special value `end: eos_token` is the portable way to express
|
||||
"Pixtral-style assistant turns end at EOS" without hard-coding an id.
|
||||
- `role_boundaries` is an **opt-in override**. A non-empty list **replaces**
|
||||
the strategy's built-in declarations wholesale (partial overlays are
|
||||
intentionally unsupported — they're hard to reason about at review time).
|
||||
Leaving the field unset *or* setting it to an empty list (`[]`) both mean
|
||||
"use the strategy's built-ins." Writing `role_boundaries: []` is almost
|
||||
always a typo or leftover — honoring it literally would produce all-masked
|
||||
labels and zero gradient, so it is treated the same as unset.
|
||||
- `cfg.roles_to_train` still governs which declared roles contribute to
|
||||
loss. You can declare `user` and `assistant` boundaries and set
|
||||
`roles_to_train: ["assistant"]` to have the scanner correctly identify
|
||||
user spans as masking boundaries without training on their content.
|
||||
- Invalid specs fail loudly at strategy init (missing `role`/`start`,
|
||||
unencodable markers), not silently at loss-compute time.
|
||||
@@ -20,6 +20,8 @@ examples:
|
||||
title: Arcee AFM
|
||||
|
||||
# MistralAI
|
||||
- name: mistral-medium-3_5
|
||||
title: Mistral Medium 3.5
|
||||
- name: ministral3/think
|
||||
title: Ministral 3 Thinking
|
||||
- name: ministral3/vision
|
||||
|
||||
@@ -15,7 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||
```
|
||||
|
||||
2. Run one of the finetuning examples below.
|
||||
|
||||
@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -13,11 +13,11 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -36,12 +36,7 @@
|
||||
"id": "msOCO4NRmRLa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
||||
]
|
||||
"source": "%%capture\n# This step can take ~5-10 minutes to install dependencies\n!pip install --no-build-isolation \"axolotl>=0.16.1\"\n!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
||||
@@ -15,8 +15,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||
```
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
@@ -52,7 +51,7 @@ gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
attn_implementation: flash_attention_2
|
||||
scaling_softmax: true
|
||||
# scaling_softmax: true # needs flex_attention
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
93
examples/falcon-e/falcon-e-3b-dpo.yaml
Normal file
93
examples/falcon-e/falcon-e-3b-dpo.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
base_model: axolotl-ai-co/Falcon-E-1.2-3B-Exp-prequantized
|
||||
output_dir: ./output
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
|
||||
use_kernels: false
|
||||
use_scattermoe: false
|
||||
use_sonicmoe: false
|
||||
use_onebitllms: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
chat_template: tokenizer_default
|
||||
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: allenai/Dolci-Think-DPO-7B
|
||||
split: train
|
||||
type: chatml.ultra
|
||||
|
||||
dataset_prepared_path: ./axolotl_dataset_cache
|
||||
|
||||
sequence_len: 8192
|
||||
trust_remote_code: false
|
||||
|
||||
gradient_accumulation_steps: 4 # This can run on 4 GPUs
|
||||
|
||||
# Very important to enable gradient accumulation with FSDP
|
||||
# https://github.com/huggingface/transformers/issues/29425
|
||||
accelerator_config:
|
||||
gradient_accumulation_kwargs:
|
||||
sync_each_batch: True
|
||||
|
||||
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1.0e-5
|
||||
# adamw hyperparams
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.95
|
||||
|
||||
bf16: true
|
||||
tf32: false
|
||||
|
||||
logging_steps: 1
|
||||
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 15.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 128
|
||||
evals_per_epoch: 0
|
||||
|
||||
save_steps: 500
|
||||
save_strategy: steps
|
||||
|
||||
weight_decay: 0.01
|
||||
|
||||
shuffle_merged_datasets: true
|
||||
experimental_skip_move_to_device: true
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
# Comment to disable CP
|
||||
# The number of GPUs to shard the model parameters across (FSDP dimension).
|
||||
dp_shard_size: 1
|
||||
|
||||
# The number of times to replicate the sharded model (DDP dimension).
|
||||
dp_replicate_size: 1
|
||||
|
||||
# Number of GPUs for Tensor Parallelism.
|
||||
tensor_parallel_size: 1 # (default is 1, no TP)
|
||||
|
||||
# Number of GPUs for Context/Sequence Parallelism.
|
||||
context_parallel_size: 1 # (default is 1, no CP)
|
||||
|
||||
special_tokens:
|
||||
eos_token: <|end_of_text|>
|
||||
|
||||
eot_tokens:
|
||||
- <|im_end|>
|
||||
100
examples/falcon-e/falcon-e-3b-ft.yaml
Normal file
100
examples/falcon-e/falcon-e-3b-ft.yaml
Normal file
@@ -0,0 +1,100 @@
|
||||
base_model: tiiuae/Falcon-E-3B-Base-prequantized
|
||||
output_dir: ./output
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
|
||||
use_kernels: false
|
||||
use_scattermoe: false
|
||||
use_sonicmoe: false
|
||||
use_onebitllms: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
|
||||
chat_template: tokenizer_default
|
||||
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: ./axolotl_dataset_cache
|
||||
|
||||
sequence_len: 32768
|
||||
trust_remote_code: false
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4 # This can run on 4 GPUs
|
||||
|
||||
# Very important to enable gradient accumulation with FSDP
|
||||
# https://github.com/huggingface/transformers/issues/29425
|
||||
accelerator_config:
|
||||
gradient_accumulation_kwargs:
|
||||
sync_each_batch: True
|
||||
|
||||
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5.0e-4
|
||||
# adamw hyperparams
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.95
|
||||
|
||||
bf16: true
|
||||
tf32: false
|
||||
|
||||
logging_steps: 1
|
||||
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 15.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 128
|
||||
evals_per_epoch: 0
|
||||
|
||||
save_steps: 500
|
||||
save_strategy: steps
|
||||
|
||||
weight_decay: 0.01
|
||||
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
shuffle_merged_datasets: true
|
||||
experimental_skip_move_to_device: true
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
|
||||
# Comment to disable CP
|
||||
# The number of GPUs to shard the model parameters across (FSDP dimension).
|
||||
dp_shard_size: 1
|
||||
|
||||
# The number of times to replicate the sharded model (DDP dimension).
|
||||
dp_replicate_size: 1
|
||||
|
||||
# Number of GPUs for Tensor Parallelism.
|
||||
tensor_parallel_size: 1 # (default is 1, no TP)
|
||||
|
||||
# Number of GPUs for Context/Sequence Parallelism.
|
||||
context_parallel_size: 1 # (default is 1, no CP)
|
||||
|
||||
special_tokens:
|
||||
eos_token: <|end_of_text|>
|
||||
|
||||
eot_tokens:
|
||||
- <|im_end|>
|
||||
@@ -9,8 +9,8 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||
```
|
||||
|
||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||
|
||||
@@ -13,8 +13,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||
```
|
||||
|
||||
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
||||
|
||||
@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.7.1 min)
|
||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -9,11 +9,11 @@ Tencent released a family of opensource models called HunYuan with varying param
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -13,8 +13,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||
```
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
@@ -59,7 +59,7 @@ gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
attn_implementation: flash_attention_2
|
||||
scaling_softmax: true
|
||||
# scaling_softmax: true # needs flex_attention
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
78
examples/mistral-medium-3_5/README.md
Normal file
78
examples/mistral-medium-3_5/README.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# Finetune Mistral Medium 3.5 with Axolotl
|
||||
|
||||
[Mistral Medium 3.5](https://huggingface.co/mistralai/Mistral-Medium-3.5-128B) is a 128B parameter dense multimodal model from MistralAI that unifies instruct, reasoning, and agentic capabilities into a single model.
|
||||
It shares the `mistral3` architecture (dense, YaRN RoPE, 256k context) with Ministral 3 and supports the same `reasoning_effort` toggle as Mistral Small 4.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. (Text config only) Install [Flash Attention 4](https://docs.axolotl.ai/docs/attention.html#flash-attention-4) on Hopper/Blackwell.
|
||||
|
||||
4. Run one of the example configs:
|
||||
|
||||
```bash
|
||||
# text-only
|
||||
axolotl train examples/mistral-medium-3_5/qlora-text.yml # ~83.1 GiB
|
||||
|
||||
# text + vision
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
axolotl train examples/mistral-medium-3_5/qlora-vision.yml # ~80.3 GiB
|
||||
```
|
||||
|
||||
Note: vision training does not currently work with Flash Attention 4.
|
||||
|
||||
## Reasoning Effort
|
||||
|
||||
The chat template supports a `reasoning_effort` variable to control the model's reasoning depth:
|
||||
|
||||
- `"none"` — instruct mode (default)
|
||||
- `"high"` — reasoning mode with explicit thinking steps
|
||||
|
||||
Pass it via `chat_template_kwargs` under your dataset config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/dataset
|
||||
type: chat_template
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
```
|
||||
|
||||
## Thinking Support
|
||||
|
||||
The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).
|
||||
|
||||
To use thinking datasets, add the `thinking` mapping via `message_property_mappings`:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/thinking-dataset
|
||||
type: chat_template
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
thinking: thinking
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
```
|
||||
|
||||
See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.
|
||||
|
||||
## Tips
|
||||
|
||||
- For smaller experiments on the same architecture, see [`examples/ministral3`](../ministral3/README.md) (Ministral 3, 3B).
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Mistral Medium 3.5 Blog](https://mistral.ai/news/vibe-remote-agents-mistral-medium-3-5)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
56
examples/mistral-medium-3_5/qlora-text.yml
Normal file
56
examples/mistral-medium-3_5/qlora-text.yml
Normal file
@@ -0,0 +1,56 @@
|
||||
base_model: axolotl-ai-co/Mistral-Medium-3.5-128B-BF16
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir:
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
# prevents targeting vision layers
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
61
examples/mistral-medium-3_5/qlora-vision.yml
Normal file
61
examples/mistral-medium-3_5/qlora-vision.yml
Normal file
@@ -0,0 +1,61 @@
|
||||
base_model: axolotl-ai-co/Mistral-Medium-3.5-128B-BF16
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir:
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
max_steps: 10
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
@@ -11,7 +11,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||
|
||||
# Install Cut Cross Entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -13,7 +13,7 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||
```
|
||||
|
||||
2. Install an extra dependency:
|
||||
|
||||
@@ -11,8 +11,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||
```
|
||||
|
||||
2. Please install the below.
|
||||
|
||||
92
human_chat_qlora.yml
Normal file
92
human_chat_qlora.yml
Normal file
@@ -0,0 +1,92 @@
|
||||
# Llama 3.1 8B — Human-like LoRA fine-tune (HQQ quantization)
|
||||
#
|
||||
# Goal: natural, warm conversation; never corrects user errors; direct responses
|
||||
# Hardware: single RTX 5080 (16 GB VRAM)
|
||||
# Method: LoRA on HQQ 4-bit quantized base (bypasses bitsandbytes — RTX 5080 compatible)
|
||||
#
|
||||
# Prerequisites:
|
||||
# pip install -e '.[flash-attn]' (inside your axolotl repo)
|
||||
# huggingface-cli login (meta-llama is a gated model)
|
||||
#
|
||||
# Run:
|
||||
# axolotl train human_chat_qlora.yml
|
||||
# axolotl merge-lora human_chat_qlora.yml # (optional) merge adapter into base
|
||||
|
||||
base_model: meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
# HQQ quantization — no bitsandbytes required, works on RTX 5080 (sm_120)
|
||||
quant_method: hqq
|
||||
strict: false
|
||||
trust_remote_code: true
|
||||
torch_dtype: bfloat16
|
||||
|
||||
# --- System prompt baked into every conversation ---
|
||||
# This is the primary lever for "no error correcting, more human-like"
|
||||
chat_template: llama3
|
||||
default_system_message: >-
|
||||
You are a direct, warm, and genuinely helpful assistant.
|
||||
Respond to the user's intent naturally — never comment on typos, grammar,
|
||||
or phrasing issues in their message. Just understand what they mean and give
|
||||
a clear, useful, conversational answer as if talking to a knowledgeable friend.
|
||||
|
||||
# --- Datasets ---
|
||||
# Both use ShareGPT format: conversations field, from/value keys
|
||||
datasets:
|
||||
- path: Open-Orca/SlimOrca
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
split: "train[:3%]"
|
||||
|
||||
- path: teknium/OpenHermes-2.5
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
split: "train[:5%]"
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/llama31-8b-humanchat
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
# --- LoRA adapter (on top of HQQ quantized base) ---
|
||||
adapter: lora
|
||||
lora_r: 64
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
# --- Training hyperparameters ---
|
||||
# Effective batch = micro_batch_size x gradient_accumulation = 2 x 4 = 8
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-4
|
||||
warmup_ratio: 0.05
|
||||
weight_decay: 0.1
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
# --- Memory & speed ---
|
||||
gradient_checkpointing: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# --- Logging & checkpointing ---
|
||||
logging_steps: 10
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|eot_id|>"
|
||||
@@ -12,7 +12,7 @@ requires-python = ">=3.10"
|
||||
|
||||
dependencies = [
|
||||
# Core ML stack
|
||||
"torch>=2.6.0",
|
||||
"torch>=2.9.1",
|
||||
"packaging==26.0",
|
||||
"huggingface_hub>=1.1.7",
|
||||
"peft>=0.19.1,<0.20.0",
|
||||
@@ -79,7 +79,7 @@ dependencies = [
|
||||
# Platform-specific (Linux only)
|
||||
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
|
||||
"triton>=3.4.0 ; sys_platform != 'darwin'",
|
||||
"xformers>=0.0.23.post1 ; sys_platform != 'darwin'",
|
||||
"xformers>=0.0.33.post2 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
||||
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
|
||||
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
||||
|
||||
|
||||
@@ -286,10 +286,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
if self.cfg.relora and self.cfg.jagged_restart_steps:
|
||||
if self.cfg.relora_prune_ratio:
|
||||
if self.cfg.relora_prune_ratio is not None:
|
||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||
self.cfg.relora_prune_ratio
|
||||
)
|
||||
if self.cfg.relora_prune_method:
|
||||
training_arguments_kwargs["relora_prune_method"] = (
|
||||
self.cfg.relora_prune_method
|
||||
)
|
||||
|
||||
if self.cfg.jagged_restart_steps:
|
||||
training_arguments_kwargs["jagged_restart_steps"] = (
|
||||
@@ -515,12 +519,53 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
if self.cfg.processor_type and self.processor:
|
||||
collator = MultiModalChatDataCollator
|
||||
# Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg.
|
||||
# NOTE: Multi-dataset configs use the first dataset's masking knobs for all datasets;
|
||||
# heterogeneous per-dataset overrides are not supported in the MM path today.
|
||||
ds_entries = self.cfg.datasets or []
|
||||
ds_cfg = ds_entries[0] if ds_entries else None
|
||||
|
||||
def _ds_get(cfg_obj, key):
|
||||
# Handle DictDefault / dict / pydantic uniformly:
|
||||
# dict-style .get first, then attribute access.
|
||||
if cfg_obj is None:
|
||||
return None
|
||||
if hasattr(cfg_obj, "get"):
|
||||
try:
|
||||
return cfg_obj.get(key)
|
||||
except (AttributeError, KeyError, TypeError):
|
||||
pass
|
||||
return getattr(cfg_obj, key, None)
|
||||
|
||||
roles_to_train = _ds_get(ds_cfg, "roles_to_train")
|
||||
train_on_eos = _ds_get(ds_cfg, "train_on_eos")
|
||||
|
||||
# cfg.role_boundaries replaces the strategy's built-in markers.
|
||||
role_boundaries_override = None
|
||||
if self.cfg.role_boundaries:
|
||||
role_boundaries_override = list(self.cfg.role_boundaries)
|
||||
|
||||
# build() calls build_collator twice (eval + train); log once.
|
||||
if not is_eval:
|
||||
LOG.info(
|
||||
"MM collator: train_on_inputs=%s roles_to_train=%s "
|
||||
"train_on_eos=%s role_boundaries_override=%s",
|
||||
bool(self.cfg.train_on_inputs),
|
||||
roles_to_train,
|
||||
train_on_eos,
|
||||
"set" if role_boundaries_override else "none",
|
||||
)
|
||||
|
||||
kwargs["processing_strategy"] = get_processing_strategy(
|
||||
self.processor,
|
||||
training_args.chat_template,
|
||||
self.cfg.chat_template,
|
||||
image_size=training_args.image_size,
|
||||
image_resize_algorithm=training_args.image_resize_algorithm,
|
||||
train_on_inputs=bool(self.cfg.train_on_inputs),
|
||||
roles_to_train=roles_to_train,
|
||||
train_on_eos=train_on_eos,
|
||||
role_boundaries_override=role_boundaries_override,
|
||||
)
|
||||
elif self.cfg.batch_flattening:
|
||||
collator = DataCollatorWithFlattening
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@@ -800,7 +801,14 @@ class AxolotlTrainer(
|
||||
with open(tokens_state_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokens_state, f)
|
||||
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
result = super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
# Reclaim VRAM held by the FSDP full-state-dict gather.
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return result
|
||||
|
||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
|
||||
@@ -83,13 +83,18 @@ class AxolotlTrainingMixins:
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for data processing"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"prune ratio for optimizer state pruning; "
|
||||
"defaults to 0.999 for reset method, 0.9 for others"
|
||||
)
|
||||
},
|
||||
)
|
||||
relora_prune_method: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "optimizer state pruning method: magnitude | random | reset"},
|
||||
)
|
||||
jagged_restart_steps: Optional[int] = field(
|
||||
default=None,
|
||||
|
||||
@@ -23,9 +23,10 @@ from __future__ import annotations
|
||||
import collections
|
||||
import importlib
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||
|
||||
from peft import PeftModel
|
||||
from peft import PeftConfig, PeftMixedModel, PeftModel
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
@@ -41,6 +42,15 @@ if TYPE_CHECKING:
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AdapterCapabilities:
|
||||
"""Capabilities for an adapter contributed by a plugin."""
|
||||
|
||||
name: str
|
||||
lora_like: bool = False
|
||||
relora: bool = False
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
"""Base class for all plugins. Defines the interface for plugin methods.
|
||||
|
||||
@@ -91,6 +101,26 @@ class BasePlugin:
|
||||
Returns a dataclass model for the plugin's training arguments.
|
||||
"""
|
||||
|
||||
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
|
||||
"""Returns adapter capabilities contributed by the plugin."""
|
||||
return []
|
||||
|
||||
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
|
||||
"""Returns extra PEFT LoraConfig kwargs for plugin LoRA-like adapters."""
|
||||
return {}
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
config_only: bool = False,
|
||||
) -> (
|
||||
tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]
|
||||
| None
|
||||
):
|
||||
"""Optionally load a plugin adapter instead of the generic loader."""
|
||||
|
||||
def load_datasets(
|
||||
self, cfg: DictDefault, preprocess: bool = False
|
||||
) -> Union["TrainDatasetMeta", None]:
|
||||
@@ -414,6 +444,58 @@ class PluginManager:
|
||||
training_args.append(training_args_from_plugin)
|
||||
return training_args
|
||||
|
||||
def adapter_capabilities(self) -> dict[str, AdapterCapabilities]:
|
||||
"""Returns adapter capabilities by adapter name."""
|
||||
capabilities = {}
|
||||
for plugin in self.plugins.values():
|
||||
for adapter_capability in plugin.get_adapter_capabilities():
|
||||
capabilities[adapter_capability.name] = adapter_capability
|
||||
return capabilities
|
||||
|
||||
def get_adapter_capability(self, adapter: str) -> AdapterCapabilities | None:
|
||||
"""Returns capabilities for a registered plugin adapter."""
|
||||
return self.adapter_capabilities().get(adapter)
|
||||
|
||||
def supports_adapter(self, adapter: str) -> bool:
|
||||
"""Returns whether a plugin has registered the adapter name."""
|
||||
return adapter in self.adapter_capabilities()
|
||||
|
||||
def adapter_supports_relora(self, adapter: str) -> bool:
|
||||
"""Returns whether a plugin adapter supports ReLoRA restart semantics."""
|
||||
capability = self.get_adapter_capability(adapter)
|
||||
return bool(capability and capability.relora)
|
||||
|
||||
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
|
||||
"""Returns extra LoraConfig kwargs from plugins for the configured adapter."""
|
||||
lora_config_kwargs = {}
|
||||
for plugin in self.plugins.values():
|
||||
plugin_kwargs = plugin.get_lora_config_kwargs(cfg)
|
||||
if plugin_kwargs:
|
||||
lora_config_kwargs.update(plugin_kwargs)
|
||||
return lora_config_kwargs
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
config_only: bool = False,
|
||||
) -> (
|
||||
tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]
|
||||
| None
|
||||
):
|
||||
"""Returns the first plugin adapter loader result, if any."""
|
||||
for plugin in self.plugins.values():
|
||||
loaded = plugin.load_adapter(
|
||||
model,
|
||||
cfg,
|
||||
inference=inference,
|
||||
config_only=config_only,
|
||||
)
|
||||
if loaded is not None:
|
||||
return loaded
|
||||
return None
|
||||
|
||||
def load_datasets(
|
||||
self, cfg: DictDefault, preprocess: bool = False
|
||||
) -> Union["TrainDatasetMeta", None]:
|
||||
|
||||
@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
|
||||
kd_alpha: 0.9
|
||||
kd_temperature: 1.0
|
||||
|
||||
torch_compile: True # torch>=2.6.0, recommended to reduce vram
|
||||
torch_compile: True # recommended to reduce vram
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
|
||||
6
src/axolotl/integrations/mora/__init__.py
Normal file
6
src/axolotl/integrations/mora/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""MoRA / ReMoRA integration for Axolotl."""
|
||||
|
||||
from .args import MoraArgs, MoraConfig, MoraType
|
||||
from .plugin import MoraPlugin
|
||||
|
||||
__all__ = ["MoraArgs", "MoraConfig", "MoraPlugin", "MoraType"]
|
||||
66
src/axolotl/integrations/mora/args.py
Normal file
66
src/axolotl/integrations/mora/args.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Config args for MoRA / ReMoRA."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class MoraType(str, Enum):
|
||||
"""MoRA variants supported by the reference implementation."""
|
||||
|
||||
SHARING = "sharing"
|
||||
ROPE = "rope"
|
||||
|
||||
@property
|
||||
def peft_value(self) -> int:
|
||||
return {
|
||||
MoraType.SHARING: 1,
|
||||
MoraType.ROPE: 6,
|
||||
}[self]
|
||||
|
||||
|
||||
class MoraConfig(BaseModel):
|
||||
"""Nested MoRA configuration available under the `mora` key."""
|
||||
|
||||
use_mora: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Enable MoRA adapter construction. Requires a PEFT build with MoRA "
|
||||
"support (for example, the MoRA fork)."
|
||||
),
|
||||
)
|
||||
mora_type: MoraType = Field(
|
||||
default=MoraType.ROPE,
|
||||
description=(
|
||||
"MoRA variant selector. Supported values are `sharing` for type 1 "
|
||||
"and `rope` for type 6. Numeric values 1 and 6 are accepted for "
|
||||
"backwards compatibility."
|
||||
),
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_mora_type(cls, data):
|
||||
if not isinstance(data, dict) or "mora_type" not in data:
|
||||
return data
|
||||
data = data.copy()
|
||||
mora_type = data["mora_type"]
|
||||
if mora_type == 1:
|
||||
data["mora_type"] = MoraType.SHARING
|
||||
elif mora_type == 6:
|
||||
data["mora_type"] = MoraType.ROPE
|
||||
return data
|
||||
|
||||
|
||||
class MoraArgs(BaseModel):
|
||||
"""Plugin entry that exposes the nested `mora` block to the core config."""
|
||||
|
||||
mora: MoraConfig = Field(
|
||||
default_factory=MoraConfig,
|
||||
description=(
|
||||
"MoRA / ReMoRA training configuration. Register the "
|
||||
"`axolotl.integrations.mora.MoraPlugin` plugin to enable this block."
|
||||
),
|
||||
)
|
||||
97
src/axolotl/integrations/mora/plugin.py
Normal file
97
src/axolotl/integrations/mora/plugin.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""MoRA / ReMoRA plugin for Axolotl."""
|
||||
|
||||
import inspect
|
||||
|
||||
from peft import LoraConfig, PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.integrations.base import AdapterCapabilities, BasePlugin
|
||||
from axolotl.integrations.mora.args import MoraType
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _peft_supports_mora() -> bool:
|
||||
try:
|
||||
params = inspect.signature(LoraConfig).parameters
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return "use_mora" in params and "mora_type" in params
|
||||
|
||||
|
||||
def _mora_type_peft_value(mora_type: MoraType | str | int) -> int:
|
||||
if isinstance(mora_type, MoraType):
|
||||
return mora_type.peft_value
|
||||
if mora_type == 1 or mora_type == MoraType.SHARING.value:
|
||||
return MoraType.SHARING.peft_value
|
||||
if mora_type == 6 or mora_type == MoraType.ROPE.value:
|
||||
return MoraType.ROPE.peft_value
|
||||
raise ValueError("mora_type must be one of `sharing`, `rope`, 1, or 6")
|
||||
|
||||
|
||||
def _mora_type_label(mora_type: MoraType | str | int) -> str:
|
||||
if isinstance(mora_type, MoraType):
|
||||
return mora_type.value
|
||||
if mora_type == 1:
|
||||
return MoraType.SHARING.value
|
||||
if mora_type == 6:
|
||||
return MoraType.ROPE.value
|
||||
return str(mora_type)
|
||||
|
||||
|
||||
class MoraPlugin(BasePlugin):
|
||||
"""Plugin that exposes MoRA-specific config and validates runtime support."""
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
return "axolotl.integrations.mora.MoraArgs"
|
||||
|
||||
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
|
||||
return [AdapterCapabilities(name="mora", lora_like=True, relora=True)]
|
||||
|
||||
def _validate_mora_config(self, cfg: DictDefault):
|
||||
mora_cfg = getattr(cfg, "mora", None)
|
||||
if mora_cfg is None:
|
||||
raise ValueError("adapter: mora requires a nested mora configuration block")
|
||||
if not getattr(mora_cfg, "use_mora", False):
|
||||
raise ValueError("mora.use_mora must be true when adapter: mora is set")
|
||||
if cfg.load_in_4bit or cfg.load_in_8bit:
|
||||
raise ValueError(
|
||||
"adapter: mora currently requires a full-precision base model. "
|
||||
"Use adapter: lora or qlora for quantized training."
|
||||
)
|
||||
if cfg.gptq:
|
||||
raise ValueError(
|
||||
"adapter: mora is not compatible with GPTQ quantized base models."
|
||||
)
|
||||
|
||||
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
|
||||
if cfg.adapter != "mora":
|
||||
return {}
|
||||
self._validate_mora_config(cfg)
|
||||
if not _peft_supports_mora():
|
||||
raise ImportError(
|
||||
"adapter: mora requires a PEFT build with MoRA support "
|
||||
"(LoraConfig(use_mora=..., mora_type=...)). "
|
||||
"Install the MoRA fork or another PEFT distribution that exposes "
|
||||
"those fields."
|
||||
)
|
||||
mora_cfg = cfg.mora
|
||||
return {
|
||||
"use_mora": mora_cfg.use_mora,
|
||||
"mora_type": _mora_type_peft_value(mora_cfg.mora_type),
|
||||
}
|
||||
|
||||
def pre_model_load(self, cfg: DictDefault):
|
||||
if cfg.adapter != "mora":
|
||||
return
|
||||
LOG.info("MoRA plugin enabled for adapter: mora")
|
||||
|
||||
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
if cfg.adapter == "mora" and getattr(cfg, "mora", None):
|
||||
LOG.debug(
|
||||
"Loaded MoRA model with mora_type=%s, relora=%s",
|
||||
_mora_type_label(cfg.mora.mora_type),
|
||||
cfg.relora,
|
||||
)
|
||||
@@ -19,12 +19,14 @@ from peft import (
|
||||
)
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
|
||||
def setup_quantized_meta_for_peft(model: torch.nn.Module):
|
||||
@@ -124,6 +126,76 @@ def _patch_peft_clippable_linear():
|
||||
LoraModel._axolotl_clippable_patched = True
|
||||
|
||||
|
||||
def _get_peft_task_type(model: PreTrainedModel) -> TaskType:
|
||||
model_cls = type(model).__name__
|
||||
if "SequenceClassification" in model_cls:
|
||||
return TaskType.SEQ_CLS
|
||||
if "TokenClassification" in model_cls:
|
||||
return TaskType.TOKEN_CLS
|
||||
return TaskType.CAUSAL_LM
|
||||
|
||||
|
||||
def _build_lora_config_kwargs(cfg: DictDefault) -> dict[str, Any]:
|
||||
lora_config_kwargs: dict[str, Any] = {}
|
||||
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if loftq_bits:
|
||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
if cfg.peft_init_lora_weights:
|
||||
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
|
||||
if cfg.peft_use_dora:
|
||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||
if cfg.peft_use_rslora:
|
||||
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
||||
if cfg.peft_layer_replication:
|
||||
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
||||
if cfg.peft_trainable_token_indices:
|
||||
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
|
||||
if cfg.peft_ensure_weight_tying is not None:
|
||||
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
|
||||
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
def _build_peft_lora_config(
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
) -> PeftConfig:
|
||||
lora_target_modules = cfg.lora_target_modules or []
|
||||
lora_target_parameters = cfg.lora_target_parameters or []
|
||||
|
||||
if cfg.lora_target_linear:
|
||||
linear_names = find_all_linear_names(model)
|
||||
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
|
||||
lora_target_modules_as_list = (
|
||||
lora_target_modules
|
||||
if isinstance(lora_target_modules, list)
|
||||
else [lora_target_modules]
|
||||
)
|
||||
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
|
||||
|
||||
lora_config_kwargs = _build_lora_config_kwargs(cfg)
|
||||
lora_config_kwargs.update(PLUGIN_MANAGER.get_lora_config_kwargs(cfg))
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
target_parameters=lora_target_parameters,
|
||||
layers_to_transform=cfg.peft_layers_to_transform,
|
||||
layers_pattern=cfg.peft_layers_pattern,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
|
||||
bias="none",
|
||||
task_type=_get_peft_task_type(model),
|
||||
**lora_config_kwargs,
|
||||
)
|
||||
return lora_config
|
||||
|
||||
|
||||
def _peft_will_auto_convert_target_params(model, lora_config) -> bool:
|
||||
"""Check whether PEFT will auto-populate target_parameters for this model.
|
||||
|
||||
@@ -226,62 +298,7 @@ def load_lora(
|
||||
config_only: bool = False,
|
||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||
_patch_peft_clippable_linear()
|
||||
lora_target_modules = cfg.lora_target_modules or []
|
||||
lora_target_parameters = cfg.lora_target_parameters or []
|
||||
|
||||
if cfg.lora_target_linear:
|
||||
linear_names = find_all_linear_names(model)
|
||||
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
|
||||
lora_target_modules_as_list = (
|
||||
lora_target_modules
|
||||
if isinstance(lora_target_modules, list)
|
||||
else [lora_target_modules]
|
||||
)
|
||||
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
|
||||
|
||||
lora_config_kwargs = {}
|
||||
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if loftq_bits:
|
||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
if cfg.peft_init_lora_weights:
|
||||
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
|
||||
if cfg.peft_use_dora:
|
||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||
if cfg.peft_use_rslora:
|
||||
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
||||
if cfg.peft_layer_replication:
|
||||
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
||||
if cfg.peft_trainable_token_indices:
|
||||
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
|
||||
if cfg.peft_ensure_weight_tying is not None:
|
||||
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
|
||||
|
||||
# Determine the correct PEFT task type
|
||||
model_cls = type(model).__name__
|
||||
if "SequenceClassification" in model_cls:
|
||||
task_type = TaskType.SEQ_CLS
|
||||
elif "TokenClassification" in model_cls:
|
||||
task_type = TaskType.TOKEN_CLS
|
||||
else:
|
||||
task_type = TaskType.CAUSAL_LM
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
target_parameters=lora_target_parameters,
|
||||
layers_to_transform=cfg.peft_layers_to_transform,
|
||||
layers_pattern=cfg.peft_layers_pattern,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
|
||||
bias="none",
|
||||
task_type=task_type,
|
||||
**lora_config_kwargs,
|
||||
)
|
||||
lora_config = _build_peft_lora_config(model, cfg)
|
||||
|
||||
if config_only:
|
||||
return None, lora_config
|
||||
@@ -315,7 +332,7 @@ def load_lora(
|
||||
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||
LOG.debug("Loading pretrained PEFT adapter")
|
||||
if cfg.lora_on_cpu:
|
||||
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
||||
model_kwargs["device_map"] = {"": "cpu"}
|
||||
@@ -364,30 +381,60 @@ def load_adapter(
|
||||
cfg: DictDefault,
|
||||
adapter: str | None,
|
||||
inference: bool = False,
|
||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]:
|
||||
config_only: bool = False,
|
||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter in ["lora", "qlora"]:
|
||||
peft_model, lora_config = load_lora(model, cfg, inference=inference)
|
||||
peft_model, lora_config = load_lora(
|
||||
model, cfg, inference=inference, config_only=config_only
|
||||
)
|
||||
return peft_model, lora_config
|
||||
if adapter == "llama-adapter":
|
||||
if config_only:
|
||||
_, lora_config = load_llama_adapter(model, cfg, config_only=True)
|
||||
return None, lora_config
|
||||
peft_model, lora_config = load_llama_adapter(model, cfg)
|
||||
return peft_model, lora_config
|
||||
|
||||
raise NotImplementedError(f"{adapter} PEFT adapter not available")
|
||||
plugin_loaded = PLUGIN_MANAGER.load_adapter(
|
||||
model,
|
||||
cfg,
|
||||
inference=inference,
|
||||
config_only=config_only,
|
||||
)
|
||||
if plugin_loaded is not None:
|
||||
return plugin_loaded
|
||||
|
||||
adapter_capability = PLUGIN_MANAGER.get_adapter_capability(adapter)
|
||||
if adapter_capability and adapter_capability.lora_like:
|
||||
peft_model, lora_config = load_lora(
|
||||
model, cfg, inference=inference, config_only=config_only
|
||||
)
|
||||
return peft_model, lora_config
|
||||
|
||||
registered = sorted(PLUGIN_MANAGER.adapter_capabilities())
|
||||
registered_msg = ", ".join(registered) if registered else "none"
|
||||
raise NotImplementedError(
|
||||
f"Adapter '{adapter}' is not built in and was not registered by a plugin "
|
||||
f"with loader support. Registered plugin adapters: {registered_msg}"
|
||||
)
|
||||
|
||||
|
||||
def load_llama_adapter(
|
||||
model: PreTrainedModel, cfg: DictDefault
|
||||
) -> tuple[PeftModel | PeftMixedModel, PeftConfig]:
|
||||
model: PreTrainedModel, cfg: DictDefault, config_only: bool = False
|
||||
) -> tuple[PeftModel | PeftMixedModel | None, PeftConfig]:
|
||||
peft_config = AdaptionPromptConfig(
|
||||
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
||||
adapter_len=cfg.peft_adapter.len, # prompt length (K)
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
if config_only:
|
||||
return None, peft_config
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - llama_adapter")
|
||||
peft_model = PeftModel.from_pretrained(
|
||||
|
||||
@@ -39,7 +39,7 @@ from transformers.integrations.deepspeed import (
|
||||
|
||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.adapter import load_adapter, load_lora
|
||||
from axolotl.loaders.adapter import load_adapter
|
||||
from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING
|
||||
from axolotl.loaders.patch_manager import PatchManager
|
||||
from axolotl.loaders.utils import (
|
||||
@@ -386,8 +386,12 @@ class ModelLoader:
|
||||
and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
|
||||
and not self.cfg.merge_lora
|
||||
):
|
||||
_, lora_config = load_lora(
|
||||
self.model, self.cfg, inference=False, config_only=True
|
||||
_, lora_config = load_adapter(
|
||||
self.model,
|
||||
self.cfg,
|
||||
self.cfg.adapter,
|
||||
inference=False,
|
||||
config_only=True,
|
||||
)
|
||||
else:
|
||||
self.model, lora_config = load_adapter(
|
||||
@@ -628,13 +632,7 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
def _set_attention_config(self):
|
||||
# s2 and fp8 need a different HF backend at load time than their
|
||||
# canonical name: s2 patches FA2 internals, so load under FA2; fp8
|
||||
# replaces F.scaled_dot_product_attention post-load, so load under sdpa.
|
||||
# Every other canonical name (and hub-kernel paths) is passed through
|
||||
# verbatim — xformers/sage/flash_attention_* are registered under their
|
||||
# own names in ALL_ATTENTION_FUNCTIONS before model load. gemma4_hybrid
|
||||
# is already pinned to flash_attention_2 by normalize_attn_implementation.
|
||||
# s2 patches FA2 internals (load as FA2); fp8 replaces sdpa post-load (load as sdpa).
|
||||
_LOAD_TIME_OVERRIDE = {"s2": "flash_attention_2", "fp8": "sdpa"}
|
||||
if self.cfg.attn_implementation:
|
||||
hf_impl = _LOAD_TIME_OVERRIDE.get(
|
||||
@@ -826,6 +824,17 @@ class ModelLoader:
|
||||
else:
|
||||
self.model = self._load_model_from_pretrained(model_loader_class)
|
||||
|
||||
if self.cfg.use_onebitllms:
|
||||
try:
|
||||
from onebitllms import replace_linear_with_bitnet_linear
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"The 'onebitllms' package is required for use_onebitllms. "
|
||||
"Install it with: `uv pip install onebitllms`"
|
||||
) from exc
|
||||
|
||||
self.model = replace_linear_with_bitnet_linear(self.model)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_move_to_device = True
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -161,6 +162,7 @@ def get_state_dict(self, model, unwrap=True):
|
||||
|
||||
state_dict = {}
|
||||
sharded_state_dict = model.state_dict()
|
||||
is_rank_zero = torch.distributed.get_rank() == 0
|
||||
for param_name, param in sharded_state_dict.items():
|
||||
if param.is_cpu:
|
||||
param = param.to(torch.device("cuda"))
|
||||
@@ -168,9 +170,20 @@ def get_state_dict(self, model, unwrap=True):
|
||||
if isinstance(param, DTensor):
|
||||
param = param.full_tensor()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if is_rank_zero:
|
||||
state_dict[param_name] = param.cpu()
|
||||
# Drop the GPU-resident gathered tensor before the next iteration
|
||||
# allocates the next one; otherwise the caching allocator holds
|
||||
# both reservations and we accumulate ~model-size of VRAM.
|
||||
del param
|
||||
torch.distributed.barrier()
|
||||
|
||||
# Release the sharded view and force the allocator to give back the
|
||||
# gather buffers.
|
||||
del sharded_state_dict
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
from torch.distributed.fsdp import (
|
||||
FullStateDictConfig,
|
||||
|
||||
@@ -6,9 +6,8 @@ import os.path
|
||||
import shutil
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
@@ -28,9 +27,15 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError: # pragma: no cover - optional dependency for 8-bit merge paths
|
||||
bnb = None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def magnitude_pruning_(tensor, prune_ratio):
|
||||
"""Zero the lowest ``prune_ratio`` fraction of values by absolute magnitude, in place."""
|
||||
tensor_magnitude = torch.abs(tensor)
|
||||
threshold = torch.quantile(
|
||||
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
|
||||
@@ -40,15 +45,43 @@ def magnitude_pruning_(tensor, prune_ratio):
|
||||
tensor.mul_(mask.to(dtype=tensor.dtype))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def random_pruning_(tensor, prune_ratio):
|
||||
"""Zero a random ``prune_ratio`` fraction of values, in place."""
|
||||
mask = (
|
||||
torch.rand(tensor.shape, dtype=torch.float32, device=tensor.device)
|
||||
> prune_ratio
|
||||
)
|
||||
tensor.mul_(mask.to(dtype=tensor.dtype))
|
||||
|
||||
|
||||
# 0.999 mirrors the reference implementation. True zeroing breaks
|
||||
# ZeroRedundancyOptimizer.consolidate_state_dict; see Guitaricet/relora's
|
||||
# peft_pretraining/training_utils.py for the original note on this.
|
||||
_FULL_RESET_RATIO = 0.999
|
||||
|
||||
|
||||
def reset_optimizer(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
*,
|
||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||
reset_params: List[torch.nn.Parameter],
|
||||
optimizer_state_keys: List[str],
|
||||
optimizer_magnitude_pruning: float = 0.9,
|
||||
prune_method: Literal["magnitude", "random", "reset"] = "magnitude",
|
||||
prune_ratio: float = 0.9,
|
||||
):
|
||||
# pylint:disable=unused-argument
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
|
||||
"""Prune optimizer state for ``reset_params`` only."""
|
||||
if prune_method == "magnitude":
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||
elif prune_method in ("random", "reset"):
|
||||
# "reset" is random pruning at a near-full ratio; the caller is responsible
|
||||
# for supplying the appropriate prune_ratio (see ReLoRACallback.on_step_begin).
|
||||
pruning_fn = partial(random_pruning_, prune_ratio=prune_ratio)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown prune_method {prune_method!r}; expected one of "
|
||||
"'magnitude', 'random', 'reset'"
|
||||
)
|
||||
|
||||
n_zeros = 0
|
||||
n_total = 0
|
||||
|
||||
@@ -56,22 +89,22 @@ def reset_optimizer(
|
||||
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||
optimizer_state = optimizer.optim.state
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
state = optimizer_state[param]
|
||||
for key, value in state.items():
|
||||
if key not in optimizer_state_keys:
|
||||
for param in reset_params:
|
||||
state = optimizer_state.get(param, {})
|
||||
if not state:
|
||||
continue
|
||||
for key in optimizer_state_keys:
|
||||
value = state.get(key)
|
||||
if value is None or not torch.is_tensor(value):
|
||||
continue
|
||||
try:
|
||||
pruning_fn(value)
|
||||
n_total += value.numel()
|
||||
n_zeros += torch.sum(value == 0).item()
|
||||
except RuntimeError as exc:
|
||||
if "quantile() input tensor is too large" in str(exc):
|
||||
continue
|
||||
if torch.is_tensor(value):
|
||||
try:
|
||||
pruning_fn(value)
|
||||
n_total += value.numel()
|
||||
n_zeros += torch.sum(value == 0).item()
|
||||
except RuntimeError as exc:
|
||||
if "quantile() input tensor is too large" in str(exc):
|
||||
pass
|
||||
else:
|
||||
raise exc
|
||||
raise
|
||||
|
||||
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||
@@ -82,11 +115,12 @@ class ReLoRACallback(TrainerCallback):
|
||||
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
||||
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.relora_steps = cfg.jagged_restart_steps
|
||||
self.jagged_restart_steps = cfg.jagged_restart_steps
|
||||
self.cpu_offload = cfg.relora_cpu_offload
|
||||
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
||||
self.last_full_model = cfg.base_model
|
||||
self.resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
self.prune_method = cfg.relora_prune_method or "magnitude"
|
||||
|
||||
if not os.path.exists(self.last_full_model):
|
||||
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
|
||||
@@ -128,7 +162,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
):
|
||||
if not optimizer:
|
||||
optimizer = state.optimizer
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
if state.global_step > 0 and state.global_step % self.jagged_restart_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
@@ -144,7 +178,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||
|
||||
lora_params = [
|
||||
n
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if p.requires_grad and "lora_" in n
|
||||
]
|
||||
@@ -166,11 +200,19 @@ class ReLoRACallback(TrainerCallback):
|
||||
actually_save=is_main_process(),
|
||||
cpu_offload=self.cpu_offload,
|
||||
)
|
||||
# When relora_prune_ratio is not set, use _FULL_RESET_RATIO for
|
||||
# "reset" (paper-style near-full reset) and 0.9 for other methods.
|
||||
prune_ratio = args.relora_prune_ratio
|
||||
if prune_ratio is None:
|
||||
prune_ratio = (
|
||||
_FULL_RESET_RATIO if self.prune_method == "reset" else 0.9
|
||||
)
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=lora_params,
|
||||
optimizer_state_keys=optimizer_state_keys,
|
||||
optimizer_magnitude_pruning=args.relora_prune_ratio,
|
||||
prune_method=self.prune_method,
|
||||
prune_ratio=prune_ratio,
|
||||
)
|
||||
|
||||
if self.quantized:
|
||||
@@ -191,8 +233,8 @@ class ReLoRACallback(TrainerCallback):
|
||||
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
|
||||
)
|
||||
if (
|
||||
state.global_step >= self.relora_steps
|
||||
and state.global_step % self.relora_steps != 0
|
||||
state.global_step >= self.jagged_restart_steps
|
||||
and state.global_step % self.jagged_restart_steps != 0
|
||||
):
|
||||
if self.quantized:
|
||||
if is_main_process() and self.last_full_model != checkpoint_folder:
|
||||
@@ -320,6 +362,8 @@ def update_weights(
|
||||
target.weight.data = new_weight.cpu()
|
||||
target.to(device)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
if bnb is None:
|
||||
raise ImportError("bitsandbytes is required to merge 8-bit LoRA weights")
|
||||
target.weight.data = (
|
||||
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -285,7 +285,9 @@ def save_trained_model(
|
||||
)
|
||||
# Handle ReLoRA early return case
|
||||
if cfg.relora:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
if hasattr(model, "merge_and_unload") and not (
|
||||
cfg.load_in_4bit or cfg.load_in_8bit
|
||||
):
|
||||
model = model.merge_and_unload()
|
||||
else:
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
|
||||
@@ -43,14 +43,16 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
# Initialize batch
|
||||
messages = [ex["messages"] for ex in examples]
|
||||
|
||||
batch = self.processing_strategy.processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_dict=True,
|
||||
chat_template=self.processing_strategy.chat_template,
|
||||
batch = dict(
|
||||
self.processing_strategy.processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_dict=True,
|
||||
chat_template=self.processing_strategy.chat_template,
|
||||
)
|
||||
)
|
||||
|
||||
# Process the labels
|
||||
|
||||
@@ -1044,7 +1044,7 @@ class AxolotlInputConfig(
|
||||
torch_compile: Literal["auto"] | bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0"
|
||||
"description": "Whether to use torch.compile and which backend to use."
|
||||
},
|
||||
)
|
||||
torch_compile_backend: str | None = Field(
|
||||
@@ -1397,16 +1397,7 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_attn_implementation(cls, data):
|
||||
"""Map legacy boolean attention flags to the canonical `attn_implementation`.
|
||||
|
||||
`attn_implementation` is the single source of truth on the validated
|
||||
config. Legacy booleans (`flash_attention: true`, …) are input-only
|
||||
aliases; this validator warns, maps them to their canonical value, and
|
||||
strips them from `data` so they cannot be read downstream.
|
||||
|
||||
Raises if a canonical `attn_implementation` is set alongside any legacy
|
||||
boolean — users must pick one.
|
||||
"""
|
||||
"""Map legacy boolean attention flags to canonical attn_implementation, warn, then strip."""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
|
||||
@@ -166,10 +166,10 @@ class SFTDataset(BaseModel):
|
||||
"description": "Roles to train on. The tokens from these roles will be considered for the loss."
|
||||
},
|
||||
)
|
||||
train_on_eos: Literal["all", "turn", "last"] | None = Field(
|
||||
train_on_eos: Literal["all", "turn", "last", "none"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation"
|
||||
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation, none: never train on EOS tokens"
|
||||
},
|
||||
)
|
||||
roles: dict[str, list[str]] | None = Field(
|
||||
|
||||
@@ -97,12 +97,7 @@ class CustomSupportedOptimizers(str, Enum):
|
||||
flash_lion = "flash_lion"
|
||||
|
||||
|
||||
# Canonical values accepted for `attn_implementation`. These are passed to HF
|
||||
# verbatim via `model.config._attn_implementation`. HF-native backends use HF's
|
||||
# own names (`flash_attention_2`, `flex_attention`, ...); axolotl-owned backends
|
||||
# (`xformers`, `sage`, `s2`, `fp8`) register into `ALL_ATTENTION_FUNCTIONS` under
|
||||
# these exact names. Hub-kernel paths (e.g. `kernels-community/flash-attn3`) are
|
||||
# not in this set — they pass through the validator via the "/" check.
|
||||
# Accepted canonical names; hub-kernel paths (containing "/") bypass this set.
|
||||
CANONICAL_ATTN_IMPLS = frozenset(
|
||||
{
|
||||
"eager",
|
||||
@@ -117,10 +112,7 @@ CANONICAL_ATTN_IMPLS = frozenset(
|
||||
}
|
||||
)
|
||||
|
||||
# Legacy boolean attention flags → canonical `attn_implementation`. Kept for
|
||||
# backwards compatibility; the normalizer warns and strips these from the
|
||||
# validated config. Priority order (first match wins) matches the old priority:
|
||||
# specific backends beat the generic flash/sdp/eager fallbacks.
|
||||
# Legacy boolean flags → canonical attn_implementation. Priority: specific before generic.
|
||||
LEGACY_ATTN_FLAG_TO_IMPL = {
|
||||
"xformers_attention": "xformers",
|
||||
"s2_attention": "s2",
|
||||
@@ -131,9 +123,7 @@ LEGACY_ATTN_FLAG_TO_IMPL = {
|
||||
"eager_attention": "eager",
|
||||
}
|
||||
|
||||
# Short-form aliases that were accepted by the in-progress branch but are
|
||||
# rejected going forward. Mapped to canonical names only to produce a helpful
|
||||
# error message pointing users at the right value.
|
||||
# Short-form aliases rejected at validation; mapped to canonical names for error messages.
|
||||
SHORT_FORM_ALIAS_TO_CANONICAL = {
|
||||
"flash": "flash_attention_2",
|
||||
"flex": "flex_attention",
|
||||
@@ -148,18 +138,19 @@ ATTN_IMPLS_SUPPORTING_PACKING = frozenset(
|
||||
"flex_attention",
|
||||
"xformers",
|
||||
"sage",
|
||||
"kernels-community/flash-attn2",
|
||||
"kernels-community/flash-attn3",
|
||||
"kernels-community/sage-attention",
|
||||
}
|
||||
)
|
||||
|
||||
# Backends that require the flash_attn library (Dao-AILab/flash-attention) for
|
||||
# axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, ring-FA, ...).
|
||||
# Backends that require the flash_attn library for axolotl's own monkeypatches.
|
||||
ATTN_IMPLS_USING_FLASH_LIB = frozenset(
|
||||
{
|
||||
"flash_attention_2",
|
||||
"flash_attention_3",
|
||||
"s2",
|
||||
"kernels-community/flash-attn2",
|
||||
"kernels-community/flash-attn3",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -103,6 +103,12 @@ class ModelInputConfig(BaseModel):
|
||||
default=None,
|
||||
json_schema_extra={"description": "kwargs for model quantization config"},
|
||||
)
|
||||
use_onebitllms: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use `onebitllms` for 1.58bit training (only for bitnet models)."
|
||||
},
|
||||
)
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
@classmethod
|
||||
|
||||
@@ -6,6 +6,57 @@ from PIL.Image import Resampling
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class RoleBoundarySpec(BaseModel):
|
||||
"""One ``cfg.role_boundaries`` row; see docs/multimodal_assistant_mask.md."""
|
||||
|
||||
role: str = Field(
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Role name as it appears in cfg.roles_to_train (e.g. "
|
||||
"'assistant', 'user', 'system', 'tool', 'ipython')."
|
||||
)
|
||||
},
|
||||
)
|
||||
start: str = Field(
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Literal string that marks the start of this role's span in "
|
||||
"the rendered chat template. Tokenized via "
|
||||
"``tokenizer.encode(..., add_special_tokens=False)`` at "
|
||||
"strategy init."
|
||||
)
|
||||
},
|
||||
)
|
||||
end: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Literal string that marks the end of this role's span. "
|
||||
"Set to ``eos_token`` to terminate at the tokenizer's EOS. "
|
||||
"Leave unset / null to terminate at end-of-sequence."
|
||||
)
|
||||
},
|
||||
)
|
||||
include_start: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Whether the start marker tokens contribute to loss on "
|
||||
"trainable turns. Default False."
|
||||
)
|
||||
},
|
||||
)
|
||||
include_end: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Whether the end marker tokens contribute to loss on "
|
||||
"trainable turns (honoring cfg.train_on_eos). Default True."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MultiModalConfig(BaseModel):
|
||||
"""Multi-modal configuration subset"""
|
||||
|
||||
@@ -26,6 +77,17 @@ class MultiModalConfig(BaseModel):
|
||||
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
|
||||
},
|
||||
)
|
||||
role_boundaries: list[RoleBoundarySpec] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Opt-in override for the MM mask scanner's per-role boundary "
|
||||
"markers. Non-empty list replaces built-ins wholesale; unset "
|
||||
"or empty falls back to built-ins. See "
|
||||
"docs/multimodal_assistant_mask.md."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@field_validator("image_resize_algorithm", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -38,10 +38,10 @@ class LoraConfig(BaseModel):
|
||||
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
|
||||
)
|
||||
|
||||
adapter: Literal["lora", "qlora", "llama-adapter"] | None = Field(
|
||||
adapter: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "If you want to use 'lora', 'qlora', or 'llama-adapter', or leave blank to train all parameters in original model"
|
||||
"description": "If you want to use a built-in or plugin adapter, or leave blank to train all parameters in original model"
|
||||
},
|
||||
)
|
||||
lora_model_dir: str | None = Field(
|
||||
@@ -174,6 +174,16 @@ class LoraConfig(BaseModel):
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
adapter = data.get("adapter")
|
||||
if adapter and adapter not in ("lora", "qlora", "llama-adapter"):
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
if not plugin_manager.supports_adapter(adapter):
|
||||
raise ValueError(
|
||||
f"Adapter '{adapter}' is not built in and was not registered by "
|
||||
"a plugin. Add the plugin that provides this adapter to `plugins:`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -240,8 +250,28 @@ class ReLoRAConfig(BaseModel):
|
||||
)
|
||||
relora_prune_ratio: float | None = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
json_schema_extra={
|
||||
"description": "threshold for optimizer magnitude when pruning"
|
||||
"description": (
|
||||
"Fraction of optimizer state values to zero on each ReLoRA restart. "
|
||||
"When relora_prune_method='reset' and this is omitted, defaults to "
|
||||
"0.999 (paper-style near-full reset). For other methods, defaults to 0.9."
|
||||
)
|
||||
},
|
||||
)
|
||||
relora_prune_method: Literal["magnitude", "random", "reset"] | None = Field(
|
||||
default="magnitude",
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Optimizer state pruning method on each ReLoRA restart. "
|
||||
"'magnitude' (default) keeps top-k by absolute value; "
|
||||
"'random' keeps a random subset at relora_prune_ratio; "
|
||||
"'reset' uses near-full random pruning (default ratio 0.999, "
|
||||
"honoring relora_prune_ratio when explicitly set). "
|
||||
"Paper-style recipe: relora_prune_method='reset' with no "
|
||||
"relora_prune_ratio, equivalent to 'random' with ratio=0.999."
|
||||
)
|
||||
},
|
||||
)
|
||||
relora_cpu_offload: bool | None = Field(
|
||||
|
||||
@@ -183,9 +183,6 @@ class DatasetValidationMixin:
|
||||
class AttentionValidationMixin:
|
||||
"""Validation methods related to attention mechanisms."""
|
||||
|
||||
# `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation`
|
||||
# is now the single entry point for attention-input mapping and conflict detection.
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_sample_packing_without_attention(self):
|
||||
if self.sample_packing and not self.attn_supports_packing:
|
||||
@@ -629,6 +626,12 @@ class LoRAValidationMixin:
|
||||
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_onebitllms_lora(self):
|
||||
if self.use_onebitllms and self.adapter in ["lora", "qlora"]:
|
||||
raise ValueError("LoRA/QLoRA is not supported with use_onebitllms")
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_qlora_zero3_w_use_reentrant(cls, data):
|
||||
@@ -1475,8 +1478,19 @@ class ComplexValidationMixin:
|
||||
if self.relora:
|
||||
if not self.jagged_restart_steps:
|
||||
raise ValueError("jagged_restart_steps must be set to use ReLoRA")
|
||||
if self.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
adapter_supports_relora = self.adapter in ("lora", "qlora")
|
||||
if self.adapter and not adapter_supports_relora:
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
adapter_supports_relora = plugin_manager.adapter_supports_relora(
|
||||
self.adapter
|
||||
)
|
||||
if not adapter_supports_relora:
|
||||
raise ValueError(
|
||||
"cfg.adapter must support ReLoRA to use ReLoRA restart semantics"
|
||||
)
|
||||
|
||||
if self.fsdp or self.fsdp_config:
|
||||
raise ValueError("fsdp not supported with ReLoRA")
|
||||
|
||||
@@ -119,15 +119,49 @@ def download_smollm2_135m_gptq_model():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_qwen_2_5_half_billion_model():
|
||||
# download the model
|
||||
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
||||
def download_qwen3_half_billion_model():
|
||||
# download the model (still used as the KD teacher in tests/e2e/integrations/test_kd.py)
|
||||
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_qwen3_half_billion_model():
|
||||
# download the model
|
||||
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
|
||||
def download_tiny_llama_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-llama-50m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_mistral_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-mistral-25m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_mixtral_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-mixtral-30m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_phi_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-phi-64m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_falcon_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-falcon-42m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_qwen2_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-qwen2-129m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_qwen3_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-qwen3-129m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_gemma2_model():
|
||||
snapshot_download_w_retry("axolotl-ai-co/tiny-gemma2-137m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -620,7 +654,15 @@ def fixture_min_base_cfg():
|
||||
)
|
||||
def test_load_fixtures(
|
||||
download_smollm2_135m_model,
|
||||
download_qwen_2_5_half_billion_model,
|
||||
download_qwen3_half_billion_model,
|
||||
download_tiny_llama_model,
|
||||
download_tiny_mistral_model,
|
||||
download_tiny_mixtral_model,
|
||||
download_tiny_phi_model,
|
||||
download_tiny_falcon_model,
|
||||
download_tiny_qwen2_model,
|
||||
download_tiny_qwen3_model,
|
||||
download_tiny_gemma2_model,
|
||||
download_tatsu_lab_alpaca_dataset,
|
||||
download_mhenrichsen_alpaca_2k_dataset,
|
||||
download_mhenrichsen_alpaca_2k_w_revision_dataset,
|
||||
|
||||
@@ -10,7 +10,10 @@ from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists
|
||||
from tests.e2e.utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -35,13 +38,16 @@ def min_cfg(temp_dir):
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 5e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"output_dir": temp_dir,
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 10,
|
||||
"max_steps": 40,
|
||||
"warmup_steps": 5,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
|
||||
|
||||
@@ -64,11 +70,18 @@ class TestCutCrossEntropyIntegration:
|
||||
else:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=2.2,
|
||||
max_final=2.0,
|
||||
)
|
||||
|
||||
def test_qwen2_w_cce(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"plugins": [
|
||||
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
|
||||
],
|
||||
@@ -87,13 +100,15 @@ class TestCutCrossEntropyIntegration:
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"output_dir": temp_dir,
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 10,
|
||||
"max_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -108,6 +123,13 @@ class TestCutCrossEntropyIntegration:
|
||||
else:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention_type",
|
||||
@@ -136,3 +158,10 @@ class TestCutCrossEntropyIntegration:
|
||||
else:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=2.2,
|
||||
max_final=2.0,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from axolotl.monkeypatch.lora_kernels import (
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
MODEL_NAME = "axolotl-ai-co/tiny-qwen3-129m"
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
"""Verify that training completed successfully — artifacts, no-NaN, loss
|
||||
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
|
||||
"""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
@@ -30,19 +29,13 @@ def verify_training_success(temp_dir):
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=10,
|
||||
final_window=10,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
|
||||
class TestDistMuon:
|
||||
@@ -52,7 +45,7 @@ class TestDistMuon:
|
||||
def test_fft_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -63,11 +56,12 @@ class TestDistMuon:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.02,
|
||||
"learning_rate": 2e-3,
|
||||
"optimizer": "muon",
|
||||
"weight_decay": 0.01,
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -82,6 +76,9 @@ class TestDistMuon:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
@@ -109,7 +106,7 @@ class TestDistMuon:
|
||||
def test_lora_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -122,14 +119,15 @@ class TestDistMuon:
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.02,
|
||||
"learning_rate": 2e-3,
|
||||
"optimizer": "muon",
|
||||
"weight_decay": 0.01,
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -144,6 +142,9 @@ class TestDistMuon:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
"""Test module for FSDP1 multi-GPU functionality."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir
|
||||
from tests.e2e.utils import check_tensorboard_loss_decreased
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
"""Verify that training completed successfully — artifacts, no-NaN, loss
|
||||
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
|
||||
"""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
@@ -31,19 +30,13 @@ def verify_training_success(temp_dir):
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=10,
|
||||
final_window=10,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
|
||||
class TestFSDP1:
|
||||
@@ -56,7 +49,7 @@ class TestFSDP1:
|
||||
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -67,11 +60,12 @@ class TestFSDP1:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -87,6 +81,9 @@ class TestFSDP1:
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
@@ -126,7 +123,7 @@ class TestFSDP1:
|
||||
def test_lora_sft(self, temp_dir, adapter_config):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -140,14 +137,15 @@ class TestFSDP1:
|
||||
"load_in_4bit": adapter_config["load_in_4bit"],
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -163,6 +161,9 @@ class TestFSDP1:
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
@@ -190,7 +191,7 @@ class TestFSDP1:
|
||||
def test_dpo_fft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"rl": "dpo",
|
||||
@@ -203,11 +204,11 @@ class TestFSDP1:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -223,6 +224,9 @@ class TestFSDP1:
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -262,7 +266,7 @@ class TestFSDP1:
|
||||
def test_dpo_lora(self, temp_dir, adapter_config):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"load_in_4bit": adapter_config["load_in_4bit"],
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
@@ -281,11 +285,11 @@ class TestFSDP1:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -301,6 +305,9 @@ class TestFSDP1:
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": "auto",
|
||||
"tf32": True,
|
||||
}
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
"""Test module for FSDP2 multi-GPU functionality."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
"""Verify that training completed successfully — artifacts, no-NaN, loss
|
||||
stayed in qwen2-pretraining scale (tiny-qwen2-129m final pretrain CE ~3.92).
|
||||
"""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
@@ -31,19 +30,13 @@ def verify_training_success(temp_dir):
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=10,
|
||||
final_window=10,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
|
||||
class TestFSDP2:
|
||||
@@ -57,7 +50,7 @@ class TestFSDP2:
|
||||
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -68,11 +61,12 @@ class TestFSDP2:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -86,6 +80,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
@@ -114,7 +111,7 @@ class TestFSDP2:
|
||||
def test_lora_sft(self, temp_dir, peft_use_dora):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -128,14 +125,15 @@ class TestFSDP2:
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -149,6 +147,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
# explicitly disable LORA kernels, as they may be auto-enabled
|
||||
"lora_mlp_kernel": False,
|
||||
@@ -180,7 +181,7 @@ class TestFSDP2:
|
||||
def test_lora_sft_kernels(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -195,11 +196,12 @@ class TestFSDP2:
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -213,6 +215,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
@@ -243,7 +248,7 @@ class TestFSDP2:
|
||||
def test_qlora_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -257,14 +262,15 @@ class TestFSDP2:
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -278,6 +284,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
@@ -305,7 +314,7 @@ class TestFSDP2:
|
||||
def test_qlora_sft_kernels(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -321,11 +330,12 @@ class TestFSDP2:
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -339,6 +349,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"bf16": True,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
@@ -370,7 +383,7 @@ class TestFSDP2:
|
||||
def test_dpo_fft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"rl": "dpo",
|
||||
@@ -383,11 +396,11 @@ class TestFSDP2:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -401,6 +414,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -428,7 +444,7 @@ class TestFSDP2:
|
||||
def test_dpo_lora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
@@ -445,11 +461,11 @@ class TestFSDP2:
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-3,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
@@ -463,6 +479,9 @@ class TestFSDP2:
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def _run_training(temp_dir, cfg):
|
||||
def _base_lora_fsdp2_config(temp_dir, **overrides):
|
||||
"""Base config for LoRA + FSDP2 + kernel tests."""
|
||||
cfg = {
|
||||
"base_model": "Qwen/Qwen3-0.6B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen3-129m",
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.0,
|
||||
"datasets": [
|
||||
|
||||
@@ -8,7 +8,7 @@ from accelerate.test_utils import execute_subprocess_async, get_torch_dist_uniqu
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
|
||||
from tests.e2e.utils import check_tensorboard_loss_decreased, require_torch_2_7_0
|
||||
|
||||
|
||||
class TestTensorParallel:
|
||||
@@ -21,7 +21,7 @@ class TestTensorParallel:
|
||||
def test_fft_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
@@ -63,6 +63,6 @@ class TestTensorParallel:
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs", max_initial=5.0, max_final=4.7
|
||||
)
|
||||
|
||||
@@ -32,12 +32,12 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
{
|
||||
"name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"name": "axolotl-ai-co/tiny-mistral-25m",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
{
|
||||
"name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
"name": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
@@ -47,7 +47,7 @@ MODEL_CONFIGS = [
|
||||
"dtype": torch.float32,
|
||||
},
|
||||
{
|
||||
"name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
|
||||
"name": "axolotl-ai-co/tiny-gemma2-137m",
|
||||
"expected_activation": apply_lora_mlp_geglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
@@ -159,7 +159,7 @@ def test_swiglu_mlp_integration(small_llama_model):
|
||||
def test_geglu_model_integration():
|
||||
"""Test GeGLU activation with Gemma model."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"trl-internal-testing/tiny-Gemma2ForCausalLM",
|
||||
"axolotl-ai-co/tiny-gemma2-137m",
|
||||
dtype=torch.float16,
|
||||
device_map="cuda:0",
|
||||
)
|
||||
|
||||
@@ -4,14 +4,16 @@ E2E tests for falcon
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
from ..utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestFalconPatched(unittest.TestCase):
|
||||
@@ -19,13 +21,12 @@ class TestFalconPatched(unittest.TestCase):
|
||||
Test case for Falcon models
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_qlora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
"flash_attention": True,
|
||||
"base_model": "axolotl-ai-co/tiny-falcon-42m",
|
||||
"flash_attention": False,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
"load_in_4bit": True,
|
||||
@@ -47,17 +48,20 @@ class TestFalconPatched(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -66,14 +70,20 @@ class TestFalconPatched(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
"flash_attention": True,
|
||||
"base_model": "axolotl-ai-co/tiny-falcon-42m",
|
||||
"flash_attention": False,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
@@ -88,17 +98,20 @@ class TestFalconPatched(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -107,3 +120,10 @@ class TestFalconPatched(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,12 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, require_torch_2_6_0, with_temp_dir
|
||||
from ..utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
require_torch_2_6_0,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
@@ -22,7 +27,7 @@ class TestMistral(unittest.TestCase):
|
||||
def test_lora_packing(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"base_model": "axolotl-ai-co/tiny-mistral-25m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -45,17 +50,20 @@ class TestMistral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -64,12 +72,19 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.5,
|
||||
max_final=4.3,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft_packing(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"base_model": "axolotl-ai-co/tiny-mistral-25m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -86,17 +101,20 @@ class TestMistral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -105,3 +123,10 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.5,
|
||||
max_final=4.3,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,11 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
from ..utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestMixtral(unittest.TestCase):
|
||||
@@ -21,8 +25,7 @@ class TestMixtral(unittest.TestCase):
|
||||
def test_qlora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
@@ -30,7 +33,7 @@ class TestMixtral(unittest.TestCase):
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {},
|
||||
@@ -41,17 +44,21 @@ class TestMixtral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 3e-3,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 80,
|
||||
"eval_steps": 80,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -60,13 +67,19 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=10,
|
||||
final_window=10,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
@@ -79,17 +92,21 @@ class TestMixtral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"learning_rate": 5e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 80,
|
||||
"eval_steps": 80,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -98,3 +115,10 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@@ -22,8 +22,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
def test_mixtral_multipack(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
@@ -57,7 +56,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
def test_mistral_multipack(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"base_model": "axolotl-ai-co/tiny-mistral-25m",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
|
||||
@@ -9,7 +9,11 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
from ..utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestPhiMultipack(unittest.TestCase):
|
||||
@@ -21,7 +25,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
def test_ft_packed(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model": "axolotl-ai-co/tiny-phi-64m",
|
||||
"model_type": "PhiForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
@@ -43,17 +47,20 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"dataset_shard_num": 10,
|
||||
"dataset_shard_idx": 0,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"eval_steps": 3,
|
||||
"save_steps": 4,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"eval_steps": 50,
|
||||
"save_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -63,12 +70,19 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_packed(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model": "axolotl-ai-co/tiny-phi-64m",
|
||||
"model_type": "PhiForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
@@ -94,17 +108,20 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"dataset_shard_num": 10,
|
||||
"dataset_shard_idx": 0,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"eval_steps": 3,
|
||||
"save_steps": 4,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"eval_steps": 50,
|
||||
"save_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -114,3 +131,10 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=6.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from transformers import AutoModelForCausalLM
|
||||
# Import the actual trainer methods we want to test
|
||||
from axolotl.core.trainers.grpo.async_trainer import AsyncGRPOTrainer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
MODEL_NAME = "axolotl-ai-co/tiny-qwen3-129m"
|
||||
|
||||
|
||||
def _fix_patched_attention(model):
|
||||
|
||||
@@ -56,7 +56,72 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
],
|
||||
"warmup_steps": 10,
|
||||
"num_epochs": 2,
|
||||
"max_steps": 105, # at least 2x relora_steps
|
||||
"max_steps": 105, # at least 2x restart cadence
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists(), (
|
||||
"Relora model checkpoint not found"
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_relora_reset_method(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"flash_attention": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": ["q_proj", "v_proj"],
|
||||
"relora": True,
|
||||
"jagged_restart_steps": 50,
|
||||
"jagged_restart_warmup_steps": 10,
|
||||
"jagged_restart_anneal_steps": 10,
|
||||
"relora_prune_ratio": 0.5, # explicitly honored by reset (not ignored)
|
||||
"relora_prune_method": "reset",
|
||||
"relora_cpu_offload": True,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"warmup_steps": 10,
|
||||
"num_epochs": 2,
|
||||
"max_steps": 105,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
|
||||
@@ -4,14 +4,16 @@ E2E tests for falcon
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
from .utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestFalcon(unittest.TestCase):
|
||||
@@ -19,13 +21,12 @@ class TestFalcon(unittest.TestCase):
|
||||
Test case for falcon
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
"flash_attention": True,
|
||||
"base_model": "axolotl-ai-co/tiny-falcon-42m",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
@@ -49,17 +50,21 @@ class TestFalcon(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"warmup_steps": 5,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -69,14 +74,20 @@ class TestFalcon(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_lora_added_vocab(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
"flash_attention": True,
|
||||
"base_model": "axolotl-ai-co/tiny-falcon-42m",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
@@ -104,17 +115,21 @@ class TestFalcon(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"warmup_steps": 5,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -124,14 +139,20 @@ class TestFalcon(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
"flash_attention": True,
|
||||
"base_model": "axolotl-ai-co/tiny-falcon-42m",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
@@ -145,17 +166,23 @@ class TestFalcon(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 5e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 80,
|
||||
"eval_steps": 80,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -165,3 +192,10 @@ class TestFalcon(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=10,
|
||||
final_window=10,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,11 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
from .utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
@@ -23,7 +27,7 @@ class TestMistral(unittest.TestCase):
|
||||
def test_lora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"base_model": "axolotl-ai-co/tiny-mistral-25m",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
@@ -45,16 +49,18 @@ class TestMistral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -64,12 +70,19 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=4.5,
|
||||
max_final=4.3,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"base_model": "axolotl-ai-co/tiny-mistral-25m",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.02,
|
||||
@@ -85,16 +98,18 @@ class TestMistral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -108,3 +123,10 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=4.5,
|
||||
max_final=4.3,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,11 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
from .utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestMixtral(unittest.TestCase):
|
||||
@@ -24,8 +28,7 @@ class TestMixtral(unittest.TestCase):
|
||||
def test_qlora_w_fa2(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_4bit": True,
|
||||
@@ -51,16 +54,18 @@ class TestMixtral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -74,13 +79,19 @@ class TestMixtral(unittest.TestCase):
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_wo_fa2(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 1024,
|
||||
"load_in_4bit": True,
|
||||
@@ -106,16 +117,18 @@ class TestMixtral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -129,13 +142,19 @@ class TestMixtral(unittest.TestCase):
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_16bit_lora_w_fa2(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"adapter": "lora",
|
||||
@@ -160,16 +179,18 @@ class TestMixtral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -187,13 +208,19 @@ class TestMixtral(unittest.TestCase):
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_16bit_lora_wo_fa2(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 1024,
|
||||
"adapter": "lora",
|
||||
@@ -218,16 +245,18 @@ class TestMixtral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -245,13 +274,19 @@ class TestMixtral(unittest.TestCase):
|
||||
== torch.float32
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"base_model": "axolotl-ai-co/tiny-mixtral-30m",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.02,
|
||||
@@ -263,16 +298,18 @@ class TestMixtral(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -286,3 +323,10 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
require_torch_2_5_1,
|
||||
require_torch_2_6_0,
|
||||
require_torch_2_7_0,
|
||||
@@ -243,20 +244,18 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
def test_came_pytorch(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "axolotl-ai-co/tiny-llama-50m",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -265,16 +264,22 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 1e-4,
|
||||
"optimizer": "came_pytorch",
|
||||
"adam_beta3": 0.9999,
|
||||
"adam_epsilon2": 1e-16,
|
||||
"max_steps": 5,
|
||||
"max_steps": 80,
|
||||
"warmup_steps": 5,
|
||||
"logging_steps": 1,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -284,6 +289,13 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=10,
|
||||
final_window=10,
|
||||
max_initial=4.0,
|
||||
max_final=3.0,
|
||||
)
|
||||
|
||||
|
||||
@require_torch_2_7_0
|
||||
|
||||
@@ -9,7 +9,11 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
from .utils import (
|
||||
check_model_output_exists,
|
||||
check_tensorboard_loss_decreased,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestPhi(unittest.TestCase):
|
||||
@@ -21,7 +25,7 @@ class TestPhi(unittest.TestCase):
|
||||
def test_phi_ft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model": "axolotl-ai-co/tiny-phi-64m",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 2048,
|
||||
@@ -41,18 +45,22 @@ class TestPhi(unittest.TestCase):
|
||||
"dataset_shard_num": 10,
|
||||
"dataset_shard_idx": 0,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"max_steps": 10,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"warmup_steps": 5,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -61,12 +69,19 @@ class TestPhi(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_phi_qlora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model": "axolotl-ai-co/tiny-phi-64m",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 2048,
|
||||
@@ -90,18 +105,22 @@ class TestPhi(unittest.TestCase):
|
||||
"dataset_shard_num": 10,
|
||||
"dataset_shard_idx": 0,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"learning_rate": 2e-4,
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"max_steps": 10,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 50,
|
||||
"warmup_steps": 5,
|
||||
"logging_steps": 1,
|
||||
"save_steps": 50,
|
||||
"eval_steps": 50,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -110,3 +129,10 @@ class TestPhi(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
check_tensorboard_loss_decreased(
|
||||
temp_dir + "/runs",
|
||||
initial_window=5,
|
||||
final_window=5,
|
||||
max_initial=5.0,
|
||||
max_final=4.7,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestPreprocess:
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"base_model": "axolotl-ai-co/tiny-qwen2-129m",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
|
||||
@@ -45,7 +45,7 @@ def _get_fake_quant_config_dtype(config):
|
||||
@pytest.fixture()
|
||||
def model():
|
||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen2-0.5B",
|
||||
"axolotl-ai-co/tiny-qwen2-129m",
|
||||
device_map="auto",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ class TestE2eQwen:
|
||||
Test cases for qwen models
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||
@pytest.mark.parametrize("base_model", ["axolotl-ai-co/tiny-qwen2-129m"])
|
||||
def test_dpo(self, base_model, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
|
||||
@@ -199,6 +199,106 @@ def check_tensorboard(
|
||||
assert df.value.values[-1] > 1e-5, "Expected loss to be greater than zero"
|
||||
|
||||
|
||||
def check_tensorboard_loss_decreased(
|
||||
temp_run_dir: str,
|
||||
tag: str | None = None,
|
||||
initial_window: int = 1,
|
||||
final_window: int = 1,
|
||||
min_delta: float | None = None,
|
||||
max_initial: float | None = None,
|
||||
max_final: float | None = None,
|
||||
max_loss_ratio: float = 0.95,
|
||||
) -> None:
|
||||
"""Check that training actually learned — loss went down and stayed in
|
||||
a sensible range.
|
||||
|
||||
Used with the tiny ``axolotl-ai-co/tiny-*`` CI models, where pretraining
|
||||
was brief enough that final loss won't clear the absolute thresholds used
|
||||
for 135M+ models — but the training pipeline should still behave.
|
||||
|
||||
``train/train_loss`` is only logged once (end-of-training aggregate). The
|
||||
per-step tag is ``train/loss`` for SFT/LM trainers and may vary across
|
||||
trainers (e.g. DPO). When ``tag`` is None we try common per-step tags in
|
||||
order and use the first with enough samples.
|
||||
|
||||
Two kinds of regression we guard against:
|
||||
|
||||
1. **Loss blew up.** A silent bug (e.g. broken label masking) can start
|
||||
training at an absurdly high loss. ``max_initial`` / ``max_final``
|
||||
assert the measured means stay at-or-below bounds measured from a
|
||||
known-good run. Both are optional but strongly encouraged — loss
|
||||
going *down* from a bad starting scale still looks like "learning."
|
||||
|
||||
2. **Loss didn't go down enough.** ``max_loss_ratio`` (default 0.95)
|
||||
requires ``final <= initial * ratio``. A default below 1.0 means the
|
||||
final window mean must sit at least 5% below the initial window mean
|
||||
— real learning, not noise that happened to land below start. Only
|
||||
raise this for configs where a smaller drop is expected *and*
|
||||
documented (e.g. DPO with near-trivial pairs); in that case you are
|
||||
intentionally weakening the test.
|
||||
|
||||
``min_delta`` is optional; when set, additionally requires
|
||||
``final + min_delta <= initial`` — use for configs with enough signal
|
||||
to demand a specific minimum absolute drop.
|
||||
"""
|
||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
|
||||
if tag is None:
|
||||
candidates = ["train/loss", "train/train_loss"]
|
||||
else:
|
||||
candidates = [tag]
|
||||
|
||||
required = initial_window + final_window
|
||||
chosen_tag, values = None, None
|
||||
for candidate in candidates:
|
||||
sub = df[df.tag == candidate]
|
||||
if len(sub) >= required:
|
||||
chosen_tag = candidate
|
||||
values = sub.value.values
|
||||
break
|
||||
|
||||
available = sorted({t for t in df.tag.unique() if "loss" in t.lower()})
|
||||
assert values is not None, (
|
||||
f"None of the tags {candidates} had ≥{required} logged steps. "
|
||||
f"Loss tags present: {available}"
|
||||
)
|
||||
|
||||
initial = float(values[:initial_window].mean())
|
||||
final = float(values[-final_window:].mean())
|
||||
print(
|
||||
f"[check_tensorboard_loss_decreased] tag={chosen_tag} n={len(values)} "
|
||||
f"initial_mean{initial_window}={initial:.4f} final_mean{final_window}={final:.4f}"
|
||||
)
|
||||
assert final > 1e-5, "Expected loss to be greater than zero"
|
||||
assert final <= initial * max_loss_ratio, (
|
||||
f"Loss did not decrease for {chosen_tag}: "
|
||||
f"initial(mean of first {initial_window})={initial:.4f}, "
|
||||
f"final(mean of last {final_window})={final:.4f}, "
|
||||
f"ratio={final / initial:.4f} (max allowed {max_loss_ratio}). "
|
||||
f"Expected final <= initial — training did not learn."
|
||||
)
|
||||
if min_delta is not None:
|
||||
assert final + min_delta <= initial, (
|
||||
f"Expected loss to decrease by at least {min_delta} for {chosen_tag}: "
|
||||
f"initial={initial:.4f}, final={final:.4f}, delta={initial - final:.4f}"
|
||||
)
|
||||
if max_initial is not None:
|
||||
assert initial <= max_initial, (
|
||||
f"Initial loss {initial:.4f} is above the expected max {max_initial}. "
|
||||
f"Absolute scale is wrong — probably a silent regression "
|
||||
f"(e.g. bad label masking) that bumped the starting point."
|
||||
)
|
||||
if max_final is not None:
|
||||
assert final <= max_final, (
|
||||
f"Final loss {final:.4f} is above the expected max {max_final}. "
|
||||
f"Absolute scale is wrong — probably a silent regression "
|
||||
f"(e.g. bad label masking) that bumped the endpoint."
|
||||
)
|
||||
|
||||
|
||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||
"""
|
||||
helper function to check if a model output file exists after training
|
||||
|
||||
160
tests/integrations/mora/test_mora.py
Normal file
160
tests/integrations/mora/test_mora.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Integration tests for the MoRA / ReMoRA adapter path."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.mora import plugin as mora_plugin
|
||||
from axolotl.loaders import adapter as adapter_module
|
||||
from axolotl.loaders.adapter import load_adapter
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestMoraAdapterLoading:
|
||||
"""MoRA adapter selection and config wiring."""
|
||||
|
||||
def test_load_adapter_uses_plugin_lora_like_registration(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
}
|
||||
)
|
||||
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_load_lora(*args, **kwargs):
|
||||
calls.append((args, kwargs))
|
||||
return args[0], "adapter-config"
|
||||
|
||||
monkeypatch.setattr(adapter_module, "load_lora", fake_load_lora)
|
||||
|
||||
_, config = load_adapter(model, cfg, "mora")
|
||||
|
||||
assert config == "adapter-config"
|
||||
assert calls[0][1]["config_only"] is False
|
||||
|
||||
def test_mora_plugin_raises_when_peft_missing_support(self):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
}
|
||||
)
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
|
||||
with pytest.raises(ImportError, match="MoRA support"):
|
||||
load_adapter(model, cfg, "mora", config_only=True)
|
||||
|
||||
def test_mora_plugin_rejects_quantized_base_model(self):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"load_in_4bit": True,
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
}
|
||||
)
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="full-precision base model"):
|
||||
load_adapter(model, cfg, "mora", config_only=True)
|
||||
|
||||
def test_mora_plugin_builds_mora_config_when_supported(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": "rope",
|
||||
},
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
class FakeLoraConfig:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
fake_model = SimpleNamespace(print_trainable_parameters=Mock())
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
|
||||
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
|
||||
monkeypatch.setattr(
|
||||
adapter_module, "get_peft_model", Mock(return_value=fake_model)
|
||||
)
|
||||
|
||||
_, config = load_adapter(model, cfg, "mora", config_only=True)
|
||||
|
||||
assert captured["use_mora"] is True
|
||||
assert captured["mora_type"] == 6
|
||||
assert captured["task_type"].name == "CAUSAL_LM"
|
||||
assert config is not None
|
||||
assert config.use_mora is True
|
||||
assert config.mora_type == 6
|
||||
|
||||
def test_mora_plugin_uses_lora_model_dir_resume_path(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
"lora_model_dir": "adapter-checkpoint",
|
||||
"lora_on_cpu": False,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
class FakeLoraConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
class FakePeftModel:
|
||||
def print_trainable_parameters(self):
|
||||
pass
|
||||
|
||||
def named_parameters(self):
|
||||
return []
|
||||
|
||||
from_pretrained = Mock(return_value=FakePeftModel())
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
|
||||
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
|
||||
monkeypatch.setattr(
|
||||
adapter_module.PeftModel,
|
||||
"from_pretrained",
|
||||
from_pretrained,
|
||||
)
|
||||
|
||||
peft_model, config = load_adapter(model, cfg, "mora")
|
||||
|
||||
assert isinstance(peft_model, FakePeftModel)
|
||||
assert config.use_mora is True
|
||||
from_pretrained.assert_called_once()
|
||||
assert from_pretrained.call_args.args[:2] == (model, "adapter-checkpoint")
|
||||
assert from_pretrained.call_args.kwargs["is_trainable"] is True
|
||||
73
tests/integrations/test_adapter_plugin_registry.py
Normal file
73
tests/integrations/test_adapter_plugin_registry.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Core adapter plugin registry tests."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.base import AdapterCapabilities, BasePlugin, PluginManager
|
||||
from axolotl.loaders import adapter as adapter_module
|
||||
from axolotl.loaders.adapter import load_adapter
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class FakeAdapterPlugin(BasePlugin):
|
||||
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
|
||||
return [AdapterCapabilities(name="fake-adapter", lora_like=True, relora=True)]
|
||||
|
||||
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
|
||||
if cfg.adapter != "fake-adapter":
|
||||
return {}
|
||||
return {"fake_kwarg": "from-plugin"}
|
||||
|
||||
|
||||
class TestAdapterPluginRegistry:
|
||||
def test_lora_like_plugin_adapter_contributes_peft_kwargs(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "fake-adapter",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
}
|
||||
)
|
||||
PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin()
|
||||
captured = {}
|
||||
|
||||
class FakeLoraConfig:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
|
||||
monkeypatch.setattr(adapter_module, "get_peft_model", Mock())
|
||||
|
||||
_, config = load_adapter(model, cfg, "fake-adapter", config_only=True)
|
||||
|
||||
assert config is not None
|
||||
assert captured["fake_kwarg"] == "from-plugin"
|
||||
assert captured["task_type"].name == "CAUSAL_LM"
|
||||
|
||||
def test_unknown_adapter_error_mentions_plugin_registry(self):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault({"adapter": "missing-adapter"})
|
||||
|
||||
with pytest.raises(NotImplementedError, match="registered by a plugin"):
|
||||
load_adapter(model, cfg, "missing-adapter")
|
||||
|
||||
def test_relora_accepts_plugin_adapter_capability(self, min_base_cfg):
|
||||
PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin()
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "fake-adapter",
|
||||
"relora": True,
|
||||
"jagged_restart_steps": 100,
|
||||
}
|
||||
)
|
||||
|
||||
validated = validate_config(cfg)
|
||||
|
||||
assert validated.adapter == "fake-adapter"
|
||||
assert validated.relora is True
|
||||
186
tests/monkeypatch/test_relora.py
Normal file
186
tests/monkeypatch/test_relora.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Unit tests for axolotl.monkeypatch.relora.reset_optimizer."""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from axolotl.monkeypatch.relora import (
|
||||
magnitude_pruning_,
|
||||
random_pruning_,
|
||||
reset_optimizer,
|
||||
)
|
||||
|
||||
ADAM_KEYS = ["exp_avg", "exp_avg_sq"]
|
||||
|
||||
|
||||
def _build_optimizer_with_state(seed: int = 0):
|
||||
"""Build a tiny optimizer over LoRA-shaped + non-LoRA params with populated state."""
|
||||
torch.manual_seed(seed)
|
||||
lora_a = nn.Parameter(torch.randn(8, 32))
|
||||
lora_b = nn.Parameter(torch.randn(32, 8))
|
||||
extra = nn.Parameter(torch.randn(64, 32))
|
||||
|
||||
optimizer = torch.optim.AdamW([lora_a, lora_b, extra], lr=1e-3)
|
||||
for _ in range(2):
|
||||
loss = (
|
||||
(lora_a * torch.randn_like(lora_a)).sum()
|
||||
+ (lora_b * torch.randn_like(lora_b)).sum()
|
||||
+ (extra * torch.randn_like(extra)).sum()
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
return optimizer, lora_a, lora_b, extra
|
||||
|
||||
|
||||
def test_reset_optimizer_only_touches_reset_params():
|
||||
"""State for params NOT in reset_params must be byte-identical after reset."""
|
||||
optimizer, lora_a, lora_b, extra = _build_optimizer_with_state()
|
||||
|
||||
extra_avg_before = optimizer.state[extra]["exp_avg"].clone()
|
||||
extra_avg_sq_before = optimizer.state[extra]["exp_avg_sq"].clone()
|
||||
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b],
|
||||
optimizer_state_keys=ADAM_KEYS,
|
||||
prune_method="magnitude",
|
||||
prune_ratio=0.9,
|
||||
)
|
||||
|
||||
assert torch.equal(optimizer.state[extra]["exp_avg"], extra_avg_before)
|
||||
assert torch.equal(optimizer.state[extra]["exp_avg_sq"], extra_avg_sq_before)
|
||||
|
||||
|
||||
def test_reset_optimizer_actually_prunes_lora_state():
|
||||
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b],
|
||||
optimizer_state_keys=ADAM_KEYS,
|
||||
prune_method="magnitude",
|
||||
prune_ratio=0.9,
|
||||
)
|
||||
|
||||
for param in (lora_a, lora_b):
|
||||
for key in ADAM_KEYS:
|
||||
zero_frac = (optimizer.state[param][key] == 0).float().mean().item()
|
||||
assert zero_frac >= 0.85
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method,ratio,expected_zero_frac",
|
||||
[
|
||||
("magnitude", 0.9, 0.9),
|
||||
("magnitude", 0.99, 0.99),
|
||||
("random", 0.9, 0.9),
|
||||
("random", 0.5, 0.5),
|
||||
# reset uses random pruning; relora_prune_ratio must be honored, not ignored.
|
||||
("reset", 0.9, 0.9),
|
||||
("reset", 0.5, 0.5),
|
||||
],
|
||||
)
|
||||
def test_prune_methods(method, ratio, expected_zero_frac):
|
||||
"""Each method zeros approximately the expected fraction."""
|
||||
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state(seed=42)
|
||||
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b],
|
||||
optimizer_state_keys=ADAM_KEYS,
|
||||
prune_method=method,
|
||||
prune_ratio=ratio,
|
||||
)
|
||||
|
||||
total = 0
|
||||
zeros = 0
|
||||
for param in (lora_a, lora_b):
|
||||
for key in ADAM_KEYS:
|
||||
tensor = optimizer.state[param][key]
|
||||
total += tensor.numel()
|
||||
zeros += (tensor == 0).sum().item()
|
||||
|
||||
actual = zeros / total
|
||||
tolerance = 0.02 if method == "magnitude" else 0.05
|
||||
assert math.isclose(actual, expected_zero_frac, abs_tol=tolerance)
|
||||
|
||||
|
||||
def test_reset_optimizer_skips_keys_not_in_state_keys():
|
||||
"""Keys present in optimizer state but not in optimizer_state_keys are untouched."""
|
||||
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||
|
||||
exp_avg_sq_before = optimizer.state[lora_a]["exp_avg_sq"].clone()
|
||||
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b],
|
||||
optimizer_state_keys=["exp_avg"],
|
||||
prune_method="magnitude",
|
||||
prune_ratio=0.9,
|
||||
)
|
||||
|
||||
assert torch.equal(optimizer.state[lora_a]["exp_avg_sq"], exp_avg_sq_before)
|
||||
|
||||
|
||||
def test_reset_optimizer_handles_param_with_empty_state():
|
||||
"""Params with no optimizer state are skipped silently."""
|
||||
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||
orphan = nn.Parameter(torch.randn(4, 4))
|
||||
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b, orphan],
|
||||
optimizer_state_keys=ADAM_KEYS,
|
||||
prune_method="magnitude",
|
||||
prune_ratio=0.9,
|
||||
)
|
||||
|
||||
assert orphan not in optimizer.state or not optimizer.state[orphan]
|
||||
|
||||
|
||||
def test_unknown_prune_method_raises():
|
||||
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown prune_method"):
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b],
|
||||
optimizer_state_keys=ADAM_KEYS,
|
||||
prune_method="bogus", # type: ignore[arg-type]
|
||||
prune_ratio=0.9,
|
||||
)
|
||||
|
||||
|
||||
def test_pruning_helpers_are_inplace():
|
||||
"""magnitude_pruning_ and random_pruning_ must mutate via tensor.mul_."""
|
||||
tensor = torch.randn(64)
|
||||
ptr_before = tensor.data_ptr()
|
||||
magnitude_pruning_(tensor, 0.5)
|
||||
assert tensor.data_ptr() == ptr_before
|
||||
|
||||
tensor = torch.randn(64)
|
||||
ptr_before = tensor.data_ptr()
|
||||
random_pruning_(tensor, 0.5)
|
||||
assert tensor.data_ptr() == ptr_before
|
||||
|
||||
|
||||
def test_pruning_helpers_support_uint8_tensors():
|
||||
"""Both pruning helpers must work on uint8 optimizer state tensors."""
|
||||
tensor = torch.arange(1, 129, dtype=torch.uint8)
|
||||
magnitude_pruning_(tensor, 0.9)
|
||||
|
||||
assert tensor.dtype == torch.uint8
|
||||
magnitude_zero_frac = (tensor == 0).float().mean().item()
|
||||
assert 0.85 <= magnitude_zero_frac <= 0.95
|
||||
|
||||
tensor = torch.arange(1, 129, dtype=torch.uint8)
|
||||
with torch.random.fork_rng(devices=[]):
|
||||
torch.manual_seed(1234)
|
||||
random_pruning_(tensor, 0.9)
|
||||
|
||||
assert tensor.dtype == torch.uint8
|
||||
random_zero_frac = (tensor == 0).float().mean().item()
|
||||
assert 0.85 <= random_zero_frac <= 0.95
|
||||
@@ -1,11 +1,5 @@
|
||||
"""Tests for attn_implementation: input normalization, canonical-value
|
||||
acceptance, capability flags, backend registration, and downstream validators.
|
||||
|
||||
Test classes are organized by feature concern, not by the layer of the schema
|
||||
where the behavior is implemented (classmethod normalizer vs. field validator
|
||||
vs. full `validate_config` pipeline). Each class covers a single contract end
|
||||
to end, dropping into the lower layer only where it gives faster or sharper
|
||||
coverage of an isolated branch.
|
||||
"""Tests for attn_implementation: normalization, canonical-value acceptance,
|
||||
capability flags, backend registration, and downstream validators.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -53,12 +47,7 @@ def _xformers_available():
|
||||
|
||||
|
||||
class TestCapabilityTables:
|
||||
"""Backend capability classification.
|
||||
|
||||
Asserts both the static frozensets in `enums.py` and the `computed_field`
|
||||
properties on a validated config read consistently from those tables, and
|
||||
that user YAML cannot override the computed flags.
|
||||
"""
|
||||
"""Backend capability classification via frozensets and computed_field properties."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"impl",
|
||||
@@ -359,8 +348,6 @@ class TestGemma4HybridMode:
|
||||
assert result["attn_implementation"] == "flash_attention_2"
|
||||
|
||||
def test_non_fa2_raises(self):
|
||||
"""The hybrid path requires FA2 under the hood — any other backend is
|
||||
a configuration error."""
|
||||
with pytest.raises(
|
||||
ValueError, match="requires attn_implementation=flash_attention_2"
|
||||
):
|
||||
@@ -370,11 +357,7 @@ class TestGemma4HybridMode:
|
||||
|
||||
|
||||
class TestSamplePackingValidation:
|
||||
"""`sample_packing` requires a varlen-capable backend.
|
||||
|
||||
Non-varlen backends (eager, sdpa) warn about cross-sample contamination;
|
||||
s2 raises outright because shifted-sparse attention has no varlen path.
|
||||
"""
|
||||
"""`sample_packing` warns for non-varlen backends; s2 raises outright."""
|
||||
|
||||
def test_eager_warns(self, min_base_cfg, caplog):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
|
||||
163
tests/test_mm_chat_collator.py
Normal file
163
tests/test_mm_chat_collator.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Regression tests for MultiModalChatDataCollator shape contracts.
|
||||
|
||||
Guard against the transformers 5.x breakage where apply_chat_template's
|
||||
own `return_dict` parameter (default False) caused it to return the raw
|
||||
input_ids tensor instead of the full BatchFeature dict, leading to
|
||||
IndexError: too many indices for tensor of dimension 2
|
||||
when downstream code did batch["input_ids"] on the resulting tensor.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import BatchFeature
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_processor")
|
||||
def fixture_mock_processor():
|
||||
"""
|
||||
A mock processor whose apply_chat_template returns a BatchFeature
|
||||
when called with return_dict=True (the correct call convention),
|
||||
or a raw input_ids tensor when called without return_dict=True
|
||||
(the broken call convention that the bug introduced).
|
||||
"""
|
||||
processor = MagicMock()
|
||||
processor.tokenizer = MagicMock()
|
||||
processor.tokenizer.pad_token_id = 0
|
||||
processor.image_token = "<|image|>"
|
||||
processor.tokenizer.convert_tokens_to_ids = MagicMock(return_value=128256)
|
||||
|
||||
batch_size, seq_len = 2, 16
|
||||
input_ids = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||||
|
||||
batch_feature = BatchFeature(
|
||||
data={
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
|
||||
def _apply_chat_template(*args, **kwargs):
|
||||
if kwargs.get("return_dict", False):
|
||||
return batch_feature
|
||||
# Simulate transformers 5.x default behaviour: returns out["input_ids"]
|
||||
return input_ids
|
||||
|
||||
processor.apply_chat_template = MagicMock(side_effect=_apply_chat_template)
|
||||
processor.chat_template = None
|
||||
return processor
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_processing_strategy")
|
||||
def fixture_mock_processing_strategy(mock_processor):
|
||||
from axolotl.processing_strategies import ProcessingStrategy
|
||||
|
||||
strategy = ProcessingStrategy(processor=mock_processor)
|
||||
return strategy
|
||||
|
||||
|
||||
class TestMultiModalChatDataCollatorShapeContract:
|
||||
"""
|
||||
Verify that MultiModalChatDataCollator.process_rows returns a dict with
|
||||
2-D input_ids and labels, not a raw tensor. This is the shape contract
|
||||
that process_labels depends on.
|
||||
"""
|
||||
|
||||
def _make_collator(self, mock_processing_strategy):
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
|
||||
tokenizer = mock_processing_strategy.processor.tokenizer
|
||||
return MultiModalChatDataCollator(
|
||||
tokenizer=tokenizer,
|
||||
processing_strategy=mock_processing_strategy,
|
||||
)
|
||||
|
||||
def _make_examples(self):
|
||||
return [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def test_process_rows_returns_dict(self, mock_processing_strategy):
|
||||
"""batch must be a dict, not a raw tensor."""
|
||||
collator = self._make_collator(mock_processing_strategy)
|
||||
examples = self._make_examples()
|
||||
|
||||
with patch.object(
|
||||
mock_processing_strategy,
|
||||
"__call__",
|
||||
return_value=examples,
|
||||
):
|
||||
batch = collator.process_rows(examples)
|
||||
|
||||
assert isinstance(batch, dict), (
|
||||
"process_rows must return a dict (BatchFeature), not a raw tensor. "
|
||||
"If it returns a tensor, apply_chat_template was called without "
|
||||
"return_dict=True at the top level."
|
||||
)
|
||||
|
||||
def test_process_rows_input_ids_shape(self, mock_processing_strategy):
|
||||
"""batch['input_ids'] must be a 2-D tensor (batch, seq_len)."""
|
||||
collator = self._make_collator(mock_processing_strategy)
|
||||
examples = self._make_examples()
|
||||
|
||||
with patch.object(
|
||||
mock_processing_strategy,
|
||||
"__call__",
|
||||
return_value=examples,
|
||||
):
|
||||
batch = collator.process_rows(examples)
|
||||
|
||||
assert "input_ids" in batch
|
||||
assert isinstance(batch["input_ids"], torch.Tensor)
|
||||
assert batch["input_ids"].ndim == 2, (
|
||||
f"input_ids must be 2-D (batch, seq_len), got shape {batch['input_ids'].shape}"
|
||||
)
|
||||
|
||||
def test_process_rows_labels_shape(self, mock_processing_strategy):
|
||||
"""batch['labels'] must be a 2-D tensor matching input_ids shape."""
|
||||
collator = self._make_collator(mock_processing_strategy)
|
||||
examples = self._make_examples()
|
||||
|
||||
with patch.object(
|
||||
mock_processing_strategy,
|
||||
"__call__",
|
||||
return_value=examples,
|
||||
):
|
||||
batch = collator.process_rows(examples)
|
||||
|
||||
assert "labels" in batch
|
||||
assert isinstance(batch["labels"], torch.Tensor)
|
||||
assert batch["labels"].ndim == 2
|
||||
assert batch["labels"].shape == batch["input_ids"].shape
|
||||
|
||||
def test_apply_chat_template_called_with_return_dict_true(
|
||||
self, mock_processing_strategy
|
||||
):
|
||||
"""apply_chat_template must be called with return_dict=True as a keyword arg."""
|
||||
collator = self._make_collator(mock_processing_strategy)
|
||||
examples = self._make_examples()
|
||||
|
||||
with patch.object(
|
||||
mock_processing_strategy,
|
||||
"__call__",
|
||||
return_value=examples,
|
||||
):
|
||||
collator.process_rows(examples)
|
||||
|
||||
call_kwargs = (
|
||||
mock_processing_strategy.processor.apply_chat_template.call_args.kwargs
|
||||
)
|
||||
assert call_kwargs.get("return_dict") is True, (
|
||||
"apply_chat_template must be called with return_dict=True as a top-level "
|
||||
"keyword argument (not inside processor_kwargs). In transformers 5.x, "
|
||||
"apply_chat_template has its own return_dict param (default False) that "
|
||||
"controls whether it returns the full BatchFeature or just input_ids."
|
||||
)
|
||||
@@ -1,16 +1,8 @@
|
||||
"""Guard the attn_implementation source-of-truth invariant.
|
||||
"""Enforce attn_implementation as the single source of truth.
|
||||
|
||||
`cfg.attn_implementation` is the single source of truth for the attention
|
||||
backend on the validated config. Legacy boolean flags (`flash_attention`,
|
||||
`sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`,
|
||||
`s2_attention`, `eager_attention`) are input-only deprecated aliases — they
|
||||
are stripped from `data` by `normalize_attn_implementation` and must never be
|
||||
read downstream.
|
||||
|
||||
This test greps `src/` and fails if it finds a `cfg.<legacy>_attention` read.
|
||||
If you're here because this test failed, migrate the read site to
|
||||
`cfg.attn_implementation` or one of the `attn_supports_packing /
|
||||
attn_uses_flash_lib / attn_needs_dtype_cast` computed capability flags.
|
||||
Fails if src/ contains a cfg.<legacy>_attention read. Migrate offending sites
|
||||
to cfg.attn_implementation or the attn_supports_packing/attn_uses_flash_lib/
|
||||
attn_needs_dtype_cast computed flags.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
1164
tests/test_processing_strategies.py
Normal file
1164
tests/test_processing_strategies.py
Normal file
File diff suppressed because it is too large
Load Diff
100
tests/utils/schemas/validation/mora/test_mora_validation.py
Normal file
100
tests/utils/schemas/validation/mora/test_mora_validation.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Validation tests for the MoRA / ReMoRA integration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.integrations.mora import MoraType
|
||||
from axolotl.utils.config import prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestMoraValidation:
|
||||
"""MoRA-specific config validation."""
|
||||
|
||||
def test_mora_block_round_trips(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": "rope",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
validated = validate_config(cfg)
|
||||
|
||||
assert validated.adapter == "mora"
|
||||
assert validated.mora.use_mora is True
|
||||
assert validated.mora.mora_type == MoraType.ROPE
|
||||
|
||||
def test_mora_type_accepts_legacy_supported_numbers(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": 1,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
validated = validate_config(cfg)
|
||||
|
||||
assert validated.mora.mora_type == MoraType.SHARING
|
||||
|
||||
def test_mora_rejects_unsupported_variant_numbers(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": 2,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
with pytest.raises(ValueError, match="mora_type"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_remora_uses_core_relora_fields(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
|
||||
"relora": True,
|
||||
"jagged_restart_steps": 2000,
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": "rope",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
validated = validate_config(cfg)
|
||||
|
||||
assert validated.relora is True
|
||||
assert validated.jagged_restart_steps == 2000
|
||||
|
||||
def test_remora_still_requires_core_restart_steps(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
|
||||
"relora": True,
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": "rope",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
with pytest.raises(ValueError, match="jagged_restart_steps"):
|
||||
validate_config(cfg)
|
||||
Reference in New Issue
Block a user