From 0d689bb4215b87037bc976aad7433774734f2934 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 15 Sep 2025 15:22:11 -0400 Subject: [PATCH] cache, example --- examples/moe/qwen2-moe-qlora-10gb.yaml | 57 ++++++++++++++++++++++++++ src/axolotl/kernels/moe/hf_triton.py | 16 +++++++- 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 examples/moe/qwen2-moe-qlora-10gb.yaml diff --git a/examples/moe/qwen2-moe-qlora-10gb.yaml b/examples/moe/qwen2-moe-qlora-10gb.yaml new file mode 100644 index 000000000..6496b825a --- /dev/null +++ b/examples/moe/qwen2-moe-qlora-10gb.yaml @@ -0,0 +1,57 @@ +base_model: Qwen/Qwen1.5-MoE-A2.7B +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +trust_remote_code: true + +# Keep VRAM low +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./outputs/qwen2-moe-qlora-10gb + +# Train small to fit 10GB +sequence_len: 512 +sample_packing: false +pad_to_sequence_len: false + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +gradient_accumulation_steps: 8 +micro_batch_size: 1 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 5 +flash_attention: true + +warmup_ratio: 0.03 +evals_per_epoch: 2 +saves_per_epoch: 1 +weight_decay: 0.0 + +# Enable router logits if you want aux loss/analysis +model_config: + output_router_logits: true + +# ZeRO-3 with CPU offload keeps VRAM within ~10GB +deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json + +special_tokens: diff --git a/src/axolotl/kernels/moe/hf_triton.py b/src/axolotl/kernels/moe/hf_triton.py index b0fbd29a5..08b8a740a 100644 --- a/src/axolotl/kernels/moe/hf_triton.py +++ b/src/axolotl/kernels/moe/hf_triton.py @@ -25,15 +25,29 @@ def available() -> bool: return False +# Cache loaded handles so we don't trigger repeated hub fetches +_CACHED_HANDLES: Optional[HFTritonHandles] = None +_LOAD_ATTEMPTED: bool = False + + def load() -> Optional[HFTritonHandles]: + global _CACHED_HANDLES, _LOAD_ATTEMPTED + if _CACHED_HANDLES is not None: + return _CACHED_HANDLES + if _LOAD_ATTEMPTED: + # Previously failed; avoid spamming retries per call + return None + _LOAD_ATTEMPTED = True try: from kernels import get_kernel tk = get_kernel("kernels-community/triton_kernels") - return HFTritonHandles( + _CACHED_HANDLES = HFTritonHandles( routing=tk.routing, matmul_ogs=tk.matmul_ogs, swiglu=tk.swiglu ) + return _CACHED_HANDLES except Exception: + # Keep None in cache state to prevent repeated fetch attempts return None