Compare commits

..

73 Commits

Author SHA1 Message Date
Dan Saunders
8564961423 fix compile 2025-09-19 13:59:57 -04:00
Dan Saunders
ce21da9177 fix compile 2025-09-19 13:55:54 -04:00
Dan Saunders
b5dc58373f fix compile 2025-09-19 13:52:42 -04:00
Dan Saunders
7327144344 compile 2025-09-19 13:41:12 -04:00
Dan Saunders
fb11f696e9 bench sweep 2025-09-19 13:24:40 -04:00
Dan Saunders
105c817b0b default fix 2025-09-19 16:59:20 +00:00
Dan Saunders
64345e7707 recurse fix 2025-09-19 12:58:58 -04:00
Dan Saunders
0f8b921399 contig 2025-09-19 12:47:53 -04:00
Dan Saunders
336616d659 defaults 2025-09-19 16:45:39 +00:00
Dan Saunders
d2f1e23bcd fix 2025-09-19 12:45:18 -04:00
Dan Saunders
42aadc5069 bench fix 2025-09-19 12:34:08 -04:00
Dan Saunders
1e7302d30a bench fix 2025-09-19 12:20:35 -04:00
Dan Saunders
63544ce709 fix 2025-09-19 11:34:27 -04:00
Dan Saunders
3bfed0aac8 shared expert detection 2025-09-19 11:24:26 -04:00
Dan Saunders
bfc848f81d bits and pieces 2025-09-19 02:12:57 +00:00
Dan Saunders
abe1cad6bc another bench 2025-09-18 13:45:19 -04:00
Dan Saunders
354389caef torchtitan bench 2025-09-18 13:29:20 -04:00
Dan Saunders
efcd032fce yet another refactor 2025-09-18 13:03:28 -04:00
Dan Saunders
7500641601 yet another refactor 2025-09-18 12:47:15 -04:00
Dan Saunders
0295df5bca precompute fuse 2025-09-18 12:10:46 -04:00
Dan Saunders
b39ef54833 combine mult 2025-09-18 12:08:03 -04:00
Dan Saunders
ad4cd39bcd remove contig 2025-09-18 11:55:15 -04:00
Dan Saunders
5c197275ad inplace 2025-09-18 11:51:17 -04:00
Dan Saunders
19c91e3675 refactor 2025-09-18 11:44:21 -04:00
Dan Saunders
2a176e4923 fix 2025-09-18 11:29:33 -04:00
Dan Saunders
7d867de9b2 refactor 2025-09-18 11:23:15 -04:00
Dan Saunders
01b6792c2e refactor 2025-09-18 11:20:08 -04:00
Dan Saunders
bbf1f14ca4 dtype issues 2025-09-17 23:52:18 +00:00
Dan Saunders
c6878beb7d simplify 2025-09-17 19:15:34 -04:00
Dan Saunders
e62979d11d fix 2025-09-17 18:53:07 -04:00
Dan Saunders
d57b9c67c2 log 2025-09-17 18:52:27 -04:00
Dan Saunders
eaaf16aa00 cumulative offsets 2025-09-17 18:45:15 -04:00
Dan Saunders
f3b953e222 fix? 2025-09-17 18:42:10 -04:00
Dan Saunders
7935dc0911 dtype fix 2025-09-17 18:36:22 -04:00
Dan Saunders
d2b49b2670 error msg 2025-09-17 18:29:30 -04:00
Dan Saunders
b5cb345ca4 fix test 2025-09-17 18:24:00 -04:00
Dan Saunders
03d4c2683e fix perf degradation 2025-09-17 18:20:37 -04:00
Dan Saunders
fd87eed501 minify 2025-09-17 16:42:35 -04:00
Dan Saunders
129db67705 fix 2025-09-17 16:24:29 -04:00
Dan Saunders
38b890a36b fix 2025-09-17 16:16:41 -04:00
Dan Saunders
180920c7bf simplify 2025-09-17 19:49:18 +00:00
Dan Saunders
d024048d74 logs + fix 2025-09-17 14:50:49 -04:00
Dan Saunders
98dc945838 fix 2025-09-17 14:42:53 -04:00
Dan Saunders
108600cd69 update config 2025-09-17 14:36:24 -04:00
Dan Saunders
0e9387c395 fix 2025-09-17 14:35:36 -04:00
Dan Saunders
db61e0d4ff fix 2025-09-17 14:26:25 -04:00
Dan Saunders
51e565f60a logs 2025-09-17 14:15:51 -04:00
Dan Saunders
c774dd0409 refactor + fix 2025-09-17 14:01:39 -04:00
Dan Saunders
7289e0cb55 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
8d483c11f7 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
9c1829cf57 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
135b09d1de logs, qwen2 support 2025-09-17 13:44:26 -04:00
Dan Saunders
de4344a56e patch 2025-09-17 13:44:26 -04:00
Dan Saunders
7d572b58d1 just grouped_mm for now 2025-09-17 13:44:26 -04:00
Dan Saunders
773d7e4291 update 2025-09-17 13:44:26 -04:00
Dan Saunders
fef47a5b7c hardening 2025-09-17 13:44:26 -04:00
Dan Saunders
f6ed8ddc01 fix 2025-09-17 13:44:26 -04:00
Dan Saunders
556d6448fe fix 2025-09-17 13:44:26 -04:00
Dan Saunders
5c2229721d diag 2025-09-17 13:44:26 -04:00
Dan Saunders
d7de6b0e96 grouped_mm 2025-09-17 13:44:26 -04:00
Dan Saunders
3c6648678f numerics 2025-09-17 13:44:26 -04:00
Dan Saunders
5b19a1ea9c improve 2025-09-17 13:44:26 -04:00
Dan Saunders
cfefad1eea fix 2025-09-17 13:44:26 -04:00
Dan Saunders
125e7b5fe6 fast path 2025-09-17 13:44:26 -04:00
Dan Saunders
479b6144df tflops 2025-09-17 13:44:26 -04:00
Dan Saunders
68da65cba2 update 2025-09-17 13:44:26 -04:00
Dan Saunders
0d689bb421 cache, example 2025-09-17 13:44:26 -04:00
Dan Saunders
43ada1278a moe kernels init scaffold 2025-09-17 13:44:26 -04:00
Dan Saunders
4065bc14c6 Debug log, logging improvements (#3159)
* simplify logging

* remove comment

* progress on debug.log

* add debug-level logger for file log

* simplify

* case insensitivity; 3rd party logging improvements

* simplify

* fix

* tests

* lint

* nits

* nit

* Update tests/test_utils_tee.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* cleanup / comments

* fix

* oops

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-09-17 13:27:03 -04:00
salman
e5c427f6de qat doc updates (#3162) [skip-ci] 2025-09-17 10:38:15 +01:00
Wing Lian
86d6ee7c05 upgrade trl and accelerate (#3161)
* upgrade trl==0.23.0

* upgrade accelerate patch fix

* add hints when using gradient_checkpointing with DPO

* set gradient-checpointing properly
2025-09-16 14:53:01 -04:00
Wing Lian
d4cff1b7bb improve setting of NCCL_P2P_DISABLE on runpod (#3132) [skip ci]
* improve setting of NCCL_P2P_DISABLE on runpod

* use recs from review
2025-09-16 14:52:45 -04:00
Wing Lian
1ef6c196f7 setup env vars for ray train for FSDP (#3130) [skip ci] 2025-09-16 14:52:29 -04:00
48 changed files with 9180 additions and 170 deletions

3
.gitignore vendored
View File

@@ -190,3 +190,6 @@ out/
# vim
*.swp
# scm auto-versioning
src/axolotl/_version.py

View File

@@ -14,7 +14,7 @@ repos:
rev: v0.12.12
hooks:
- id: ruff
args: [--fix, --select, I]
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1

View File

@@ -285,6 +285,7 @@ website:
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/moe_backends.md
- docs/nd_parallelism.qmd
- section: "Troubleshooting"

18
docs/moe_backends.md Normal file
View File

@@ -0,0 +1,18 @@
MoE Backends in Axolotl
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
- Set `moe_backend: auto|torch_grouped|naive`
Behavior
- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive.
- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended).
- naive: keeps the reference per-expert loop.
Notes
- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop.
- No changes to training scripts are required; selection happens inside the model forward.
Example
moe_backend: torch_grouped
accelerate launch -m axolotl.cli.train path/to/config.yaml

View File

@@ -23,10 +23,17 @@ To enable QAT in axolotl, add the following to your configuration file:
```yaml
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```
We support the following quantization schemas:
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
- `Int8DynamicActivationInt4Weight`
- `Float8DynamicActivationFloat8Weight`
- `Float8DynamicActivationInt4Weight`
- `NVFP4`
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.

View File

@@ -22,8 +22,8 @@ Quantization is configured using the `quantization` key in your configuration fi
```yaml
base_model: # The path to the model to quantize.
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
@@ -39,9 +39,8 @@ you used to train the model:
# qat.yml
qat:
activation_dtype: int8
weight_dtype: int8
weight_dtype: int4
group_size: 256
quantize_embedding: true
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
```

View File

@@ -251,10 +251,10 @@
},
"outputs": [],
"source": [
"from axolotl.utils import patch_optimized_env\n",
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
"\n",
"# speedup downloads from HF 🤗 and set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
"patch_optimized_env()"
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
"set_pytorch_cuda_alloc_conf()"
]
},
{

View File

@@ -0,0 +1,53 @@
base_model: Qwen/Qwen1.5-MoE-A2.7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
# Keep VRAM low
load_in_8bit: false
load_in_4bit: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/qwen2-moe-qlora-10gb
# Train small to fit 10GB
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 5
flash_attention: true
warmup_ratio: 0.03
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
model_config:
output_router_logits: true
special_tokens:

View File

@@ -32,7 +32,7 @@ line-length = 88
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "W", "C90", "B"]
select = ["E", "F", "W", "C90", "B", "I"]
ignore = [
"E203", # Whitespace before ':'
"E501", # Line too long

View File

@@ -15,10 +15,10 @@ huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.56.1
tokenizers>=0.21.1
accelerate==1.10.0
accelerate==1.10.1
datasets==4.0.0
deepspeed>=0.17.0
trl==0.21.0
trl==0.23.0
hf_xet==1.1.5
kernels==0.9.0
trackio

209
scripts/bench_moe.py Normal file
View File

@@ -0,0 +1,209 @@
#!/usr/bin/env python
"""Benchmark Hugging Face Qwen2 MoE block with and without grouped_mm."""
from __future__ import annotations
import argparse
import sys
import time
import weakref
from pathlib import Path
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def bench(run, *, iters: int, warmup: int, sync: bool = True) -> float:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for _ in range(warmup):
run()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(iters):
if sync and device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def load_hf_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
def main() -> None:
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=32)
p.add_argument("--top_k", type=int, default=4)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--profile", action="store_true")
p.add_argument(
"--compile",
action="store_true",
help="Torch.compile both paths before benchmarking",
)
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = load_hf_block(
args.hidden,
args.inter,
args.experts,
args.top_k,
device=device,
dtype=dtype,
)
tokens = args.bsz * args.seq
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k}"
)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
# Optional torch.compile
run_grouped_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile naive failed ({exc}); using eager")
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp, block.gate, block.experts, block.top_k
)
return y
try:
run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile grouped failed ({exc}); using eager")
run_grouped_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl):
if impl is not None:
return impl(data)
if tg is None or not tg.available():
return torch.empty(0)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k)
return y if y is not None else torch.empty(0)
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
)
with torch.no_grad():
y_ref = run_naive()
if tg is None or not tg.available():
print("torch_grouped\tN/A (unavailable)")
return
y_grouped = run_grouped()
if y_grouped.numel() == 0:
print("torch_grouped\tN/A (op not callable)")
return
t_grouped = bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped
print(
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
diff = (y_ref.float() - y_grouped.float()).abs()
print(
"torch_grouped_check: "
f"max_abs={diff.max().item():.3e} mean_abs={diff.mean().item():.3e} "
f"rel_l2={(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item():.3e}"
)
if args.profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_naive()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_grouped()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
if __name__ == "__main__":
main()

311
scripts/bench_moe_sweep.py Normal file
View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python
"""Sweep grouped_mm vs naive performance for Qwen2 MoE block."""
from __future__ import annotations
import argparse
import csv
import sys
import time
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import List
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def _parse_list(arg: str) -> List[int]:
return [int(v) for v in arg.split(",") if v]
def _bench(run, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
run()
if device.type == "cuda":
torch.cuda.synchronize()
times: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _load_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
@dataclass
class Result:
bsz: int
seq: int
hidden: int
inter: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def main() -> None:
p = argparse.ArgumentParser(description="Grouped MoE sweep")
p.add_argument("--batch-sizes", default="4,8,16")
p.add_argument("--seq-lens", default="512,1024,2048")
p.add_argument("--hidden", default="2048,4096")
p.add_argument("--inter", default="5632,8192,14336")
p.add_argument("--experts", default="8,16,32")
p.add_argument("--top-k", default="1,2,4")
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--csv", type=Path, default=None)
p.add_argument("--compile", action="store_true")
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
if tg is None or not tg.available():
print("torch_grouped unavailable; sweep aborted")
return
bs_list = _parse_list(args.batch_sizes)
seq_list = _parse_list(args.seq_lens)
hidden_list = _parse_list(args.hidden)
inter_list = _parse_list(args.inter)
expert_list = _parse_list(args.experts)
topk_list = _parse_list(args.top_k)
results: List[Result] = []
print(
"bsz\tseq\thidden\tinter\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
for bsz in bs_list:
for seq in seq_list:
tokens = bsz * seq
for hidden in hidden_list:
for inter in inter_list:
for experts in expert_list:
for top_k in topk_list:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = _load_block(
hidden,
inter,
experts,
top_k,
device=device,
dtype=dtype,
)
x = torch.randn(
bsz, seq, hidden, device=device, dtype=dtype
)
compiled_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile naive failed ({exc}); using eager"
)
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = (
weakref.ref(block)
) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp,
block.gate,
block.experts,
block.top_k,
)
return y
try:
compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile grouped failed ({exc}); using eager"
)
compiled_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(
block=block_grouped, data=x, impl=compiled_impl
):
if impl is not None:
return impl(data)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
data,
block.gate,
block.experts,
block.top_k,
)
return y
naive_ms = _bench(
run_naive,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_naive = run_naive()
grouped_ms = _bench(
run_grouped,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
res = Result(
bsz,
seq,
hidden,
inter,
experts,
top_k,
args.dtype,
naive_ms,
grouped_ms,
naive_ms / grouped_ms,
_estimate_flops(tokens, hidden, inter, top_k)
/ ((naive_ms / 1000.0) * 1e12),
_estimate_flops(tokens, hidden, inter, top_k)
/ ((grouped_ms / 1000.0) * 1e12),
diff.max().item(),
diff.mean().item(),
(
(
diff.pow(2).sum()
/ (y_naive.float().pow(2).sum() + 1e-12)
)
.sqrt()
.item()
),
)
results.append(res)
print(
f"{bsz}\t{seq}\t{hidden}\t{inter}\t{experts}\t{top_k}\t{res.naive_ms:.2f}\t"
f"{res.grouped_ms:.2f}\t{res.speedup:.2f}\t{res.naive_tflops:.2f}\t"
f"{res.grouped_tflops:.2f}\t{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
if args.csv:
fieldnames = [
"bsz",
"seq",
"hidden",
"inter",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with args.csv.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"bsz": r.bsz,
"seq": r.seq,
"hidden": r.hidden,
"inter": r.inter,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
if __name__ == "__main__":
import weakref
main()

View File

@@ -0,0 +1,205 @@
#!/usr/bin/env python
"""Benchmark Torchtitan MoE grouped vs naive expert execution."""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
import torch
# Ensure torchtitan is importable when running from the axolotl tree
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
if str(_TITAN_PATH) not in sys.path:
sys.path.append(str(_TITAN_PATH))
from torchtitan.models.moe import MoE, MoEArgs
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE microbenchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=8)
p.add_argument("--top_k", type=int, default=2)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--init-std", type=float, default=0.02)
p.add_argument(
"--score-before",
action="store_true",
help="Apply routing scores before expert computation (default: after)",
)
p.add_argument(
"--score-func",
choices=["softmax", "sigmoid"],
default="softmax",
)
p.add_argument(
"--route-norm",
action="store_true",
help="Enable Torchtitan router normalization when using sigmoid scores.",
)
return p.parse_args()
def _map_dtype(arg: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[arg]
def _estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
# Two up projections + one down projection per expert/token combination.
return 6.0 * tokens * top_k * hidden * inter
def _prepare_module(
moe: MoE,
*,
device: torch.device,
dtype: torch.dtype,
) -> MoE:
moe = moe.to(device=device)
for param in moe.parameters():
param.data = param.data.to(dtype)
if param.grad is not None:
param.grad = None
buffers = dict(moe.named_buffers())
for name, buf in buffers.items():
if name == "tokens_per_expert":
moe._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
elif name == "expert_bias" and buf is not None:
moe._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
else:
moe._buffers[name] = buf.to(device=device, dtype=dtype)
moe.eval()
return moe
@torch.inference_mode()
def _forward_fn(module: MoE, x: torch.Tensor) -> torch.Tensor:
return module(x)
def _bench(fn, *, iters: int, warmup: int, sync: bool = True) -> float:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
for _ in range(warmup):
fn()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(iters):
if sync and device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
fn()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def main() -> None:
args = _parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = _map_dtype(args.dtype)
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
moe_args_grouped = MoEArgs(
num_experts=args.experts,
num_shared_experts=0,
score_func=args.score_func,
route_norm=args.route_norm,
top_k=args.top_k,
use_grouped_mm=True,
score_before_experts=args.score_before,
load_balance_coeff=None,
)
moe_grouped = MoE(moe_args_grouped, dim=args.hidden, hidden_dim=args.inter)
moe_grouped.init_weights(args.init_std, buffer_device=device)
moe_args_naive = MoEArgs(
num_experts=args.experts,
num_shared_experts=0,
score_func=args.score_func,
route_norm=args.route_norm,
top_k=args.top_k,
use_grouped_mm=False,
score_before_experts=args.score_before,
load_balance_coeff=None,
)
moe_naive = MoE(moe_args_naive, dim=args.hidden, hidden_dim=args.inter)
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
tokens = args.bsz * args.seq
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} "
f"inter={args.inter} experts={args.experts} top_k={args.top_k}"
)
def run_naive():
return _forward_fn(moe_naive, x)
def run_grouped():
return _forward_fn(moe_grouped, x)
if hasattr(moe_naive, "tokens_per_expert"):
moe_naive.tokens_per_expert.zero_()
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
t_naive = _bench(run_naive, iters=args.iters, warmup=args.warmup)
flops = _estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
tflops_naive = flops / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t"
f"{tflops_naive:.2f} TFLOP/s"
)
y_naive = run_naive()
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
t_grouped = _bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped if t_grouped > 0 else float("nan")
print(
f"grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
print(
f"grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,328 @@
#!/usr/bin/env python
"""Sweep Torchtitan MoE grouped vs naive configurations and report performance."""
from __future__ import annotations
import argparse
import csv
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List
import torch
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
if str(_TITAN_PATH) not in sys.path:
sys.path.append(str(_TITAN_PATH))
from torchtitan.models.moe import MoE, MoEArgs
def _parse_int_list(value: str) -> List[int]:
return [int(v) for v in value.split(",") if v]
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
p.add_argument(
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
)
p.add_argument(
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
)
p.add_argument(
"--experts", default="8,16,32,64", help="Comma separated expert counts"
)
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--init-std", type=float, default=0.02)
p.add_argument("--score-before", action="store_true")
p.add_argument("--score-func", choices=["softmax", "sigmoid"], default="softmax")
p.add_argument("--route-norm", action="store_true")
p.add_argument("--csv", type=Path, default=None, help="Optional CSV output path")
return p.parse_args()
def _map_dtype(arg: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[arg]
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _prepare_module(module: MoE, *, device: torch.device, dtype: torch.dtype) -> MoE:
module = module.to(device=device)
for param in module.parameters():
param.data = param.data.to(dtype)
if param.grad is not None:
param.grad = None
for name, buf in module.named_buffers():
if name == "tokens_per_expert":
module._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
elif name == "expert_bias" and buf is not None:
module._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
else:
module._buffers[name] = buf.to(device=device, dtype=dtype)
module.eval()
return module
@torch.inference_mode()
def _forward(module: MoE, x: torch.Tensor) -> torch.Tensor:
return module(x)
def _bench(callable_, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
callable_()
if device.type == "cuda":
torch.cuda.synchronize()
timings: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
callable_()
if device.type == "cuda":
torch.cuda.synchronize()
timings.append((time.perf_counter() - start) * 1000.0)
return sum(timings) / len(timings)
@dataclass
class SweepResult:
bsz: int
seq: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def _run_case(
*,
bsz: int,
seq: int,
experts: int,
top_k: int,
hidden: int,
inter: int,
dtype: torch.dtype,
device: torch.device,
iters: int,
warmup: int,
init_std: float,
score_before: bool,
score_func: str,
route_norm: bool,
) -> SweepResult:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
moe_args_grouped = MoEArgs(
num_experts=experts,
num_shared_experts=0,
score_func=score_func,
route_norm=route_norm,
top_k=top_k,
use_grouped_mm=True,
score_before_experts=score_before,
load_balance_coeff=None,
)
moe_grouped = MoE(moe_args_grouped, dim=hidden, hidden_dim=inter)
moe_grouped.init_weights(init_std, buffer_device=device)
moe_args_naive = MoEArgs(
num_experts=experts,
num_shared_experts=0,
score_func=score_func,
route_norm=route_norm,
top_k=top_k,
use_grouped_mm=False,
score_before_experts=score_before,
load_balance_coeff=None,
)
moe_naive = MoE(moe_args_naive, dim=hidden, hidden_dim=inter)
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
x = torch.randn(bsz, seq, hidden, device=device, dtype=dtype)
def run_naive():
if hasattr(moe_naive, "tokens_per_expert"):
moe_naive.tokens_per_expert.zero_()
return _forward(moe_naive, x)
def run_grouped():
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
return _forward(moe_grouped, x)
naive_ms = _bench(run_naive, iters=iters, warmup=warmup, device=device)
y_naive = run_naive()
grouped_ms = _bench(run_grouped, iters=iters, warmup=warmup, device=device)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
tokens = bsz * seq
flops = _estimate_flops(tokens, hidden, inter, top_k)
naive_tflops = flops / ((naive_ms / 1000.0) * 1e12)
grouped_tflops = flops / ((grouped_ms / 1000.0) * 1e12)
speedup = naive_ms / grouped_ms if grouped_ms > 0 else float("nan")
return SweepResult(
bsz=bsz,
seq=seq,
experts=experts,
top_k=top_k,
dtype=str(dtype),
naive_ms=naive_ms,
grouped_ms=grouped_ms,
speedup=speedup,
naive_tflops=naive_tflops,
grouped_tflops=grouped_tflops,
max_abs=max_abs,
mean_abs=mean_abs,
rel_l2=rel_l2,
)
def _print_header(
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
) -> None:
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
print(
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
def _print_result(res: SweepResult) -> None:
print(
f"{res.bsz}\t{res.seq}\t{res.experts}\t{res.top_k}\t"
f"{res.naive_ms:.2f}\t{res.grouped_ms:.2f}\t{res.speedup:.2f}\t"
f"{res.naive_tflops:.2f}\t{res.grouped_tflops:.2f}\t"
f"{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
def _write_csv(path: Path, results: Iterable[SweepResult]) -> None:
fieldnames = [
"batch_size",
"seq_len",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with path.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"batch_size": r.bsz,
"seq_len": r.seq,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
def main() -> None:
args = _parse_args()
dtype = _map_dtype(args.dtype)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_sizes = _parse_int_list(args.batch_sizes)
seq_lens = _parse_int_list(args.seq_lens)
experts_list = _parse_int_list(args.experts)
top_ks = _parse_int_list(args.top_ks)
results: List[SweepResult] = []
_print_header(args.hidden, args.inter, dtype, device)
for bsz in batch_sizes:
for seq in seq_lens:
for experts in experts_list:
for top_k in top_ks:
try:
res = _run_case(
bsz=bsz,
seq=seq,
experts=experts,
top_k=top_k,
hidden=args.hidden,
inter=args.inter,
dtype=dtype,
device=device,
iters=args.iters,
warmup=args.warmup,
init_std=args.init_std,
score_before=args.score_before,
score_func=args.score_func,
route_norm=args.route_norm,
)
except RuntimeError as err:
print(
f"{bsz}\t{seq}\t{experts}\t{top_k}\tERROR: {err}",
file=sys.stderr,
)
continue
results.append(res)
_print_result(res)
if args.csv and results:
_write_csv(args.csv, results)
print(f"Wrote {len(results)} rows to {args.csv}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python
"""Inspect Qwen2 MoE expert implementations for grouped-mm debugging."""
from __future__ import annotations
import sys
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
sys.path.extend(
[
str(ROOT / "transformers" / "src"),
str(ROOT / "src"),
]
)
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from axolotl.kernels.moe.torch_grouped import _iter_expert_impls
def main() -> None:
cfg = Qwen2MoeConfig(
hidden_size=4096,
moe_intermediate_size=14336,
shared_expert_intermediate_size=14336,
num_experts=32,
num_experts_per_tok=4,
)
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
experts = block.experts
experts._ax_parent_block = block
impls = _iter_expert_impls(experts)
print(f"impl count: {len(impls)}")
for idx, impl in enumerate(impls[:8]):
has_gate = hasattr(impl, "gate_proj")
has_up = hasattr(impl, "up_proj")
print(
f"impl[{idx}] type={impl.__class__.__name__} has_gate={has_gate} has_up={has_up}"
)
if has_gate:
print(f" gate shape {tuple(impl.gate_proj.weight.shape)}")
print(f" up shape {tuple(impl.up_proj.weight.shape)}")
print(f" down shape {tuple(impl.down_proj.weight.shape)}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python
"""
Probe PyTorch for grouped GEMM operator names and namespaces.
Run: python scripts/probe_torch_grouped_ops.py
"""
import sys
def main():
try:
import torch
except Exception as e:
print("Failed to import torch:", e)
sys.exit(1)
print("torch version:", torch.__version__)
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
print("ops namespaces:", namespaces)
found_any = False
for ns in namespaces:
obj = getattr(torch.ops, ns, None)
ops = []
if obj is not None:
try:
ops = dir(obj)
except Exception as e:
print(f"warning: failed to list ops for namespace {ns}: {e}")
cands = [
o
for o in ops
if ("group" in o.lower())
or ("mm_grouped" in o.lower())
or ("matmul_grouped" in o.lower())
or ("grouped" in o.lower())
]
if cands:
found_any = True
print(f"namespace {ns} candidates:", cands)
if not found_any:
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
if __name__ == "__main__":
main()

View File

@@ -4,5 +4,7 @@ import os
from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
configure_logging()

View File

@@ -23,7 +23,8 @@ from axolotl.utils.config import (
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.tee import prepare_debug_log
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__)
@@ -227,8 +228,11 @@ def load_cfg(
},
)
# NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we
# have to wait for cfg.output to be resolved. We could call this earlier if we write
# to a temporary file, and then move it later.
prepare_debug_log(cfg)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg)
@@ -241,7 +245,6 @@ def load_cfg(
for k, v in cfg.items()
if v is not None
}
LOG.info(
"config:\n%s",
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),

View File

@@ -17,8 +17,6 @@ from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.cli.utils.diffusion import (
diffusion_inference,
launch_diffusion_gradio_ui,
render_html,
run_diffusion,
)
from axolotl.integrations.base import PluginManager
from axolotl.utils.chat_templates import get_chat_template_from_config

View File

@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
launch_training,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -44,7 +44,7 @@ def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
print_axolotl_text_art()
load_dotenv()
patch_optimized_env()
set_pytorch_cuda_alloc_conf()
@cli.command()

View File

@@ -17,6 +17,7 @@ from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import prepare_optim_env
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
@@ -59,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
@@ -92,6 +92,7 @@ def ray_train_func(kwargs: dict):
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
# also renormalize the config now that TorchTrainer has spawned distributed workers
cfg = DictDefault(kwargs["cfg"])
prepare_optim_env(cfg)
normalize_config(cfg)
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import gradio as gr
import torch
from colorama import Fore, Style
from axolotl.integrations.diffusion import generate, resolve_mask_token_id

View File

@@ -435,7 +435,7 @@ class TrainerBuilderBase(abc.ABC):
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False
training_args_kwargs["activation_offloading"] = True
elif self.cfg.gradient_checkpointing:
elif self.cfg.gradient_checkpointing is not None:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)

View File

@@ -0,0 +1,3 @@
from .backends import MOEBackend, get_moe_backend_name
__all__ = ["get_moe_backend_name", "MOEBackend"]

View File

@@ -0,0 +1,47 @@
import warnings
from enum import Enum
class MOEBackend(str, Enum):
AUTO = "auto"
TORCH_GROUPED = "torch_grouped"
NAIVE = "naive"
def _probe_torch_grouped() -> bool:
try:
import torch # noqa: F401
# Prefer a simple version check; exact APIs may vary across 2.8+.
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
return ver >= (2, 8)
except Exception:
return False
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
"""
Resolve the desired MoE backend using, in order of precedence:
- explicit preferred argument (e.g., from config)
- auto detection
"""
choice = (preferred or "auto").lower()
try:
selected = MOEBackend(choice)
except ValueError:
warnings.warn(
f"Unknown moe backend '{choice}', falling back to auto", stacklevel=2
)
selected = MOEBackend.AUTO
if selected == MOEBackend.AUTO:
if _probe_torch_grouped():
return MOEBackend.TORCH_GROUPED
return MOEBackend.NAIVE
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
warnings.warn(
"torch_grouped requested but torch>=2.8 not detected; falling back to naive",
stacklevel=2,
)
return MOEBackend.NAIVE
return selected

View File

@@ -0,0 +1,371 @@
"""Minimal grouped GEMM fast path for MoE experts using PyTorch _grouped_mm."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
_LOGGER = logging.getLogger("axolotl.moe.grouped")
def available() -> bool:
try:
major, minor = map(int, torch.__version__.split("+")[0].split(".")[:2])
if (major, minor) < (2, 8):
return False
if not torch.cuda.is_available():
return False
sm, _ = torch.cuda.get_device_capability()
if sm < 9:
return False
return hasattr(torch.ops, "_grouped_mm")
except Exception:
return False
def _iter_expert_impls(
experts_module, visited: Optional[set[int]] = None
) -> List[torch.nn.Module]:
if visited is None:
visited = set()
module_id = id(experts_module)
if module_id in visited:
return []
visited.add(module_id)
impls: List[torch.nn.Module] = []
for exp in experts_module:
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
if hasattr(candidate, "gate_proj") and hasattr(candidate, "up_proj"):
impls.append(candidate)
continue
nested = getattr(candidate, "experts", None)
if nested is not None:
impls.extend(_iter_expert_impls(nested, visited))
continue
raise RuntimeError(
"torch_grouped: unable to resolve expert implementation for module"
)
return impls
@dataclass
class _GroupedWeightStorage:
pattern: str
gate: torch.Tensor
up: torch.Tensor
down: torch.Tensor
fused_gate_up: torch.Tensor
dtype: torch.dtype
device: torch.device
def _allocate_fused_gate_up(
num_experts: int,
gate_shape: torch.Size,
up_shape: torch.Size,
*,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if gate_shape[1] != up_shape[1]:
raise RuntimeError(
"torch_grouped: gate and up projections must share the hidden dimension"
)
fused = torch.empty(
(num_experts, gate_shape[0] + up_shape[0], gate_shape[1]),
device=device,
dtype=dtype,
)
gate_view = fused[:, : gate_shape[0]]
up_view = fused[:, gate_shape[0] : gate_shape[0] + up_shape[0]]
return fused, gate_view, up_view
def _ensure_grouped_weights(
experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module
) -> _GroupedWeightStorage:
storage: Optional[_GroupedWeightStorage] = getattr(
experts_module, "_ax_grouped_storage", None
)
def _store(new_storage: _GroupedWeightStorage) -> _GroupedWeightStorage:
experts_module._ax_grouped_storage = new_storage
return new_storage
# Identify expert parameter layout
if (
hasattr(sample_mod, "w1")
and hasattr(sample_mod, "w3")
and hasattr(sample_mod, "w2")
):
pattern = "swi_glu"
num_experts = len(expert_impls)
w1_shape = sample_mod.w1.weight.shape
w3_shape = sample_mod.w3.weight.shape
w2_shape = sample_mod.w2.weight.shape
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.w1.weight.dtype
and storage.device == sample_mod.w1.weight.device
and storage.gate.shape[1:] == w1_shape
):
return storage
fused, gate, up = _allocate_fused_gate_up(
num_experts,
w1_shape,
w3_shape,
device=sample_mod.w1.weight.device,
dtype=sample_mod.w1.weight.dtype,
)
down = torch.empty(
(num_experts, *w2_shape),
device=sample_mod.w2.weight.device,
dtype=sample_mod.w2.weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate[idx].copy_(mod.w1.weight.detach())
up[idx].copy_(mod.w3.weight.detach())
down[idx].copy_(mod.w2.weight.detach())
mod.w1.weight.detach_()
mod.w1.weight.set_(gate[idx])
mod.w3.weight.detach_()
mod.w3.weight.set_(up[idx])
mod.w2.weight.detach_()
mod.w2.weight.set_(down[idx])
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=fused,
dtype=gate.dtype,
device=gate.device,
)
)
if hasattr(sample_mod, "gate_up_proj") and hasattr(sample_mod, "down_proj"):
pattern = "fused_gate_up"
gate_weight = sample_mod.gate_up_proj.weight
down_weight = sample_mod.down_proj.weight
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == gate_weight.dtype
and storage.device == gate_weight.device
and storage.gate.shape[1:]
== (gate_weight.shape[0] // 2, gate_weight.shape[1])
):
return storage
num_experts = len(expert_impls)
gate_full = torch.empty(
(num_experts, *gate_weight.shape),
device=gate_weight.device,
dtype=gate_weight.dtype,
)
down = torch.empty(
(num_experts, *down_weight.shape),
device=down_weight.device,
dtype=down_weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate_full[idx].copy_(mod.gate_up_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.gate_up_proj.weight.detach_()
mod.gate_up_proj.weight.set_(gate_full[idx])
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
inter = gate_weight.shape[0] // 2
gate = gate_full[:, :inter]
up = gate_full[:, inter:]
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=gate_full,
dtype=gate.dtype,
device=gate.device,
)
)
if (
hasattr(sample_mod, "up_proj")
and hasattr(sample_mod, "gate_proj")
and hasattr(sample_mod, "down_proj")
):
pattern = "dual_proj"
up_weight = sample_mod.up_proj.weight
gate_weight = sample_mod.gate_proj.weight
down_weight = sample_mod.down_proj.weight
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.up_proj.weight.dtype
and storage.device == sample_mod.up_proj.weight.device
and storage.gate.shape[1:] == gate_weight.shape
):
return storage
num_experts = len(expert_impls)
fused, gate, up = _allocate_fused_gate_up(
num_experts,
gate_weight.shape,
up_weight.shape,
device=gate_weight.device,
dtype=gate_weight.dtype,
)
down = torch.empty(
(num_experts, *down_weight.shape),
device=down_weight.device,
dtype=down_weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate[idx].copy_(mod.gate_proj.weight.detach())
up[idx].copy_(mod.up_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.up_proj.weight.detach_()
mod.up_proj.weight.set_(up[idx])
mod.gate_proj.weight.detach_()
mod.gate_proj.weight.set_(gate[idx])
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=fused,
dtype=gate.dtype,
device=gate.device,
)
)
raise RuntimeError(
"torch_grouped: unsupported expert module layout for grouped weights"
)
def moe_ffn_forward_grouped(
hidden_states: torch.Tensor,
gate_linear: torch.nn.Linear,
experts_module,
top_k: int,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if not available():
return None, None
bsz, seqlen, hdim = hidden_states.shape
tokens = bsz * seqlen
device = hidden_states.device
routing_dtype = gate_linear.weight.dtype
expert_dtype = hidden_states.dtype
if expert_dtype not in (torch.bfloat16, torch.float16):
_LOGGER.debug(
"torch_grouped: unsupported expert dtype %s; falling back to naive",
expert_dtype,
)
return None, None
parent_block = None
parent_ref = getattr(experts_module, "_ax_parent_block_ref", None)
if parent_ref is not None:
try:
parent_block = parent_ref()
except TypeError:
parent_block = None
expert_container = getattr(experts_module, "experts", experts_module)
expert_impls = _iter_expert_impls(expert_container)
sample_mod = expert_impls[0]
storage = _ensure_grouped_weights(expert_container, expert_impls, sample_mod)
w_gate = storage.gate
w_up = storage.up
w2 = storage.down
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
router_logits = gate_linear(x_flat.to(routing_dtype))
shared_out_flat: Optional[torch.Tensor] = None
shared_owner = parent_block if parent_block is not None else experts_module
if hasattr(shared_owner, "shared_expert"):
shared_expert = shared_owner.shared_expert
shared_out_flat = shared_expert(x_flat)
shared_out_flat = shared_out_flat.to(expert_dtype)
shared_gate = getattr(shared_owner, "shared_expert_gate", None)
if shared_gate is not None:
gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype))
gate_vals = torch.sigmoid(gate_input)
shared_out_flat.mul_(gate_vals.to(expert_dtype))
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
flat_idx = topk_idx.view(-1)
num_experts = len(expert_impls)
if flat_idx.numel() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
sorted_experts, perm = torch.sort(flat_idx)
assignments = torch.bincount(sorted_experts, minlength=num_experts)
if assignments.sum() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
routed_input = torch.gather(x_flat, 0, gather_index)
counts_i32 = assignments.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
routed_in = routed_input.to(mm_dtype)
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
w2_t = w2.transpose(-2, -1).to(mm_dtype)
routed_in = routed_in.contiguous()
w_gate_t = w_gate_t.contiguous()
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
torch.ops.aten.silu_(gate_out)
w_up_t = w_up_t.contiguous()
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
gate_out.mul_(up_out)
gate_out = gate_out.contiguous()
w2_t = w2_t.contiguous()
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
down_out.mul_(weights)
combined = torch.zeros_like(x_flat)
combined.scatter_add_(0, gather_index, down_out)
output = combined.view(bsz, seqlen, hdim)
if shared_out_flat is not None:
output = output + shared_out_flat.view(bsz, seqlen, hdim)
return output, router_logits

View File

@@ -12,6 +12,7 @@ import transformers
from transformers import PretrainedConfig, PreTrainedModel
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.moe_grouped import apply_grouped_to_moe_blocks
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
@@ -57,6 +58,8 @@ class PatchManager:
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
# Apply MoE grouped GEMM patches (cfg.moe_backend)
apply_grouped_to_moe_blocks(self.cfg)
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches()
@@ -269,6 +272,7 @@ class PatchManager:
self.cfg.model_config_type,
model_name=self.cfg.base_model,
has_remote_code=has_remote_code,
cfg=self.cfg,
)
if self.cfg.sample_packing:

View File

@@ -1,10 +1,7 @@
"""
Common logging module for axolotl
"""
"""Common logging module for axolotl."""
import logging
import os
import sys
from logging import Formatter, Logger, LogRecord
from logging.config import dictConfig
from typing import Any, Dict
@@ -17,9 +14,9 @@ DEFAULT_LOG_LEVEL = "WARNING"
class AxolotlOrWarnErrorFilter(logging.Filter):
"""
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL)
Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL)
Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default)
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at
INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records
(i.e. non-axolotl.INFO, DEBUG, etc. by default).
"""
def __init__(self, **kwargs):
@@ -52,13 +49,12 @@ class AxolotlOrWarnErrorFilter(logging.Filter):
class AxolotlLogger(Logger):
"""A Logger that automatically rejects non-axolotl INFOs."""
"""Logger that applies filtering to non-axolotl loggers."""
def __init__(self, name: str, level: int = logging.NOTSET):
super().__init__(name, level)
# set global filter on the logger itself
self.addFilter(AxolotlOrWarnErrorFilter())
if not name.startswith("axolotl"):
self.addFilter(AxolotlOrWarnErrorFilter())
class ColorfulFormatter(Formatter):
@@ -74,6 +70,7 @@ class ColorfulFormatter(Formatter):
def format(self, record):
record.rank = int(os.getenv("LOCAL_RANK", "0"))
record.rank_fmt = f" [RANK:{record.rank}]" if record.rank != 0 else ""
log_message = super().format(record)
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
@@ -87,32 +84,54 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
},
"colorful": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d]%(rank_fmt)s %(message)s",
},
"concise": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
},
"concise_color": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s]%(rank_fmt)s %(message)s",
},
},
"filters": {
"ax_or_warn": {
"()": "axolotl.logging_config.AxolotlOrWarnErrorFilter",
},
},
"filters": {},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "simple",
"filters": [],
"stream": sys.stdout,
"formatter": "concise",
"filters": ["ax_or_warn"],
"stream": "ext://sys.stdout",
},
"color_console": {
"class": "logging.StreamHandler",
"formatter": "colorful",
"filters": [],
"stream": sys.stdout,
"formatter": "concise_color",
"filters": ["ax_or_warn"],
"stream": "ext://sys.stdout",
},
"ax_file_only": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "simple",
"stream": "ext://axolotl.utils.tee.file_only_stream",
},
"root_file_only": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "simple",
"stream": "ext://axolotl.utils.tee.file_only_stream",
},
},
# log level will be superseded by the AxolotlLogger
"root": {
"handlers": ["console"],
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL),
"handlers": ["console", "root_file_only"],
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper(),
},
"loggers": {
"axolotl": {
"handlers": ["color_console"],
"handlers": ["color_console", "ax_file_only"],
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
"propagate": False,
},
@@ -123,9 +142,15 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
def configure_logging():
"""Configure with default logging"""
init() # Initialize colorama
dictConfig(DEFAULT_LOGGING_CONFIG)
logging.setLoggerClass(AxolotlLogger)
# set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
# Route Python warnings through logging so they reach file handlers
logging.captureWarnings(True)
# Set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
if "ACCELERATE_LOG_LEVEL" not in os.environ:
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv(
"LOG_LEVEL", DEFAULT_LOG_LEVEL
).upper()

View File

@@ -180,38 +180,6 @@ def get_state_dict(self, model, unwrap=True):
return state_dict
def cast_lora_module(module):
base_layer_dtype = module.base_layer.weight.dtype
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
for active_adapter in module.active_adapters:
if module.lora_A:
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_B:
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_A:
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_B:
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_magnitude_vector:
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
@@ -227,37 +195,18 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
fully_shard(module, **fsdp2_kwargs)
module.set_reshard_after_forward(False)
module.set_reshard_after_backward(False)
# for active_adapter in module.active_adapters:
# for adapter_name in [
# "lora_A",
# "lora_B",
# "lora_embedding_A",
# "lora_embedding_B",
# "lora_magnitude_vector",
# ]:
# adapter_module = getattr(module, adapter_name, None)
# # print(adapter_module, adapter_name)
# # torch.distributed.breakpoint()
# if not adapter_module:
# continue
# fsdp_adapter_module = fully_shard(adapter_module[active_adapter], **fsdp2_kwargs)
# # fsdp_adapter_module.unshard()
# fsdp_adapter_module.set_reshard_after_backward(False)
# fsdp_adapter_module.set_reshard_after_forward(False)
# torch.distributed.breakpoint()
# if module.lora_A:
# fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
# if module.lora_B:
# fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
# if module.lora_embedding_A:
# fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
# if module.lora_embedding_B:
# fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
# if module.lora_magnitude_vector:
# fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
for active_adapter in module.active_adapters:
if module.lora_A:
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
if module.lora_B:
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_A:
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_B:
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
if module.lora_magnitude_vector:
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
return log_bias_dtype_mismatch
@@ -371,26 +320,16 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
model.tie_weights()
is_peft_model = isinstance(model, PeftModel)
# TODO - this doesn't actually do anything
for name, module in model.named_children():
if name == "experts":
# torch.distributed.breakpoint()
for expert in module.children():
# torch.distributed.breakpoint()
print(f"expert: {expert}")
for lora_module in expert.children():
print(f"lora {lora_module}")
# torch.distributed.breakpoint()
cast_lora_module(lora_module)
_process_lora_module_for_fsdp(lora_module, fsdp2_kwargs)
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer) and not isinstance(module, FSDPModule):
# torch.distributed.breakpoint()
cast_lora_module(module)
# torch.distributed.breakpoint()
if is_peft_model and isinstance(module, LoraLayer):
module_log_bias_mismatch = _process_lora_module_for_fsdp(
module, fsdp2_kwargs
)
log_bias_dtype_mismatch |= module_log_bias_mismatch
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)
@@ -407,9 +346,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
)
# for module in model.named_modules():
# if "Lora" in
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# We re-register the buffers, as they may not be in the state_dict
for fqn, buffer_tensor in original_non_persistent_buffers.items():

View File

@@ -5,9 +5,14 @@ Patches to support multipack for mixtral
import torch
def patch_mixtral_moe_forward_zero3() -> None:
def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
import warnings
import torch.nn.functional as F
from axolotl.kernels.moe import backends as _moe_backends
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
def mlp_forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
hidden_states
@@ -21,21 +26,32 @@ def patch_mixtral_moe_forward_zero3() -> None:
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if (
backend == MOEBackend.TORCH_GROUPED
and not _moe_backends._probe_torch_grouped()
):
warnings.warn(
"torch_grouped selected but not available; falling back to naive",
stacklevel=2,
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(
routing_weights, self.top_k, dim=-1, sorted=False
)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states)
hidden_states_rep = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states_rep)
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
sel = flat_topk_idx == i
if sel.any():
y[sel] = expert(hidden_states_rep[sel])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
@@ -46,4 +62,23 @@ def patch_mixtral_moe_forward_zero3() -> None:
)
MixtralBlockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
# Wrap forward to support optional torch_grouped backend via config
from axolotl.kernels.moe import torch_grouped as _tg
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend == MOEBackend.TORCH_GROUPED and _tg.available():
def moe_forward_grouped(self, hidden_states: torch.Tensor) -> torch.Tensor:
bsz, seqlen, hdim = hidden_states.shape
y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k
)
if y is None:
return moe_forward(self, hidden_states)
return y, router_logits
MixtralSparseMoeBlock.forward = moe_forward_grouped
else:
MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -0,0 +1,133 @@
import logging
import weakref
from functools import wraps
import torch
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
_LOG = logging.getLogger("axolotl.moe.patch")
def _patch_block_forward(block_cls, grouped_fn):
"""Replace block_cls.forward with grouped_fn preserving signature."""
block_cls.forward = grouped_fn
def apply_grouped_to_moe_blocks(cfg=None) -> None:
"""
Attempt to patch all known MoE block classes to use the torch_grouped backend
when cfg.moe_backend resolves to 'torch_grouped' and the op is available.
Falls back to original forwards otherwise.
"""
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend != MOEBackend.TORCH_GROUPED:
_LOG.info(
f"moe_backend is '{backend}', not 'torch_grouped'; skipping grouped patches"
)
return
try:
from axolotl.kernels.moe import torch_grouped as _tg
except Exception:
_LOG.warning("torch_grouped backend import failed; skipping grouped patches")
return
if not _tg.available():
_LOG.warning(
"torch_grouped requested but unavailable (op smoke test failed); skipping grouped patches"
)
return
# Map of architecture key to (modeling module path, class name or list of class names)
model_mods = {
"mixtral": (
"transformers.models.mixtral.modeling_mixtral",
MOE_ARCH_BLOCK.get("mixtral"),
),
"qwen2_moe": (
"transformers.models.qwen2_moe.modeling_qwen2_moe",
MOE_ARCH_BLOCK.get("qwen2_moe"),
),
"qwen3_moe": (
"transformers.models.qwen3_moe.modeling_qwen3_moe",
MOE_ARCH_BLOCK.get("qwen3_moe"),
),
"jamba": (
"transformers.models.jamba.modeling_jamba",
MOE_ARCH_BLOCK.get("jamba"),
),
"deepseek_v2": (
"transformers.models.deepseek_v2.modeling_deepseek_v2",
MOE_ARCH_BLOCK.get("deepseek_v2"),
),
# Others may not follow standard paths; best-effort import
"dbrx": ("transformers.models.dbrx.modeling_dbrx", MOE_ARCH_BLOCK.get("dbrx")),
"jetmoe": (
"transformers.models.jetmoe.modeling_jetmoe",
MOE_ARCH_BLOCK.get("jetmoe"),
),
"gpt_oss": (
"transformers.models.gpt_oss.modeling_gpt_oss",
MOE_ARCH_BLOCK.get("gpt_oss"),
),
}
def make_grouped_forward(orig_forward):
@wraps(orig_forward)
def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs):
bsz, seqlen, hdim = hidden_states.shape
# expose parent block so grouped backend can access shared expert context
try:
self.experts._ax_parent_block_ref = weakref.ref(self)
except Exception:
pass
y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k
)
# One-time log per block instance indicating whether grouped engaged or fallback occurred
if not getattr(self, "_ax_grouped_wrapper_logged", False):
if y is None:
_LOG.warning(
"Grouped wrapper active but fell back to naive for %s",
self.__class__.__name__,
)
else:
_LOG.info(
f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})"
)
self._ax_grouped_wrapper_logged = True
if y is None:
return orig_forward(self, hidden_states, *args, **kwargs)
return y, router_logits
return _grouped_forward
patched = 0
for key, (mod_path, cls_names) in model_mods.items():
if not cls_names:
continue
try:
import importlib
modeling = importlib.import_module(mod_path)
names = cls_names if isinstance(cls_names, list) else [cls_names]
for name in names:
if not hasattr(modeling, name):
continue
block_cls = getattr(modeling, name)
orig_forward = getattr(block_cls, "forward", None)
if orig_forward is None:
continue
_patch_block_forward(block_cls, make_grouped_forward(orig_forward))
patched += 1
_LOG.info(f"Patched MoE block for grouped GEMM: {mod_path}.{name}")
except Exception as e:
# Best effort; log and skip this entry
_LOG.warning(f"Skipping MoE patch for arch '{key}' ({mod_path}): {e}")
if patched == 0:
_LOG.warning(
"No MoE blocks patched for grouped GEMM; model may not use known MoE classes"
)
else:
_LOG.info(f"Grouped GEMM patches applied to {patched} MoE block class(es)")

View File

@@ -46,7 +46,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False, cfg=None):
if has_remote_code:
patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"):
@@ -57,7 +57,7 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
patch_mixtral_moe_forward_zero3(cfg)
def patch_remote(model_name):

View File

@@ -41,7 +41,7 @@ def patch_evaluation_loop():
"""Patch the evaluation_loop method."""
# Check if already patched
if hasattr(Trainer, "_original_evaluation_loop"):
LOG.info("Trainer.evaluation_loop already patched")
LOG.debug("Trainer.evaluation_loop already patched")
return
# Check if the patterns exist
@@ -84,7 +84,7 @@ def patch_evaluation_loop():
)
exec(evaluation_loop_source, globals())
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
LOG.debug("Patched Trainer.evaluation_loop with nanmean loss calculation")
Trainer.evaluation_loop = axolotl_evaluation_loop
@@ -135,5 +135,5 @@ def patch_maybe_log_save_evaluate():
)
exec(maybe_log_source, globals())
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
LOG.debug("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate

View File

@@ -196,10 +196,11 @@ def execute_training(
)
)
LOG.info("Starting trainer...")
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
# if cfg.bf16:
# torch.set_default_dtype(torch.bfloat16)
LOG.info("Starting trainer...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
plugin_manager = PluginManager.get_instance()

View File

@@ -44,15 +44,6 @@ def set_pytorch_cuda_alloc_conf():
)
def patch_optimized_env():
"""
Patch environment variables to improve VRAM usage and increase download speed
"""
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()
def get_not_null(value, default=None):
"""
return the value if it's not None, otherwise return the default value

View File

@@ -2,6 +2,8 @@
utils to get GPU info for the current environment
"""
import os
import subprocess # nosec B404
from importlib.metadata import version
from accelerate.utils.environment import (
@@ -14,6 +16,8 @@ from packaging.version import Version, parse
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
if not check_runpod_p2p_support():
return False
unsupported_devices = {"RTX 6000 Ada", "L40S"}
try:
device_names, device_count = get_gpu_info()
@@ -29,6 +33,39 @@ def check_cuda_p2p_ib_support():
return True
def check_runpod_p2p_support() -> bool:
if "RUNPOD_GPU_COUNT" not in os.environ:
return True
try:
gpu_count = int(os.environ.get("RUNPOD_GPU_COUNT", "1"))
except ValueError:
return True
if gpu_count >= 2:
# run `nvidia-smi topo -p2p n` and inspect the GPU0 row
try:
result = subprocess.run( # nosec B603 B607
["nvidia-smi", "topo", "-p2p", "n"],
check=True,
capture_output=True,
text=True,
timeout=5,
)
except (
subprocess.CalledProcessError,
FileNotFoundError,
subprocess.TimeoutExpired,
):
return True # fail-open if detection fails
output_lines = result.stdout.strip().split("\n")
# filter rows that start with "GPU0" (avoid header row)
gpu0_rows = [line for line in output_lines if line.lstrip().startswith("GPU0")]
if not gpu0_rows:
return True
# consider P2P supported if any OK is present in the GPU0 row
return "OK" in gpu0_rows[-1]
return True
def get_package_version(package: str) -> Version:
version_str = version(package)
return parse(version_str)

View File

@@ -2,7 +2,6 @@
import functools
import logging
import os
from axolotl.utils.distributed import is_main_process
@@ -40,10 +39,6 @@ class MultiProcessAdapter(logging.LoggerAdapter):
def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter:
if log_level is None:
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
logger = logging.getLogger(name)
if log_level is not None:
logger.setLevel(log_level.upper())
logger.root.setLevel(log_level.upper())
logger.setLevel(logging.DEBUG)
return MultiProcessAdapter(logger, extra={})

View File

@@ -132,6 +132,14 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
moe_backend: Literal["auto", "torch_grouped", "naive"] | None = Field(
default=None,
json_schema_extra={
"description": "Mixture-of-Experts backend to use: 'auto', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
},
)
# Value is constrained by the Literal type; no normalization needed.
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(

View File

@@ -1378,6 +1378,21 @@ class ComplexValidationMixin:
return self
def hint_gradient_checkpointing_dpo_lora_ddp(self):
if (
(self.gradient_checkpointing is True or self.gradient_checkpointing is None)
and self.capabilities
and self.capabilities.get("n_gpu", 1) > 1
and self.adapter in ("lora", "qlora")
and self.rl == RLType.DPO
and not self.fsdp
and not self.deepspeed
):
LOG.warning(
"gradient_checkpointing with DPO + DDP + LoRA is not recommended."
)
return self
class DistributedValidationMixin:
"""validation for distributed training."""

166
src/axolotl/utils/tee.py Normal file
View File

@@ -0,0 +1,166 @@
"""
Utilities for managing the debug log file and providing a file-only stream for logging
handlers.
"""
from __future__ import annotations
import io
import os
import sys
import threading
from pathlib import Path
from typing import TextIO, cast
_lock = threading.Lock()
_file_handle: io.TextIOWrapper | None = None
_log_path: str | None = None
_tee_installed: bool = False
_orig_stdout: TextIO | None = None
_orig_stderr: TextIO | None = None
class _FileOnlyWriter(io.TextIOBase):
"""A stream-like object that writes only to the tee file.
Before the file is prepared, writes are dropped (no-op).
"""
def write(self, s: str) -> int: # type: ignore[override]
with _lock:
if _file_handle is not None:
_file_handle.write(s)
return len(s)
return len(s)
def flush(self) -> None: # type: ignore[override]
with _lock:
if _file_handle is not None:
try:
_file_handle.flush()
except Exception:
pass
file_only_stream: io.TextIOBase = _FileOnlyWriter()
class _StreamTee(io.TextIOBase):
"""A minimal tee that mirrors writes to the debug log file.
Installed only after the debug log is prepared; no buffering.
"""
def __init__(self, stream: io.TextIOBase):
self._stream = stream
def write(self, s: str) -> int: # type: ignore[override]
with _lock:
n = self._stream.write(s)
if _file_handle is not None:
_file_handle.write(s)
return n
def flush(self) -> None: # type: ignore[override]
with _lock:
self._stream.flush()
if _file_handle is not None:
try:
_file_handle.flush()
except Exception:
pass
@property
def encoding(self): # type: ignore[override]
return getattr(self._stream, "encoding", None)
@property
def errors(self): # type: ignore[override]
return getattr(self._stream, "errors", None)
def isatty(self): # type: ignore[override]
return getattr(self._stream, "isatty", lambda: False)()
def fileno(self): # type: ignore[override]
if hasattr(self._stream, "fileno"):
return self._stream.fileno()
raise OSError("Underlying stream has no fileno")
def prepare_debug_log(cfg, filename: str = "debug.log") -> str:
"""
Prepare the debug log.
Creates the output directory, handles append/truncate logic based on cfg, and opens
the debug log file for subsequent writes via file-only handlers.
"""
global _file_handle, _log_path, _tee_installed
with _lock:
# If already initialized, reuse existing path
if _log_path is not None:
return _log_path
output_dir = cfg.output_dir
os.makedirs(output_dir, exist_ok=True)
log_path = Path(output_dir) / filename
append = bool(
cfg.get("resume_from_checkpoint") or cfg.get("auto_resume_from_checkpoints")
)
if not append and log_path.exists():
log_path.unlink()
fh = open(log_path, "a", encoding="utf-8")
fh.flush()
_file_handle = fh
_log_path = str(log_path)
# Install a tee so stdout/stderr are mirrored to the debug file
# Allow disabling via env for testing or advanced usage.
tee_enabled = os.getenv("AXOLOTL_TEE_STDOUT", "1").lower() not in {
"0",
"false",
"no",
}
if tee_enabled and not _tee_installed:
# Save originals so we can restore later (e.g., tests)
global _orig_stdout, _orig_stderr
_orig_stdout = sys.stdout
_orig_stderr = sys.stderr
sys.stdout = _StreamTee(cast(io.TextIOBase, sys.stdout))
sys.stderr = _StreamTee(cast(io.TextIOBase, sys.stderr))
_tee_installed = True
return _log_path
def close_debug_log() -> None:
"""Flush and close the debug log and uninstall the stdout/stderr tee.
Safe to call even if not initialized.
"""
global _file_handle, _log_path, _tee_installed, _orig_stdout, _orig_stderr
with _lock:
# Restore original stdout/stderr if we installed a tee
if _tee_installed:
if _orig_stdout is not None:
sys.stdout = _orig_stdout
if _orig_stderr is not None:
sys.stderr = _orig_stderr
_tee_installed = False
_orig_stdout = None
_orig_stderr = None
# Close the file handle if open
if _file_handle is not None:
try:
_file_handle.flush()
_file_handle.close()
except Exception:
pass
finally:
_file_handle = None
_log_path = None

View File

@@ -31,6 +31,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No
if checkpoints:
last_checkpoint = str(checkpoints[-1])
if not update:
LOG.info(f"Resuming from last checkpoint at {last_checkpoint}")
return last_checkpoint
if (
@@ -40,6 +41,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No
):
cfg.resume_from_checkpoint = last_checkpoint
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
"Using auto-resume functionality to resume from checkpoint at "
f"{cfg.resume_from_checkpoint}"
)
return cfg.resume_from_checkpoint

View File

@@ -655,15 +655,6 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
def prepare_opinionated_env(cfg):
if cfg.qlora_sharded_model_loading:
# model loading is forked after the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.sample_packing:
# multipack parallel packing sampler defaults to using fork
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_trainer(
cfg,
train_dataset,

View File

@@ -199,7 +199,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"gradient_checkpointing": False,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0,
@@ -278,7 +278,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"gradient_checkpointing": False,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0,

View File

@@ -0,0 +1,258 @@
import sys
import types
import torch
import torch.nn as nn
from axolotl.kernels.moe import (
backends as moe_backends,
torch_grouped as torch_grouped_module,
)
from axolotl.monkeypatch import moe_grouped
class DummyExperts(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = nn.ModuleList(layers)
self.num_experts = len(layers)
def __getitem__(self, idx):
return self.layers[idx]
class DummyQwenMLP(nn.Module):
def __init__(self, idx: int, hidden: int, intermediate: int):
super().__init__()
self.gate_up_proj = nn.Linear(hidden, 2 * intermediate, bias=False)
self.down_proj = nn.Linear(intermediate, hidden, bias=False)
nn.init.constant_(self.gate_up_proj.weight, float(idx + 1))
nn.init.constant_(self.down_proj.weight, float((idx + 1) * 10))
class DummyQwenExpert(nn.Module):
def __init__(self, idx: int, hidden: int, intermediate: int):
super().__init__()
self.mlp = DummyQwenMLP(idx, hidden, intermediate)
def _make_transformers_stub(monkeypatch, block_cls):
# ensure we start from the original forward for each test
if block_cls is DummyMixtralBlock:
DummyMixtralBlock.forward = _DUMMY_MIXTRAL_ORIG_FORWARD
transformers_mod = types.ModuleType("transformers")
models_mod = types.ModuleType("transformers.models")
mixtral_mod = types.ModuleType("transformers.models.mixtral")
modeling_mixtral = types.ModuleType("transformers.models.mixtral.modeling_mixtral")
modeling_mixtral.MixtralSparseMoeBlock = block_cls
transformers_mod.models = models_mod
models_mod.mixtral = mixtral_mod
mixtral_mod.modeling_mixtral = modeling_mixtral
monkeypatch.setitem(sys.modules, "transformers", transformers_mod)
monkeypatch.setitem(sys.modules, "transformers.models", models_mod)
monkeypatch.setitem(sys.modules, "transformers.models.mixtral", mixtral_mod)
monkeypatch.setitem(
sys.modules,
"transformers.models.mixtral.modeling_mixtral",
modeling_mixtral,
)
def test_grouped_uses_per_expert_nested_modules(monkeypatch):
hidden = 4
intermediate = 2
num_experts = 2
experts = DummyExperts(
[DummyQwenExpert(i, hidden, intermediate) for i in range(num_experts)]
)
gate = nn.Linear(hidden, num_experts, bias=False)
nn.init.zeros_(gate.weight)
captured = []
def fake_grouped_mm(As, Bs, dtype):
captured.append([b.detach().clone() for b in Bs])
return [
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
for a, b in zip(As, Bs, strict=False)
]
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
hidden_states = torch.randn(1, 2, hidden)
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
hidden_states, gate, experts, top_k=2
)
assert y is not None
assert router_logits is not None
assert captured, "Grouped GEMM path should have been invoked"
first_call = captured[0]
expected0 = experts[0].mlp.gate_up_proj.weight.t()
expected1 = experts[1].mlp.gate_up_proj.weight.t()
assert torch.equal(first_call[0], expected0)
assert torch.equal(first_call[1], expected1)
assert not torch.equal(first_call[0], first_call[1])
def test_grouped_accepts_module_list_experts(monkeypatch):
hidden = 4
intermediate = 2
experts = nn.ModuleList(
[DummyQwenExpert(i, hidden, intermediate) for i in range(2)]
)
gate = nn.Linear(hidden, len(experts), bias=False)
nn.init.zeros_(gate.weight)
calls = {"count": 0}
def fake_grouped_mm(As, Bs, dtype):
calls["count"] += 1
return [
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
for a, b in zip(As, Bs, strict=False)
]
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
hidden_states = torch.randn(1, 2, hidden)
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
hidden_states, gate, experts, top_k=2
)
assert y is not None
assert router_logits is not None
assert calls["count"] > 0
class _DummyCfg:
moe_backend = "torch_grouped"
class DummyMixtralBlock(nn.Module):
def __init__(self):
super().__init__()
self.top_k = 1
self.gate = lambda x: x
self.experts = object()
self._calls = []
def forward(self, hidden_states: torch.Tensor, attention_mask=None):
self._calls.append((hidden_states, attention_mask))
tokens = hidden_states.shape[0] * hidden_states.shape[1]
router = torch.ones(
tokens, 2, device=hidden_states.device, dtype=hidden_states.dtype
)
return hidden_states + 5, router
_DUMMY_MIXTRAL_ORIG_FORWARD = DummyMixtralBlock.forward
def test_apply_grouped_forward_handles_args(monkeypatch):
_make_transformers_stub(monkeypatch, DummyMixtralBlock)
import axolotl.common.architectures as arch
original_map = arch.MOE_ARCH_BLOCK.copy()
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock")
for key in list(original_map.keys()):
if key != "mixtral":
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None)
monkeypatch.setattr(
moe_grouped,
"get_moe_backend_name",
lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED,
)
results = {}
def fake_grouped_forward(hidden_states, gate, experts, top_k):
results["called"] = True
router = torch.zeros(
hidden_states.shape[0] * hidden_states.shape[1],
2,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
return hidden_states + 1, router
monkeypatch.setattr(torch_grouped_module, "available", lambda: True)
monkeypatch.setattr(
torch_grouped_module,
"moe_ffn_forward_grouped",
fake_grouped_forward,
)
cfg = _DummyCfg()
moe_grouped.apply_grouped_to_moe_blocks(cfg)
block = DummyMixtralBlock()
hidden_states = torch.ones(1, 2, 3)
mask = torch.zeros(1, 2)
out, router = block.forward(hidden_states, attention_mask=mask)
assert results.get("called") is True
assert torch.equal(out, hidden_states + 1)
assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1]
def test_apply_grouped_forward_fallback(monkeypatch):
_make_transformers_stub(monkeypatch, DummyMixtralBlock)
import axolotl.common.architectures as arch
original_map = arch.MOE_ARCH_BLOCK.copy()
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock")
for key in list(original_map.keys()):
if key != "mixtral":
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None)
monkeypatch.setattr(
moe_grouped,
"get_moe_backend_name",
lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED,
)
monkeypatch.setattr(torch_grouped_module, "available", lambda: True)
monkeypatch.setattr(
torch_grouped_module,
"moe_ffn_forward_grouped",
lambda *args, **kwargs: (None, None),
)
cfg = _DummyCfg()
moe_grouped.apply_grouped_to_moe_blocks(cfg)
block = DummyMixtralBlock()
hidden_states = torch.ones(1, 2, 3)
mask = torch.zeros(1, 2)
out, router = block.forward(hidden_states, attention_mask=mask)
assert torch.equal(out, hidden_states + 5)
assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1]
assert block._calls, "Original forward should have been invoked"
call_hidden, call_mask = block._calls[-1]
assert torch.equal(call_hidden, hidden_states)
assert torch.equal(call_mask, mask)
def test_get_moe_backend_name_prefers_probe(monkeypatch):
monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: True)
assert moe_backends.get_moe_backend_name() == moe_backends.MOEBackend.TORCH_GROUPED
def test_get_moe_backend_name_falls_back(monkeypatch):
warnings_captured = []
def fake_warn(msg, *, stacklevel=None): # noqa: ARG001
warnings_captured.append(msg)
monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: False)
monkeypatch.setattr(moe_backends.warnings, "warn", fake_warn)
backend = moe_backends.get_moe_backend_name("torch_grouped")
assert backend == moe_backends.MOEBackend.NAIVE
assert warnings_captured, "Expected warning when torch_grouped unavailable"

View File

@@ -0,0 +1,103 @@
import logging
import tempfile
import pytest
def read(path: str) -> str:
with open(path, "r", encoding="utf-8") as f:
return f.read()
@pytest.fixture(autouse=True)
def _reset_logging_state():
# Ensure a clean slate for logging between tests
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.shutdown()
# Note: dictConfig in configure_logging will set up handlers again
yield
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.shutdown()
def test_axolotl_logs_captured_at_all_levels(monkeypatch):
from axolotl.logging_config import configure_logging
from axolotl.utils import tee
from axolotl.utils.logging import get_logger
with tempfile.TemporaryDirectory() as td:
# Avoid stdout tee in this test to simplify interaction with pytest capture
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
configure_logging()
path = tee.prepare_debug_log(
type("Cfg", (), {"output_dir": td, "get": lambda *_: False})
)
log = get_logger("axolotl.test")
log.info("AX-INFO")
log.debug("AX-DEBUG")
tee.file_only_stream.flush()
data = read(path)
assert "AX-INFO" in data
assert "AX-DEBUG" in data
tee.close_debug_log()
def test_third_party_logs_filtered_and_warning_captured(monkeypatch):
from axolotl.logging_config import configure_logging
from axolotl.utils import tee
with tempfile.TemporaryDirectory() as td:
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
configure_logging()
path = tee.prepare_debug_log(
type("Cfg", (), {"output_dir": td, "get": lambda *_: False})
)
# Third-party logger (non-axolotl)
other = logging.getLogger("thirdparty.lib")
other.info("TP-INFO")
other.warning("TP-WARN")
# Simulate Python warnings routed through logging
logging.getLogger("py.warnings").warning("PY-WARN")
# Push through buffers
tee.file_only_stream.flush()
data = read(path)
# INFO from non-axolotl should be filtered out (not present)
assert "TP-INFO" not in data
# WARNING+ should be present
assert "TP-WARN" in data
# Python warnings captured (via py.warnings logger)
assert "PY-WARN" in data
tee.close_debug_log()
tee.close_debug_log()
def test_prepare_debug_log_idempotent_and_no_duplicate(monkeypatch):
from axolotl.logging_config import configure_logging
from axolotl.utils import tee
from axolotl.utils.logging import get_logger
with tempfile.TemporaryDirectory() as td:
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
configure_logging()
cfg = type("Cfg", (), {"output_dir": td, "get": lambda *_: False})
p1 = tee.prepare_debug_log(cfg)
p2 = tee.prepare_debug_log(cfg)
assert p1 == p2
log = get_logger("axolotl.test")
marker = "UNIQUE-MARKER-12345"
log.info(marker)
tee.file_only_stream.flush()
data = read(p1)
# Ensure the marker appears once (not duplicated via propagation)
assert data.count(marker) == 1
tee.close_debug_log()

107
tests/test_utils_tee.py Normal file
View File

@@ -0,0 +1,107 @@
import os
import tempfile
def _dummy_cfg(output_dir: str, append: bool = False):
# Minimal object with attributes used by prepare_debug_log
class Cfg:
def __init__(self, out, append):
self.output_dir = out
self._append = append
def get(self, key, default=None):
if key in {"resume_from_checkpoint", "auto_resume_from_checkpoints"}:
return self._append
return default
return Cfg(output_dir, append)
def read(path: str) -> str:
with open(path, "r", encoding="utf-8") as f:
return f.read()
def test_file_only_stream_writes_after_prepare(monkeypatch):
from axolotl.utils import tee
with tempfile.TemporaryDirectory() as td:
# Avoid stdout tee in this test
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
cfg = _dummy_cfg(td, append=False)
# before prepare: writing to file_only_stream creates no file
tee.file_only_stream.write("before\n")
tee.file_only_stream.flush()
assert not os.path.exists(os.path.join(td, "debug.log"))
# prepare and write
path = tee.prepare_debug_log(cfg)
assert os.path.basename(path) == "debug.log"
tee.file_only_stream.write("hello\n")
tee.file_only_stream.flush()
content = read(path)
assert "hello" in content
tee.close_debug_log()
def test_stdout_is_mirrored_after_prepare(capsys, monkeypatch):
from axolotl.utils import tee
with tempfile.TemporaryDirectory() as td:
cfg = _dummy_cfg(td, append=False)
try:
# Install tee while capture is disabled so stdout tee wraps real stdout.
with capsys.disabled():
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "1")
path = tee.prepare_debug_log(cfg)
import sys
print("printed-line")
sys.stdout.flush()
# Now verify file contains the line
content = read(path)
assert "printed-line" in content
finally:
tee.close_debug_log()
def test_truncate_vs_append_behavior(monkeypatch):
from axolotl.utils import tee
with tempfile.TemporaryDirectory() as td:
# Avoid stdout tee in this test
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
# First run creates file with A
cfg = _dummy_cfg(td, append=False)
_ = tee.prepare_debug_log(cfg)
try:
tee.file_only_stream.write("A\n")
tee.file_only_stream.flush()
finally:
tee.close_debug_log()
# Second run with append=False truncates
cfg2 = _dummy_cfg(td, append=False)
path2 = tee.prepare_debug_log(cfg2)
try:
tee.file_only_stream.write("B\n")
tee.file_only_stream.flush()
content = read(path2)
assert "A\n" not in content and "B\n" in content
finally:
tee.close_debug_log()
# Third run with append=True preserves existing
cfg3 = _dummy_cfg(td, append=True)
path3 = tee.prepare_debug_log(cfg3)
try:
tee.file_only_stream.write("C\n")
tee.file_only_stream.flush()
content = read(path3)
assert "B\n" in content and "C\n" in content
finally:
tee.close_debug_log()

6545
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff