diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 2a00cd1d0..2cc27f211 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\"" ] }, { diff --git a/examples/glm4.7-flash/README.md b/examples/glm4.7-flash/README.md new file mode 100644 index 000000000..6d7fd437a --- /dev/null +++ b/examples/glm4.7-flash/README.md @@ -0,0 +1,77 @@ +# Finetune Z.ai's GLM-4.7-Flash with Axolotl + +[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai. + +This guide shows how to fine-tune it with Axolotl. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. + +3. Run the finetuning example: + +```bash +# QLoRA +# - no target experts (1x48GB @ ~24GiB/GPU) +# - target experts (1x48GB @ ~34GiB/GPU) +axolotl train examples/glm4.7-flash/qlora.yaml + +# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU) +axolotl train examples/glm4.7-flash/qlora_fsdp.yaml +``` + +```bash +# LoRA +# - no target experts (1x48GB @ ~35GiB/GPU) +# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU) +axolotl train examples/glm4.7-flash/lora.yaml + +# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU) +axolotl train examples/glm4.7-flash/lora_fsdp.yaml +``` + +### Expert LoRA + +To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config. + +Note: `lora_dropout` must be `0` when using `lora_target_parameters`. + +```yaml +lora_target_parameters: + - mlp.experts.gate_up_proj + - mlp.experts.down_proj + # - mlp.gate.weight # router, untested but should work, not normally targeted +``` + +## Limitations + +- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks. +- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this. +- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise. +- **lora_target_linear**: Incompatible for this model. +- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`). + + +### TIPS + +- For inference, the official Z.ai team recommends these default settings (most tasks): + - `temperature: 1.0` + - `top_p: 0.95` + - `max_new_tokens: 131072` +- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). + +## Optimization Guides + +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). + +## Related Resources + +- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash) +- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/glm4.7-flash/lora.yaml b/examples/glm4.7-flash/lora.yaml new file mode 100644 index 000000000..2586babb7 --- /dev/null +++ b/examples/glm4.7-flash/lora.yaml @@ -0,0 +1,65 @@ +base_model: zai-org/GLM-4.7-Flash + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: true +quantize_moe_experts: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/glm4.7-flash-lora-8bit-out + +adapter: lora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0 +lora_target_modules: + - q_proj + - v_proj + - k_proj + - o_proj + +# Uncomment to also target MoE expert weights: +# lora_target_parameters: +# - mlp.experts.gate_up_proj +# - mlp.experts.down_proj + +# LoRA kernels incompatible with DSA attention +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 diff --git a/examples/glm4.7-flash/lora_fsdp.yaml b/examples/glm4.7-flash/lora_fsdp.yaml new file mode 100644 index 000000000..bee20bf02 --- /dev/null +++ b/examples/glm4.7-flash/lora_fsdp.yaml @@ -0,0 +1,75 @@ +base_model: zai-org/GLM-4.7-Flash + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: true +quantize_moe_experts: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out + +adapter: lora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0 +lora_target_modules: + - q_proj + - v_proj + - k_proj + - o_proj + +# Uncomment to also target MoE expert weights: +# lora_target_parameters: +# - mlp.experts.gate_up_proj +# - mlp.experts.down_proj + +# LoRA kernels incompatible with DSA attention +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +fsdp_config: + fsdp_version: 2 + offload_params: false + cpu_ram_efficient_loading: false + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true diff --git a/examples/glm4.7-flash/qlora.yaml b/examples/glm4.7-flash/qlora.yaml new file mode 100644 index 000000000..834c46af8 --- /dev/null +++ b/examples/glm4.7-flash/qlora.yaml @@ -0,0 +1,65 @@ +base_model: zai-org/GLM-4.7-Flash + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_4bit: true +quantize_moe_experts: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/glm4.7-flash-qlora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0 +lora_target_modules: + - q_proj + - v_proj + - k_proj + - o_proj + +# Uncomment to also target MoE expert weights: +# lora_target_parameters: +# - mlp.experts.gate_up_proj +# - mlp.experts.down_proj + +# LoRA kernels incompatible with DSA attention +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 diff --git a/examples/glm4.7-flash/qlora_fsdp.yaml b/examples/glm4.7-flash/qlora_fsdp.yaml new file mode 100644 index 000000000..0bb87813f --- /dev/null +++ b/examples/glm4.7-flash/qlora_fsdp.yaml @@ -0,0 +1,75 @@ +base_model: zai-org/GLM-4.7-Flash + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_4bit: true +quantize_moe_experts: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0 +lora_target_modules: + - q_proj + - v_proj + - k_proj + - o_proj + +# Uncomment to also target MoE expert weights: +# lora_target_parameters: +# - mlp.experts.gate_up_proj +# - mlp.experts.down_proj + +# LoRA kernels incompatible with DSA attention +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +fsdp_config: + fsdp_version: 2 + offload_params: false + cpu_ram_efficient_loading: false + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true diff --git a/examples/trinity/README.md b/examples/trinity/README.md index 4bbfcf29c..e9710915c 100644 --- a/examples/trinity/README.md +++ b/examples/trinity/README.md @@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations 1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build). -2. Run the finetuning example: +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. + +3. Run the finetuning example: ```bash axolotl train examples/trinity/trinity-nano-preview-qlora.yaml ``` -This config uses about 24.9 GiB VRAM. +This config uses about 24.9 GiB VRAM (w/o CCE). Let us know how it goes. Happy finetuning! 🚀 @@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀 Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). -## Limitations - -**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future. - ## Related Resources - [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto) diff --git a/examples/trinity/trinity-nano-preview-qlora.yaml b/examples/trinity/trinity-nano-preview-qlora.yaml index de54fc8ac..d8bf9f073 100644 --- a/examples/trinity/trinity-nano-preview-qlora.yaml +++ b/examples/trinity/trinity-nano-preview-qlora.yaml @@ -1,5 +1,4 @@ base_model: arcee-ai/Trinity-Nano-Preview -trust_remote_code: true revision_of_model: 2ee94b0 # Automatically upload checkpoint and final model to HF diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index afc172e22..f6cd0c495 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"' ) diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index f4d6ca928..a409ed9f4 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -18,4 +18,7 @@ MOE_ARCH_BLOCK = { "gpt_oss": "GptOssDecoderLayer", "lfm2_moe": "Lfm2MoeSparseMoeBlock", "afmoe": "AfmoeMoE", + "glm4_moe": "Glm4MoeDecoderLayer", + "glm4_moe_lite": "Glm4MoeLiteDecoderLayer", + "glm_moe_dsa": "GlmMoeDsaDecoderLayer", } diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 414abeb4d..76e8f105f 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -720,12 +720,16 @@ class AxolotlTrainer( os.makedirs(output_dir, exist_ok=True) LOG.info(f"Saving model checkpoint to {output_dir}") - # fix for Context Parallel save - if state_dict is None: - state_dict = self.accelerator.get_state_dict(self.model) - if state_dict is not None: + # fix for Context Parallel save: CP eval invalidates tensor storage + # pointers, so clone to CPU to get fresh valid storage for safetensors + if ( + state_dict is not None + and self.axolotl_cfg + and self.axolotl_cfg.context_parallel_size + and self.axolotl_cfg.context_parallel_size > 1 + ): state_dict = { - k: v.clone() if isinstance(v, torch.Tensor) else v + k: v.detach().cpu() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() } @@ -761,7 +765,11 @@ class AxolotlTrainer( metadata={"format": "pt"}, ) else: - self.model.save_pretrained(output_dir, state_dict=state_dict) + self.model.save_pretrained( + output_dir, + state_dict=state_dict, + is_main_process=self.accelerator.is_main_process, + ) if self.processing_class is not None: self.processing_class.save_pretrained(output_dir) diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index b3f475dc2..b892033da 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583" ``` ## Usage @@ -88,9 +88,9 @@ plugins: - qwen2_vl - qwen3 - qwen3_5 +- qwen3_5_text - qwen3_5_moe -- qwen3_5_moe_vl -- qwen3_5_vl +- qwen3_5_moe_text - qwen3_moe - qwen3_next - qwen3_vl diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 36681d770..5c207e0fc 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`' ) diff --git a/src/axolotl/integrations/kernels/README.md b/src/axolotl/integrations/kernels/README.md index 96ff7b328..237d653cf 100644 --- a/src/axolotl/integrations/kernels/README.md +++ b/src/axolotl/integrations/kernels/README.md @@ -39,6 +39,8 @@ This works for any MoE model in transformers that uses a `SparseMoeBlock` class ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA). +ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm. + ## Note on MegaBlocks We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers. diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 3b64b23db..eb7203c01 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module): return self for param in model.parameters(): - if isinstance(param, Params4bit): + if isinstance(param, Params4bit) and param.quant_state is not None: param.quant_state._orig_to = param.quant_state.to param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 6c8885526..3be557a42 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -172,7 +172,10 @@ class ModelLoader: # Build the model PLUGIN_MANAGER.pre_model_load(self.cfg) self.patch_manager.apply_post_plugin_pre_model_load_patches() + skip_move_to_device = self._build_model() + self.patch_manager.apply_post_model_build_patches(self.model) + PLUGIN_MANAGER.post_model_build(self.cfg, self.model) # Post-build model configuration @@ -860,6 +863,10 @@ class ModelLoader: # Make sure everything is in the same dtype skip_prepare_model_for_kbit_training = True + if getattr(self.model, "_moe_experts_quantized", False): + # Parametrized expert tensors dequantize on access — would OOM. + skip_prepare_model_for_kbit_training = True + if ( not skip_prepare_model_for_kbit_training and self.cfg.adapter in ["lora", "qlora"] diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 62dcbde7a..87520c06f 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -118,6 +118,7 @@ class PatchManager: def apply_post_plugin_pre_model_load_patches(self): """Apply post plugin-pre_model_load load patches based on config.""" self._apply_tiled_mlp(self.cfg.model_config_type) + self._apply_moe_expert_quantization_patch() def _apply_transformers_patches(self): from axolotl.monkeypatch.transformers.trainer_loss_calc import ( @@ -135,6 +136,10 @@ class PatchManager: patch_prepare_context_parallel_inputs() + def apply_post_model_build_patches(self, model: PreTrainedModel): + """Apply patches right after model build, before post-load setup.""" + self._finalize_moe_expert_quantization(model) + def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" self._apply_llama_flash_attn_patches(model) @@ -170,9 +175,14 @@ class PatchManager: patch_parallelism_config() if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2": - from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2 + from axolotl.monkeypatch.accelerate.fsdp2 import ( + patch_accelerate_fsdp2, + patch_tied_keys_for_meta_device, + ) patch_accelerate_fsdp2() + if self.cfg.fsdp_config.cpu_ram_efficient_loading: + patch_tied_keys_for_meta_device() if self.cfg.rl: from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2 @@ -352,15 +362,54 @@ class PatchManager: if ( self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2" - and self.cfg.adapter == "qlora" + and (self.cfg.load_in_4bit or self.cfg.load_in_8bit) ): from axolotl.monkeypatch.fsdp2_qlora import ( + apply_init_dtype_attrs_patch, apply_init_sharded_param_patch, apply_init_unsharded_param_patch, + apply_linear8bitlt_save_patch, ) apply_init_sharded_param_patch() apply_init_unsharded_param_patch() + apply_init_dtype_attrs_patch() + if self.cfg.load_in_8bit: + apply_linear8bitlt_save_patch() + + def _apply_moe_expert_quantization_patch(self): + """Patch transformers weight loading to quantize MoE expert params on-the-fly.""" + if not self.cfg.quantize_moe_experts: + return + + from axolotl.monkeypatch.moe_quant import ( + patch_moe_quantization_on_load, + patch_peft_target_parameters_matching, + ) + + patch_moe_quantization_on_load(self.cfg) + patch_peft_target_parameters_matching() + + def _finalize_moe_expert_quantization(self, model: PreTrainedModel): + """Log quantization results and set model flag for downstream use.""" + import torch + + model._moe_experts_quantized = False + if self.cfg.quantize_moe_experts: + from axolotl.monkeypatch.moe_quant import get_moe_quantized_count + + count = get_moe_quantized_count() + if count > 0: + import gc + + model._moe_experts_quantized = True + LOG.info( + "Quantized %d MoE expert parameter(s) to %s during model loading", + count, + "4-bit" if self.cfg.load_in_4bit else "8-bit", + ) + gc.collect() + torch.cuda.empty_cache() def _apply_tiled_mlp(self, model_type: str): if self.cfg.tiled_mlp: diff --git a/src/axolotl/models/mamba/modeling_mamba.py b/src/axolotl/models/mamba/modeling_mamba.py index e6158a0a9..b1847d6b5 100644 --- a/src/axolotl/models/mamba/modeling_mamba.py +++ b/src/axolotl/models/mamba/modeling_mamba.py @@ -111,6 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin): self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, + **kwargs, ): if state_dict is None: state_dict = self.state_dict() diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index af6f24a63..4a8d9840f 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -150,13 +150,17 @@ def get_state_dict(self, model, unwrap=True): ) elif self.is_fsdp2: # https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465 + from torch.distributed.tensor import DTensor + state_dict = {} sharded_state_dict = model.state_dict() for param_name, param in sharded_state_dict.items(): if param.is_cpu: param = param.to(torch.device("cuda")) - param = param.full_tensor() + if isinstance(param, DTensor): + param = param.full_tensor() + if torch.distributed.get_rank() == 0: state_dict[param_name] = param.cpu() torch.distributed.barrier() @@ -182,10 +186,56 @@ def get_state_dict(self, model, unwrap=True): return state_dict +def patch_peft_param_wrapper_for_fsdp2(): + """Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility. + + PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds + delta_weight to the base weight W inside _LoraParameterProxy.forward(). + Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a + regular Tensor (or vice versa), causing a RuntimeError on mixed types. + + This patch promotes the non-DTensor operand to match the DTensor's spec + using DTensor.from_local(), which is free for Replicate placement (just + metadata wrapping, no communication). + """ + from peft.tuners.lora.layer import _LoraParameterProxy + + if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False): + return + + _original_forward = _LoraParameterProxy.forward + + # NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`. + def _patched_forward(self, W): + from torch.distributed.tensor import DTensor + + delta = self.delta_weight + w_is_dt = isinstance(W, DTensor) + d_is_dt = isinstance(delta, DTensor) + + with torch.nn.utils.parametrize.cached(): + if w_is_dt == d_is_dt: + return W + delta + if w_is_dt: + return W + DTensor.from_local(delta, W.device_mesh, W.placements) + return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta + + _LoraParameterProxy.forward = _patched_forward + _LoraParameterProxy._axolotl_fsdp2_patched = True + LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility") + + def _process_lora_module_for_fsdp(module, fsdp2_kwargs): """Helper function to process LoRA modules for FSDP2.""" + from peft.tuners.lora.layer import ParamWrapper from torch.distributed.fsdp import fully_shard + # Skip ParamWrapper — its lora_A/B must not be independently sharded. + # The parent decoder layer's FSDP wrapper handles unsharding them. + # TODO: review if we even need to shard them separately in first place. + if isinstance(module, ParamWrapper): + return False + log_bias_dtype_mismatch = False # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to @@ -327,6 +377,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: is_peft_model = isinstance(model, PeftModel) + # Patch PEFT's _LoraParameterProxy for DTensor compatibility if any + # ParamWrapper modules exist (used for target_parameters / 3D expert params). + if is_peft_model: + from peft.tuners.lora.layer import ParamWrapper + + if any(isinstance(m, ParamWrapper) for m in model.modules()): + patch_peft_param_wrapper_for_fsdp2() + auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) log_bias_dtype_mismatch = False if auto_wrap_policy is not None: @@ -376,6 +434,43 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: return model +def patch_tied_keys_for_meta_device(): + """Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors. + + Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly + grouped as "tied". Skipping them is safe since they have no real storage. + """ + from collections import defaultdict + + from transformers import PreTrainedModel + + def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys): + param_pointers = defaultdict(list) + for param_name, param_value in self.state_dict().items(): + if param_value.is_meta: + continue + param_pointers[param_value.data_ptr()].append(param_name) + + tied_param_names = [ + names + for names in param_pointers.values() + if len(names) > 1 + and not any(name in self.all_tied_weights_keys.keys() for name in names) + and not all(name in missing_keys for name in names) + ] + + tied_weights_keys_by_pointers = { + param_name: group[0] + for group in tied_param_names + for param_name in group[1:] + } + self.all_tied_weights_keys.update(tied_weights_keys_by_pointers) + + PreTrainedModel._adjust_tied_keys_with_tied_pointers = ( + _patched_adjust_tied_keys_with_tied_pointers + ) + + def patch_accelerate_fsdp2(): import accelerate diff --git a/src/axolotl/monkeypatch/fsdp2_qlora.py b/src/axolotl/monkeypatch/fsdp2_qlora.py index 04d0d1971..1887c0a8a 100644 --- a/src/axolotl/monkeypatch/fsdp2_qlora.py +++ b/src/axolotl/monkeypatch/fsdp2_qlora.py @@ -1,9 +1,10 @@ """ -Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as -our LoRA / QLoRA Triton kernels to work with FSDP2. +Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2 +and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2. -This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes -Params4bit parameters. +This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam +to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization +metadata through the FSDP2 shard/unshard cycle. """ import importlib @@ -17,6 +18,8 @@ LOG = get_logger(__name__) def apply_init_sharded_param_patch(): """Apply patch to FSDPParam._init_sharded_param to support Params4bit.""" + if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False): + return from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam # Get original source @@ -41,9 +44,20 @@ def apply_init_sharded_param_patch(): bnb_quantized=param.bnb_quantized, ) self.sharded_param = self.to_sharded_dtensor(self.sharded_param) + elif isinstance(param, bnb.nn.modules.Int8Params): + self.sharded_param = bnb.nn.modules.Int8Params( + data=sharded_param, + requires_grad=param.requires_grad, + has_fp16_weights=param.has_fp16_weights, + CB=None, + SCB=param.SCB, + ) + self.sharded_param = self.to_sharded_dtensor(self.sharded_param) else: - self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) - self.sharded_param.requires_grad_(param.requires_grad)""" + self.sharded_param = nn.Parameter( + self.to_sharded_dtensor(sharded_param), + requires_grad=param.requires_grad, + )""" # Apply the replacement if original_param_creation in original_source: @@ -73,6 +87,7 @@ def apply_init_sharded_param_patch(): # Replace the method FSDPParam._init_sharded_param = patched_init_sharded_param + apply_init_sharded_param_patch._axolotl_patched = True LOG.info("Successfully applied FSDP _init_sharded_param patch") else: LOG.warning("Could not find target code for _init_sharded_param patching") @@ -80,6 +95,8 @@ def apply_init_sharded_param_patch(): def apply_init_unsharded_param_patch(): """Apply patch to FSDPParam.init_unsharded_param to support Params4bit.""" + if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False): + return from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam # Get original source @@ -105,6 +122,14 @@ def apply_init_unsharded_param_patch(): module=local_tensor.module, bnb_quantized=local_tensor.bnb_quantized, ) + elif isinstance(local_tensor, bnb.nn.modules.Int8Params): + self._unsharded_param = bnb.nn.modules.Int8Params( + data=unsharded_param, + requires_grad=self.sharded_param.requires_grad, + has_fp16_weights=local_tensor.has_fp16_weights, + CB=unsharded_param, + SCB=local_tensor.SCB, + ) else: self._unsharded_param = nn.Parameter( unsharded_param, requires_grad=self.sharded_param.requires_grad @@ -138,6 +163,74 @@ def apply_init_unsharded_param_patch(): # Replace the method FSDPParam.init_unsharded_param = patched_init_unsharded_param + apply_init_unsharded_param_patch._axolotl_patched = True LOG.info("Successfully applied FSDP init_unsharded_param patch") else: LOG.warning("Could not find target code for patching") + + +def apply_linear8bitlt_save_patch(): + """Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params. + + After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params. + BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor + doesn't proxy custom attribute access to its _local_tensor. This patch + temporarily unwraps the DTensor during saving so BnB can find the SCB attribute. + """ + if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False): + return + import bitsandbytes as bnb + from torch.distributed.tensor import DTensor + + original_save = bnb.nn.Linear8bitLt._save_to_state_dict + + def _patched_save_to_state_dict(self, destination, prefix, keep_vars): + # Use _parameters dict directly to bypass nn.Module.__setattr__ type check. + weight = self._parameters["weight"] + unwrapped = False + if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"): + self._parameters["weight"] = weight._local_tensor + unwrapped = True + try: + original_save(self, destination, prefix, keep_vars) + finally: + if unwrapped: + self._parameters["weight"] = weight + + bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict + apply_linear8bitlt_save_patch._axolotl_patched = True + LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility") + + +def apply_init_dtype_attrs_patch(): + """Prevent FSDP2 mixed precision from casting non-float quantized params. + + When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets + param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts + the sharded param to param_dtype. For non-float params (uint8 packed 4-bit, + int8 quantized) without FSDP2 extensions, this destroys the quantized data. + + Params4bit handles this via fsdp_pre/post_all_gather extensions, but our + parametrize-based expert quantization uses plain nn.Parameter(uint8/int8) + without extensions. + """ + if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False): + return + from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam + + original_init_dtype_attrs = FSDPParam.init_dtype_attrs + + def patched_init_dtype_attrs(self, mp_policy): + original_init_dtype_attrs(self, mp_policy) + # Skip casting non-float quantized params (uint8/int8) without FSDP2 + # extensions — the parametrization chain handles dequantization. + if self.param_dtype is not None and not self.sharded_param.is_floating_point(): + local = self.sharded_param + if hasattr(local, "_local_tensor"): + local = local._local_tensor + if not hasattr(local, "fsdp_pre_all_gather"): + self.param_dtype = None + + FSDPParam.init_dtype_attrs = patched_init_dtype_attrs + apply_init_dtype_attrs_patch._axolotl_patched = True + LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params") diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py new file mode 100644 index 000000000..42beec6a9 --- /dev/null +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -0,0 +1,188 @@ +""" +Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors. + +In transformers v5, MoE models store expert weights as fused 3D tensors that BnB +skips (only targets nn.Linear). This module patches weight loading to quantize them +on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization), +reducing peak VRAM from "all experts in bf16" to "one expert at a time." +""" + +import bitsandbytes as bnb +import torch +import torch.nn.utils.parametrize as P + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +# Module-level state for the loading-time quantization patch. +_moe_load_state = { + "count": 0, + "mode": "4bit", + "quant_type": "nf4", + "compress_statistics": True, + "patched": False, +} + + +class Bnb8bitParametrization(torch.nn.Module): + """Parametrization that dequantizes int8 row-wise quantized data on access.""" + + def __init__(self, row_stats: torch.Tensor): + super().__init__() + self.register_buffer("row_stats", row_stats) + + @torch.no_grad() + def forward(self, quantized_param: torch.Tensor) -> torch.Tensor: + # Flatten 3D+ to 2D for BnB's dequant, then reshape back. + orig_shape = quantized_param.shape + if quantized_param.ndim > 2: + quantized_param = quantized_param.reshape(-1, orig_shape[-1]) + result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats) + return result.reshape(orig_shape) + + +def _enable_parametrization_cache(module, inputs): + P._cache_enabled += 1 + + +def _disable_parametrization_cache(module, inputs, output): + P._cache_enabled -= 1 + if not P._cache_enabled: + P._cache = {} + + +def replace_parameter_8bit(module, param_name): + """Replace a module parameter with an 8-bit quantized version using parametrization.""" + original_param = getattr(module, param_name) + int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant( + original_param.data.to(torch.float16) + ) + + setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False)) + del original_param + + P.register_parametrization( + module, param_name, Bnb8bitParametrization(row_stats), unsafe=True + ) + + # Cache dequantized values during forward to avoid redundant dequantization. + if not getattr(module, "_axolotl_8bit_hooks_registered", False): + module.register_forward_pre_hook(_enable_parametrization_cache) + module.register_forward_hook(_disable_parametrization_cache) + module._axolotl_8bit_hooks_registered = True + + +def patch_moe_quantization_on_load(cfg): + """Patch transformers' weight loading to quantize MoE expert params on-the-fly. + + Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their + name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low. + """ + mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit" + _moe_load_state["mode"] = mode + _moe_load_state["count"] = 0 + + if _moe_load_state["patched"]: + LOG.debug("MoE loading-time quantization patch already active") + return + + import transformers.core_model_loading + import transformers.modeling_utils + + if mode == "4bit": + from bitsandbytes.nn.parametrize import replace_parameter_4bit + + quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4" + compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None) + if compress_statistics is None: + compress_statistics = True + + _moe_load_state["quant_type"] = quant_type + _moe_load_state["compress_statistics"] = compress_statistics + + # Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16 + # size for all params, defeating our on-load quantization VRAM savings. + def _noop_warmup(*args, **kwargs): + pass + + transformers.modeling_utils.caching_allocator_warmup = _noop_warmup + + original_set_param = transformers.core_model_loading.set_param_for_module + + def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs): + original_set_param(model, target_name, param_value, *args, **kwargs) + + # Quantize 3D+ expert params that BnB skipped (only on CUDA). + if param_value.ndim >= 3 and param_value.is_cuda: + mod_path, _, pname = target_name.rpartition(".") + mod = model.get_submodule(mod_path) if mod_path else model + if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): + if "expert" not in target_name.lower(): + LOG.debug( + "Skipping non-expert 3D param: %s (shape=%s)", + target_name, + list(param_value.shape), + ) + return + + if _moe_load_state["mode"] == "4bit": + replace_parameter_4bit( + mod, + pname, + compress_statistics=_moe_load_state["compress_statistics"], + quant_type=_moe_load_state["quant_type"], + ) + else: + replace_parameter_8bit(mod, pname) + _moe_load_state["count"] += 1 + + # Release the bf16 tensor so CUDA memory is freed immediately. + param_value.data = torch.empty(0, device="cpu") + torch.cuda.empty_cache() + + transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module + _moe_load_state["patched"] = True + + +def get_moe_quantized_count(): + """Return the number of expert parameters quantized during loading.""" + return _moe_load_state["count"] + + +def patch_peft_target_parameters_matching(): + """Fix PEFT's _inject_parameters to use suffix matching for parametrized modules.""" + if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False): + return + from peft.tuners.tuners_utils import BaseTuner + + original_inject = BaseTuner._inject_parameters + + def _patched_inject_parameters( + self, peft_config, model, adapter_name, low_cpu_mem_usage + ): + # Patch target_parameters to use full paths for parametrized modules + original_targets = list(peft_config.target_parameters) + expanded = set(original_targets) + + for module_name, module in model.named_modules(): + if not hasattr(module, "parametrizations"): + continue + for target in original_targets: + mod_path, _, param_name = target.rpartition(".") + if ( + module_name == mod_path or module_name.endswith("." + mod_path) + ) and hasattr(module, param_name): + expanded.add(f"{module_name}.{param_name}") + + peft_config.target_parameters = sorted(expanded) + try: + return original_inject( + self, peft_config, model, adapter_name, low_cpu_mem_usage + ) + finally: + peft_config.target_parameters = original_targets + + BaseTuner._inject_parameters = _patched_inject_parameters + patch_peft_target_parameters_matching._axolotl_patched = True + LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching") diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index b15b99955..5c0d31ff0 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -629,6 +629,17 @@ class AxolotlInputConfig( }, ) + quantize_moe_experts: bool = Field( + default=False, + json_schema_extra={ + "description": "Quantize MoE expert weights on load to reduce VRAM. " + "Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. " + "Requires CUDA (not compatible with ROCm or other backends). " + "Note: total parameter count may be reported incorrectly when enabled " + "(trainable param count is correct)." + }, + ) + scaling_softmax: bool | None = Field( default=None, json_schema_extra={ @@ -1289,6 +1300,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return data + @model_validator(mode="before") + @classmethod + def check_quantize_moe_experts(cls, data): + if data.get("quantize_moe_experts"): + if data.get("adapter") not in ("lora", "qlora"): + raise ValueError("quantize_moe_experts requires adapter: lora or qlora") + if not (data.get("load_in_4bit") or data.get("load_in_8bit")): + raise ValueError( + "quantize_moe_experts requires load_in_4bit or load_in_8bit" + ) + if ( + data.get("capabilities") + and data["capabilities"].get("compute_capability") + and not data["capabilities"]["compute_capability"].startswith("sm_") + ): + raise ValueError( + "quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)" + ) + return data + @model_validator(mode="before") @classmethod def check_auto_enable_lora_kernels(cls, data): diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index a86de7822..5b90fb63f 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -209,6 +209,19 @@ class LoraConfig(BaseModel): data["lora_dropout"] = 0.0 return data + @model_validator(mode="after") + def validate_lora_target_parameters_dropout(self): + if ( + self.lora_target_parameters + and self.lora_dropout + and self.lora_dropout != 0.0 + ): + raise ValueError( + "lora_dropout must be 0 when lora_target_parameters is set. " + "PEFT's ParamWrapper does not support lora_dropout != 0." + ) + return self + class ReLoRAConfig(BaseModel): """ReLoRA configuration subset""" diff --git a/tests/utils/schemas/validation/test_moe_quant.py b/tests/utils/schemas/validation/test_moe_quant.py new file mode 100644 index 000000000..b969cbb68 --- /dev/null +++ b/tests/utils/schemas/validation/test_moe_quant.py @@ -0,0 +1,142 @@ +"""Tests for MoE expert quantization config validation and PEFT patch idempotency.""" + +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +@pytest.fixture() +def gpu_caps(): + return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1} + + +@pytest.fixture() +def env_caps(): + return {"torch_version": "2.7.0"} + + +class TestQuantizeMoeExpertsValidation: + """Test suite for quantize_moe_experts config validator.""" + + def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps): + """quantize_moe_experts without adapter should fail.""" + cfg = ( + DictDefault( + quantize_moe_experts=True, + ) + | min_base_cfg + ) + with pytest.raises(ValueError, match="requires adapter"): + validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps) + + def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps): + """quantize_moe_experts without load_in_4bit/8bit should fail.""" + cfg = ( + DictDefault( + quantize_moe_experts=True, + adapter="lora", + ) + | min_base_cfg + ) + with pytest.raises(ValueError, match="requires load_in_4bit or load_in_8bit"): + validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps) + + def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps): + """quantize_moe_experts with qlora + 4bit should pass.""" + cfg = ( + DictDefault( + quantize_moe_experts=True, + adapter="qlora", + load_in_4bit=True, + ) + | min_base_cfg + ) + result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps) + assert result["quantize_moe_experts"] is True + + def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps): + """quantize_moe_experts with lora + 8bit should pass.""" + cfg = ( + DictDefault( + quantize_moe_experts=True, + adapter="lora", + load_in_8bit=True, + ) + | min_base_cfg + ) + result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps) + assert result["quantize_moe_experts"] is True + + def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps): + """quantize_moe_experts=false should not check adapter/quantization.""" + cfg = ( + DictDefault( + quantize_moe_experts=False, + ) + | min_base_cfg + ) + result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps) + assert result["quantize_moe_experts"] is False + + def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps): + """quantize_moe_experts should default to false.""" + cfg = DictDefault({}) | min_base_cfg + result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps) + assert result["quantize_moe_experts"] is False + + +class TestLoraTargetParametersDropout: + """Test that lora_dropout must be 0 when lora_target_parameters is set.""" + + def test_rejects_nonzero_dropout(self, min_base_cfg): + """lora_dropout > 0 with lora_target_parameters should fail.""" + cfg = ( + DictDefault( + adapter="lora", + lora_target_parameters=["mlp.experts.gate_up_proj"], + lora_dropout=0.1, + load_in_8bit=True, + ) + | min_base_cfg + ) + with pytest.raises(ValueError, match="lora_dropout must be 0"): + validate_config(cfg) + + def test_zero_dropout_passes(self, min_base_cfg): + """lora_dropout=0 with lora_target_parameters should pass.""" + cfg = ( + DictDefault( + adapter="lora", + lora_target_parameters=["mlp.experts.gate_up_proj"], + lora_dropout=0.0, + load_in_8bit=True, + ) + | min_base_cfg + ) + result = validate_config(cfg) + assert result["lora_dropout"] == 0.0 + + +class TestPeftPatchIdempotency: + """Test that patch_peft_target_parameters_matching is idempotent.""" + + def test_double_call_does_not_stack_wrappers(self): + """Calling patch twice should not double-wrap _inject_parameters.""" + from peft.tuners.tuners_utils import BaseTuner + + from axolotl.monkeypatch.moe_quant import ( + patch_peft_target_parameters_matching, + ) + + original = BaseTuner._inject_parameters + try: + patch_peft_target_parameters_matching() + first_patched = BaseTuner._inject_parameters + patch_peft_target_parameters_matching() + second_patched = BaseTuner._inject_parameters + # Should be same function, not double-wrapped + assert first_patched is second_patched + finally: + BaseTuner._inject_parameters = original + patch_peft_target_parameters_matching._axolotl_patched = False