From 753906cfc7a669e9b1788193174304bbd9059a8e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 5 Mar 2026 21:58:09 +0700 Subject: [PATCH] feat: add doc for expert quantization, glm45 air example configs, and update readme for release (#3452) [skip ci] * chore: rename without period * feat: add glm45 air * feat: add doc on expert quantization * feat: update base readme with new changes * chore: cleanup * chore: cleanup * chore: cleanup * fix: disable quantize_moe_expert on merge per comment * chore: add kernel info to optimizations doc --- README.md | 32 ++++++--- _quarto.yml | 1 + docs/expert_quantization.qmd | 66 +++++++++++++++++ docs/optimizations.qmd | 16 +++++ examples/glm45/README.md | 72 +++++++++++++++++++ examples/glm45/glm-45-air-qlora.yaml | 64 +++++++++++++++++ .../{glm4.7-flash => glm47-flash}/README.md | 24 ++----- .../{glm4.7-flash => glm47-flash}/lora.yaml | 0 .../lora_fsdp.yaml | 0 .../{glm4.7-flash => glm47-flash}/qlora.yaml | 0 .../qlora_fsdp.yaml | 0 src/axolotl/cli/merge_lora.py | 1 + src/axolotl/monkeypatch/multipack.py | 1 + 13 files changed, 248 insertions(+), 29 deletions(-) create mode 100644 docs/expert_quantization.qmd create mode 100644 examples/glm45/README.md create mode 100644 examples/glm45/glm-45-air-qlora.yaml rename examples/{glm4.7-flash => glm47-flash}/README.md (67%) rename examples/{glm4.7-flash => glm47-flash}/lora.yaml (100%) rename examples/{glm4.7-flash => glm47-flash}/lora_fsdp.yaml (100%) rename examples/{glm4.7-flash => glm47-flash}/qlora.yaml (100%) rename examples/{glm4.7-flash => glm47-flash}/qlora_fsdp.yaml (100%) diff --git a/README.md b/README.md index b56cdf0e8..9c7a8a493 100644 --- a/README.md +++ b/README.md @@ -29,8 +29,23 @@ ## 🎉 Latest Updates -- 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html). +- 2026/03: + - New model support has been added in Axolotl for Qwen3.5, Qwen3.5 MoE, [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45). + - [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat). +- 2026/02: + - [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels. + - Axolotl now has support for [SageAttention](https://github.com/axolotl-ai-cloud/axolotl/pull/2823) and [GDPO](https://github.com/axolotl-ai-cloud/axolotl/pull/3353) (Generalized DPO). +- 2026/01: + - New integration for [EAFT](https://github.com/axolotl-ai-cloud/axolotl/pull/3366) (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and [Scalable Softmax](https://github.com/axolotl-ai-cloud/axolotl/pull/3338), improves long context in attention. +- 2025/12: + - Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html). + - [Distributed Muon Optimizer](https://github.com/axolotl-ai-cloud/axolotl/pull/3264) support has been added for FSDP2 pretraining. - 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html). + +
+ +Expand older updates + - 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion). - 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107). - 2025/07: @@ -39,15 +54,10 @@ - FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)! - [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl! - TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl! -- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more! - -
- -Expand older updates - -- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning. - 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl! +- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more! - 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version! +- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning. - 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own! - 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try. - 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun! @@ -62,10 +72,10 @@ Axolotl is a free and open-source tool designed to streamline post-training and Features: - **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub. -- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support. -- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM). +- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support. +- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM). - **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference. -- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more! +- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [SageAttention](https://github.com/thu-ml/SageAttention), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more! - **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets. - **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware. diff --git a/_quarto.yml b/_quarto.yml index 4534c0a0e..5e1169102 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -331,6 +331,7 @@ website: - docs/sequence_parallelism.qmd - docs/gradient_checkpointing.qmd - docs/nd_parallelism.qmd + - docs/expert_quantization.qmd - section: "Troubleshooting" contents: diff --git a/docs/expert_quantization.qmd b/docs/expert_quantization.qmd new file mode 100644 index 000000000..7271e8864 --- /dev/null +++ b/docs/expert_quantization.qmd @@ -0,0 +1,66 @@ +--- +title: "MoE Expert Quantization" +description: "Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load" +--- + +Transformers v5 changed MoE expert layers from `nn.Linear` to fused `nn.Parameter` (3D+ tensors). +This means `bitsandbytes` can no longer quantize them during model loading, resulting in all expert +weights being loaded in full bf16 precision and causing massive VRAM usage. + +`quantize_moe_experts` solves this by quantizing expert weights during model loading. +It intercepts the weight loading process, quantizes each expert tensor on the fly, and +immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory. +For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory. + +## Usage + +Enable expert quantization in your Axolotl config: + +```yaml +quantize_moe_experts: true +``` + +This works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization. + +### Expert LoRA targeting + +You can optionally apply LoRA adapters directly to expert weights using `lora_target_parameters`: + +```yaml +lora_target_parameters: + - mlp.experts.gate_up_proj + - mlp.experts.down_proj + # - mlp.gate.weight # router +``` + +::: {.callout-note} +`lora_dropout` must be `0` when using `lora_target_parameters`. +::: + +## Requirements + +- Requires (`adapter: lora` and `load_in_8bit: true`) or (`adapter: qlora` and `load_in_4bit: true`) +- CUDA GPUs only (not tested with ROCm or other backends) +- FSDP2 compatible for distributed training + +## Limitations + +- `cpu_ram_efficient_loading` hangs / takes long time with FSDP2 + QLoRA. +- Total model parameter count may display incorrectly (trainable param count is correct). +- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this. +- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks. +- Model loading takes longer due to on-demand quantization, even on consecutive runs. +- DeepSpeed has not been tested. + +## Implementation details + +The quantization is applied by patching transformers to intercept weight loading. +When a 3D+ CUDA tensor with "expert" in its name is detected: + +- **4-bit mode:** Uses bitsandbytes NF4 parametrization (configurable via `bnb_4bit_quant_type`). +- **8-bit mode:** Uses a custom row-wise int8 parametrization with bitsandbytes dequantization. + +The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to +transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules. + +For full implementation details, see [PR #3439](https://github.com/axolotl-ai-cloud/axolotl/pull/3439). diff --git a/docs/optimizations.qmd b/docs/optimizations.qmd index 967ec2d34..624738de7 100644 --- a/docs/optimizations.qmd +++ b/docs/optimizations.qmd @@ -66,6 +66,15 @@ Provides efficient Triton kernels to improve training speed and reduce memory us - **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels) +### Expert Kernels + +Optimized kernel implementations for Mixture of Experts (MoE) model training. + +- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support. +- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs. + +- **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration) + ## Long Context Models Techniques to train models on sequences longer than their original context window. @@ -131,3 +140,10 @@ Simulates quantization effects during training, helping the model adapt and pote Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method. - **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml) + +### MoE Expert Quantization + +Quantizes MoE expert weights on load to reduce VRAM when training MoE models with adapters. Required for Transformers v5+ MoE models where experts use fused `nn.Parameter` tensors. + +- **Config:** `quantize_moe_experts: true` +- **Learn more:** [MoE Expert Quantization](expert_quantization.qmd) diff --git a/examples/glm45/README.md b/examples/glm45/README.md new file mode 100644 index 000000000..06c7834fc --- /dev/null +++ b/examples/glm45/README.md @@ -0,0 +1,72 @@ +# Finetune Z.ai's GLM-4.5-Air with Axolotl + +[GLM-4.5-Air](https://huggingface.co/zai-org/GLM-4.5-Air) is a MoE model by Z.ai. + +This guide shows how to fine-tune it with Axolotl. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. + +3. Run the finetuning example: + +```bash +# QLoRA (1x80GB @ ~63.4GiB/GPU) +axolotl train examples/glm45/glm-45-air-qlora.yaml +``` + +### Dataset + +In addition to the standard OpenAI Messages format, GLM-4.5 supports an extra parameter for thinking in the assistant section. + +```json +{ + "role": "assistant", + "reasoning_content": "...", // or have ... in `content` + "content": "..." +} +``` + +Make sure you set the below extra attributes if needed: + +```yaml +datasets: + - path: ... + type: chat_template + message_property_mappings: + role: role + content: content + + # tool_calls: tool_calls # uncomment if using tools + # reasoning_content: reasoning_content # uncomment if have reasoning + +# Uncomment if training on tool role (you would rarely if ever need this) +# eot_tokens: +# - <|observation|> +``` + +### Tips + +- The role name for tools in this template is `tool`. +- You will see this Axolotl WARNING — this is expected as the template does not use EOS: + ``` + EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct. + ``` +- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. +- **LoRA kernels**: Incompatible with this model. Must be explicitly disabled (`lora_*_kernel: false`). +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). + +## Optimization Guides + +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). + +## Related Resources + +- [GLM-4.5-Air on HuggingFace](https://huggingface.co/zai-org/GLM-4.5-Air) +- [GLM-4.5 Blog](https://z.ai/blog/glm-4.5) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/glm45/glm-45-air-qlora.yaml b/examples/glm45/glm-45-air-qlora.yaml new file mode 100644 index 000000000..accb8898f --- /dev/null +++ b/examples/glm45/glm-45-air-qlora.yaml @@ -0,0 +1,64 @@ +base_model: zai-org/GLM-4.5-Air + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +quantize_moe_experts: true # important + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 16 +lora_alpha: 8 +lora_dropout: 0 +lora_target_modules: + - q_proj + - v_proj + - k_proj + - o_proj + +# lora_target_parameters: +# - mlp.experts.gate_up_proj +# - mlp.experts.down_proj + +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +gradient_accumulation_steps: 2 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/glm4.7-flash/README.md b/examples/glm47-flash/README.md similarity index 67% rename from examples/glm4.7-flash/README.md rename to examples/glm47-flash/README.md index 6d7fd437a..2e5e21010 100644 --- a/examples/glm4.7-flash/README.md +++ b/examples/glm47-flash/README.md @@ -16,40 +16,28 @@ This guide shows how to fine-tune it with Axolotl. # QLoRA # - no target experts (1x48GB @ ~24GiB/GPU) # - target experts (1x48GB @ ~34GiB/GPU) -axolotl train examples/glm4.7-flash/qlora.yaml +axolotl train examples/glm47-flash/qlora.yaml # QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU) -axolotl train examples/glm4.7-flash/qlora_fsdp.yaml +axolotl train examples/glm47-flash/qlora_fsdp.yaml ``` ```bash # LoRA # - no target experts (1x48GB @ ~35GiB/GPU) # - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU) -axolotl train examples/glm4.7-flash/lora.yaml +axolotl train examples/glm47-flash/lora.yaml # LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU) -axolotl train examples/glm4.7-flash/lora_fsdp.yaml +axolotl train examples/glm47-flash/lora_fsdp.yaml ``` -### Expert LoRA +### MoE Expert Quantization & Expert LoRA -To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config. - -Note: `lora_dropout` must be `0` when using `lora_target_parameters`. - -```yaml -lora_target_parameters: - - mlp.experts.gate_up_proj - - mlp.experts.down_proj - # - mlp.gate.weight # router, untested but should work, not normally targeted -``` +This model quantize expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs. ## Limitations -- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks. -- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this. -- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise. - **lora_target_linear**: Incompatible for this model. - **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`). diff --git a/examples/glm4.7-flash/lora.yaml b/examples/glm47-flash/lora.yaml similarity index 100% rename from examples/glm4.7-flash/lora.yaml rename to examples/glm47-flash/lora.yaml diff --git a/examples/glm4.7-flash/lora_fsdp.yaml b/examples/glm47-flash/lora_fsdp.yaml similarity index 100% rename from examples/glm4.7-flash/lora_fsdp.yaml rename to examples/glm47-flash/lora_fsdp.yaml diff --git a/examples/glm4.7-flash/qlora.yaml b/examples/glm47-flash/qlora.yaml similarity index 100% rename from examples/glm4.7-flash/qlora.yaml rename to examples/glm47-flash/qlora.yaml diff --git a/examples/glm4.7-flash/qlora_fsdp.yaml b/examples/glm47-flash/qlora_fsdp.yaml similarity index 100% rename from examples/glm4.7-flash/qlora_fsdp.yaml rename to examples/glm47-flash/qlora_fsdp.yaml diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index e7ad89036..bc2dc84c7 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -71,6 +71,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: merge_lora=True, load_in_8bit=False, load_in_4bit=False, + quantize_moe_experts=False, flash_attention=False, context_parallel_size=None, deepspeed=None, diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 6a6b935be..3208325eb 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -37,6 +37,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "deepseek_v3", "glm", "glm4", + "glm4_moe", "smollm3", "granite", "granitemoe",