diff --git a/examples/nemotron-h/120b-a12b-qlora.yaml b/examples/nemotron-h/120b-a12b-qlora.yaml new file mode 100644 index 000000000..67dcdb96e --- /dev/null +++ b/examples/nemotron-h/120b-a12b-qlora.yaml @@ -0,0 +1,74 @@ +base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16 + +# LoRA kernel patches are incompatible with this architecture — see README. +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +chat_template: tokenizer_default +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out +dataset_prepared_path: last_run_prepared + +sequence_len: 4096 +sample_packing: true + +use_cut_cross_entropy: true + +load_in_4bit: true +quantize_moe_experts: true +adapter: qlora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.0 +lora_target_modules: + # Attention projection layers (present in ~12 attention layers out of 88) + - q_proj + - k_proj + - v_proj + - o_proj + # To also train MoE expert weights, add them via lora_target_parameters + # (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj): + # lora_target_parameters: + # - up_proj + # - down_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_4bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false + +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 2 +saves_per_epoch: 1 +weight_decay: 0.0 + +special_tokens: diff --git a/examples/nemotron-h/README.md b/examples/nemotron-h/README.md new file mode 100644 index 000000000..3f9071431 --- /dev/null +++ b/examples/nemotron-h/README.md @@ -0,0 +1,48 @@ +# Nemotron-H (nvidia/NVIDIA-Nemotron-3-*) + +Hybrid Mamba2 / Attention / MoE architecture (`model_type: nemotron_h`). + +| Model | Total params | Active params | Layers | +|---|---|---|---| +| NVIDIA-Nemotron-3-Super-120B-A12B-BF16 | 120B | ~12B | 88 | +| NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 | 30B | ~3B | — | + +## Requirements + +```bash +pip install mamba-ssm causal-conv1d # fast Mamba2 CUDA kernels +``` + +## Architecture notes + +- Three block types per layer: **Mamba2** (selective SSM), **Attention** (sparse), **MoE** (mixture-of-experts). +- Only ~12 out of 88 blocks are attention layers (120B variant). +- MLP activation is `relu2` via `mlp_hidden_act` (not the usual `hidden_act`). + +## LoRA kernel patches + +All three LoRA Triton kernel patches must be disabled: + +```yaml +lora_qkv_kernel: false # attention lives in NemotronHBlock.mixer, not layer.self_attn +lora_o_kernel: false # same reason +lora_mlp_kernel: false # relu2 (mlp_hidden_act) is not supported by lora_mlp_kernel +``` + +## MoE expert weights + +NemotronH experts store `up_proj` and `down_proj` as 3D `nn.Parameter` tensors +(shape `[num_experts, out_dim, in_dim]`), **not** `nn.Linear` modules — there is no +`gate_proj`. To fine-tune them alongside attention, use `lora_target_parameters` +instead of `lora_target_modules`: + +```yaml +lora_target_parameters: + - up_proj + - down_proj +``` + +## Limitations + +- **MoE Triton kernels**: `lora_mlp_kernel` is not supported for NemotronH's MoE expert layers. The expert weights are 3D `nn.Parameter` tensors (not `nn.Linear`), which the Triton kernel does not support. Keep `lora_mlp_kernel: false`. +- **Gradient checkpointing**: Only supported when `sample_packing: true`. Without sample packing the upstream model marks `supports_gradient_checkpointing = False`. diff --git a/examples/nemotron-h/nano-30b-a3b-qlora.yaml b/examples/nemotron-h/nano-30b-a3b-qlora.yaml new file mode 100644 index 000000000..2d7307f99 --- /dev/null +++ b/examples/nemotron-h/nano-30b-a3b-qlora.yaml @@ -0,0 +1,74 @@ +# See examples/nemotron-h/README.md for architecture notes and requirements. +base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 + +# LoRA kernel patches are incompatible with this architecture — see README. +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +chat_template: tokenizer_default +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out +dataset_prepared_path: last_run_prepared + +sequence_len: 4096 +sample_packing: true + +use_cut_cross_entropy: true + +load_in_4bit: true +quantize_moe_experts: true +adapter: qlora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.0 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + # To also train MoE expert weights, add them via lora_target_parameters + # (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj): + # lora_target_parameters: + # - up_proj + # - down_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_4bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false + +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 + +special_tokens: diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index 181667cb9..1f943852a 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -23,4 +23,5 @@ MOE_ARCH_BLOCK = { "glm4_moe": "Glm4MoeDecoderLayer", "glm4_moe_lite": "Glm4MoeLiteDecoderLayer", "glm_moe_dsa": "GlmMoeDsaDecoderLayer", + "nemotron_h": "NemotronHMoE", } diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index dd3f4ddfa..3bfda7e23 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -590,9 +590,11 @@ class ModelLoader: "bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_storage": torch.bfloat16, } - if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( - self.cfg.deepspeed or self.is_fsdp_enabled - ): + if self.cfg.model_config_type in [ + "jamba", + "qwen2_moe", + "nemotron_h", + ] and not (self.cfg.deepspeed or self.is_fsdp_enabled): # for some reason, this causes the loss to be off by an order of magnitude # but deepspeed needs this still in bfloat16 bnb_config["bnb_4bit_quant_storage"] = torch.float32 diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 756eef886..50c3b85f4 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -142,6 +142,12 @@ class PatchManager: def apply_post_model_build_patches(self, model: PreTrainedModel): """Apply patches right after model build, before post-load setup.""" + if self.cfg.model_config_type == "nemotron_h": + # Must run after model build because NemotronHForCausalLM.__init__ + # calls register_nemotron_h_conversion_mapping() with overwrite=True, + # which would clobber any earlier fix. + self._fix_nemotron_h_conversion_mapping() + self._finalize_moe_expert_quantization(model) def apply_post_model_load_patches(self, model: PreTrainedModel): @@ -291,6 +297,66 @@ class PatchManager: patch_kimi_model() + if self.cfg.model_config_type == "nemotron_h": + if self.cfg.sample_packing: + from transformers.models.nemotron_h.modeling_nemotron_h import ( + NemotronHPreTrainedModel, + ) + + from axolotl.monkeypatch.models.nemotron_h.modeling import ( + patch_nemotron_h_modeling_packing, + ) + + patch_nemotron_h_modeling_packing() + # supports_gradient_checkpointing is only enabled after + # patch_nemotron_h_modeling_packing() installs the GC-compatible + # NemotronHBlock.forward. Without the patch, upstream marks this + # False because the original block forward is not GC-safe. + NemotronHPreTrainedModel.supports_gradient_checkpointing = True + + @staticmethod + def _fix_nemotron_h_conversion_mapping(): + """Remove the spurious embedding→embeddings WeightRenaming from the + nemotron_h checkpoint conversion mapping. + + The nvidia Hub model registers: + WeightRenaming("embedding.weight", "embeddings.weight") + to handle a legacy checkpoint variant. Its reverse (applied on save) + converts ``embeddings`` back to ``embedding``, which silently renames + ``backbone.embeddings.weight`` → ``backbone.embedding.weight`` when + merging LoRA adapters back into the base model. + """ + try: + from transformers.conversion_mapping import ( + WeightRenaming, + get_checkpoint_conversion_mapping, + register_checkpoint_conversion_mapping, + ) + except ImportError: + return + + mapping = get_checkpoint_conversion_mapping("nemotron_h") + if mapping is None: + return + + filtered = [ + entry + for entry in mapping + if not ( + isinstance(entry, WeightRenaming) + and entry.source_patterns == ["embedding.weight"] + and entry.target_patterns == ["embeddings.weight"] + ) + ] + if len(filtered) != len(mapping): + register_checkpoint_conversion_mapping( + "nemotron_h", filtered, overwrite=True + ) + LOG.info( + "Removed embedding→embeddings WeightRenaming from nemotron_h " + "checkpoint conversion mapping" + ) + def _apply_fp8_patches(self): """Apply patches for FP8 support.""" if self.cfg.fp8: diff --git a/src/axolotl/loaders/utils.py b/src/axolotl/loaders/utils.py index 187784b93..ce4018014 100644 --- a/src/axolotl/loaders/utils.py +++ b/src/axolotl/loaders/utils.py @@ -234,4 +234,6 @@ def get_linear_embedding_layers(model_type: str) -> list[str]: return ["embed_in", "embed_out"] if model_type == "falcon": return ["word_embeddings", "lm_head"] + if model_type == "nemotron_h": + return ["embeddings", "lm_head"] return ["embed_tokens", "lm_head"] diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 5bb3a32eb..c5d552c03 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -394,15 +394,15 @@ def apply_lora_kernel_patches( activation = text_config.hidden_act elif hasattr(text_config, "hidden_activation"): activation = text_config.hidden_activation + elif hasattr(text_config, "mlp_hidden_act"): + # Hybrid models (e.g. nemotron_h) use mlp_hidden_act instead of hidden_act + activation = text_config.mlp_hidden_act # map activation to supported activation - if "gelu" in activation: + if activation and "gelu" in activation: # gemma3 uses gelu_pytorch_tanh activation = "gelu" - if activation not in SUPPORTED_ACTIVATIONS: - raise NotImplementedError(f"Activation {activation} is not supported") - layers = get_layers(model) # Patch each layer @@ -444,6 +444,15 @@ def apply_lora_kernel_patches( ) for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer): if cfg.lora_mlp_kernel: + # Check is inside lora_mlp_kernel guard so models with an + # unsupported activation (e.g. nemotron_h uses relu2) can set + # lora_mlp_kernel: false without hitting an error here. + if activation not in SUPPORTED_ACTIVATIONS: + raise NotImplementedError( + f"Activation {activation!r} is not supported by lora_mlp_kernel. " + f"Set `lora_mlp_kernel: false` in your config or use a model with " + f"a supported activation ({SUPPORTED_ACTIVATIONS})." + ) # MLP patching can_patch_mlp = all( hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj) diff --git a/src/axolotl/monkeypatch/models/nemotron_h/__init__.py b/src/axolotl/monkeypatch/models/nemotron_h/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/nemotron_h/modeling.py b/src/axolotl/monkeypatch/models/nemotron_h/modeling.py new file mode 100644 index 000000000..a36c34259 --- /dev/null +++ b/src/axolotl/monkeypatch/models/nemotron_h/modeling.py @@ -0,0 +1,315 @@ +"""Sample-packing patch for NemotronH (Mamba2/Attention/MoE hybrid). + +Threads seq_idx (derived from position_ids) into the Mamba2 SSM kernels so +packed-sequence boundaries reset SSM state. Upstream hard-codes seq_idx=None, +which leaks hidden state across boundaries. Attention and MoE blocks need no +changes — only the Mamba2 mixer is patched. +""" + +import importlib + +import torch + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def get_seq_idx(position_ids: torch.Tensor) -> torch.Tensor: + """Convert position_ids [B, T] → seq_idx [B, T] int32 for mamba-ssm kernels. + + Example: position_ids [[0,1,2,3,0,1,2]] → seq_idx [[0,0,0,0,1,1,1]] + """ + return (torch.cumsum((position_ids == 0).int(), dim=-1) - 1).to(torch.int32) + + +def patch_nemotron_h_modeling_packing(): + """Patch NemotronH for sample packing: seq_idx threading into Mamba2 SSM kernels. + + _get_unpad_data is handled by SUPPORTED_MULTIPACK_MODEL_TYPES / patch_for_multipack(). + This function only applies the seq_idx patches that are unique to nemotron_h. + """ + try: + mod = importlib.import_module( + "transformers.models.nemotron_h.modeling_nemotron_h" + ) + except ImportError: + LOG.warning("nemotron_h not found in transformers, skipping packing patches") + return + + NemotronHMamba2Mixer = mod.NemotronHMamba2Mixer + NemotronHBlock = mod.NemotronHBlock + + # Patch 1: cuda_kernels_forward — add seq_idx param and thread it to + # causal_conv1d_fn and mamba_chunk_scan_combined. Fused fast path is + # bypassed when seq_idx is set (requires causal_conv1d_cuda C extension). + def patched_cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params=None, + attention_mask=None, + seq_idx=None, + ): + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_to_remove = ( + 2 * self.intermediate_size + + 2 * self.n_groups * self.ssm_state_size + + self.num_heads + ) + + if cache_params is not None and cache_params.has_previous_state: + in_projected_states = self.in_proj(hidden_states.squeeze(1)) + d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 + split_projection_dim = [ + d_mlp, + d_mlp, + self.intermediate_size, + self.conv_dim, + self.num_heads, + ] + _, _, gate, hidden_states_B_C, dt = torch.split( + in_projected_states, split_projection_dim, dim=-1 + ) + hidden_states_B_C = mod.causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, + groups_time_state_size, + groups_time_state_size, + ], + dim=-1, + ) + A = -torch.exp(self.A_log.float()) + A = ( + A[:, None, ...][:, :, None] + .expand(-1, self.head_dim, self.ssm_state_size) + .to(dtype=torch.float32) + ) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view( + batch_size, self.num_heads, self.head_dim + ) + hidden_states = mod.selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view( + batch_size, self.num_heads * self.head_dim + ) + hidden_states = self.norm(hidden_states, gate) + out = self.out_proj(hidden_states)[:, None, ...] + + else: + if attention_mask is not None and not torch.all(attention_mask == 1): + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + projected_states = self.in_proj(hidden_states) + A = -torch.exp(self.A_log.float()) + dt_limit_kwargs = ( + {} + if self.time_step_limit is None + else {"dt_limit": self.time_step_limit} + ) + if attention_mask is not None: + input_not_masked = torch.all(attention_mask == 1) + else: + input_not_masked = True + + if ( + self.use_mem_eff_path + and self.training + and cache_params is None + and input_not_masked + and seq_idx is None + ): + out, ssm_state = mod.mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=seq_idx, + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=True, + **dt_limit_kwargs, + ) + else: + gate, hidden_states_B_C, time_step = torch.split( + projected_states, + [self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, + ) + + if cache_params is not None: + hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2) + conv_state = torch.nn.functional.pad( + hidden_states_B_C_t, + (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0), + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + + if mod.causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + ]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[ + :, :seq_len + ] + ) + else: + hidden_states_B_C = mod.causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=seq_idx, + ).transpose(1, 2)[:, :seq_len] + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, + groups_time_state_size, + groups_time_state_size, + ], + dim=-1, + ) + + if attention_mask is not None and not torch.all(attention_mask == 1): + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to( + dtype + ) + + scan_output, ssm_state = mod.mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + time_step, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=seq_idx, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) + scan_output = self.norm(scan_output, gate) + out = self.out_proj(scan_output) + + return out + + NemotronHMamba2Mixer.cuda_kernels_forward = patched_cuda_kernels_forward + + # Patch 2: Mamba2Mixer.forward — add seq_idx, guard on causal_conv1d_fn, + # restore the cuda stream context (matches upstream; avoids NaN on multi-GPU). + def patched_mixer_forward( + self, + hidden_states, + cache_params=None, + attention_mask=None, + seq_idx=None, + ): + if seq_idx is not None and mod.causal_conv1d_fn is None: + raise RuntimeError( + "Nemotron-H sample packing requires causal_conv1d_fn. " + "Install with: pip install mamba-ssm causal-conv1d" + ) + if ( + mod.is_fast_path_available + and "cuda" in self.in_proj.weight.device.type + and not mod.is_torchdynamo_compiling() + ): + with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)): + return self.cuda_kernels_forward( + hidden_states, cache_params, attention_mask, seq_idx=seq_idx + ) + return self.torch_forward(hidden_states, cache_params, attention_mask) + + NemotronHMamba2Mixer.forward = patched_mixer_forward + + # Patch 3: NemotronHBlock.forward — compute seq_idx from position_ids and + # pass it to the Mamba2 mixer. Skipped during decode (has_previous_state). + def patched_block_forward( + self, + hidden_states, + past_key_values=None, + cache_position=None, + attention_mask=None, + position_ids=None, + use_cache=False, + **kwargs, + ): + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + + if self.block_type == "mamba": + is_decoding = ( + past_key_values is not None and past_key_values.has_previous_state + ) + seq_idx = ( + get_seq_idx(position_ids) + if position_ids is not None and not is_decoding + else None + ) + hidden_states = self.mixer( + hidden_states, + cache_params=past_key_values, + attention_mask=attention_mask, + seq_idx=seq_idx, + ) + elif self.block_type == "attention": + hidden_states, _ = self.mixer( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + user_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + else: + hidden_states = self.mixer(hidden_states) + + hidden_states = residual + hidden_states + return hidden_states + + NemotronHBlock.forward = patched_block_forward + + LOG.info("Applied NemotronH sample packing patch (seq_idx threading into Mamba2)") diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index 68c458f5a..d58b78af2 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -154,6 +154,8 @@ def patch_peft_target_parameters_matching(): 1. Expands short suffixes to full module paths for parametrized modules. 2. Iterates params in definition order (not alphabetical order) so saved adapters are compatible with standard PEFT, vLLM, etc. + 3. Skips ParametrizationList synthetic paths to prevent PEFT from mistakenly + targeting quantized expert params via name-suffix matching. """ if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False): return @@ -293,5 +295,23 @@ def patch_peft_target_parameters_matching(): self.targeted_parameter_names.append(key) BaseTuner._inject_parameters = _patched_inject_parameters + + # Skip ParametrizationList synthetic paths (e.g. "...parametrizations.up_proj") + # so PEFT suffix-matching doesn't try to wrap quantized expert params in LoRA. + # Previous MoE models (Mixtral, DeepSeek, etc.) stored experts as nn.Linear + # modules, so PEFT's normal target_modules path worked fine. NemotronH uses + # 3D nn.Parameter tensors via our quantize_moe_experts parametrization, which + # exposes synthetic ".parametrizations." paths that PEFT's suffix match + # would otherwise treat as target_modules candidates. + _original_check = BaseTuner._check_target_module_exists + + @staticmethod + def _patched_check_target_module_exists(config, key): + if ".parametrizations." in key: + return False + return _original_check(config, key) + + BaseTuner._check_target_module_exists = _patched_check_target_module_exists + patch_peft_target_parameters_matching._axolotl_patched = True LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering") diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 8566af526..9e2157ef4 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -62,6 +62,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "mistral4", "afmoe", "nemotron", + "nemotron_h", ] diff --git a/src/axolotl/utils/chat_templates/templates/nemotron_h.jinja b/src/axolotl/utils/chat_templates/templates/nemotron_h.jinja new file mode 100644 index 000000000..75dcd9d9f --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/nemotron_h.jinja @@ -0,0 +1,16 @@ +{%- if messages and messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- set messages = messages[1:] %} +{%- endif %} +{%- for message in messages %} + {%- if message.role == 'user' %} + {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} + {%- elif message.role == 'assistant' %} + {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {%- else %} + {{- raise_exception('Unexpected role: ' + message.role) }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 7ffa793f2..4b759237e 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -62,6 +62,7 @@ class ChatTemplate(str, Enum): qwen3 = "qwen3" qwen3_5 = "qwen3_5" falcon_h1 = "falcon_h1" + nemotron_h = "nemotron_h" tokenizer_default = "tokenizer_default" exaone = "exaone" exaone4 = "exaone4" diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index f665a99ff..ff7813600 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1258,6 +1258,21 @@ class ModelCompatibilityValidationMixin: raise ValueError("gradient_checkpointing is not supported for MPT models") return self + @model_validator(mode="after") + def check_nemotron_h_gradient_checkpointing(self): + if ( + self.base_model + and "nemotron-h" in self.base_model.lower() + and self.gradient_checkpointing + and not self.sample_packing + ): + raise ValueError( + "gradient_checkpointing for nemotron_h requires sample_packing: true. " + "The upstream model marks supports_gradient_checkpointing=False; " + "axolotl only enables it after applying the sample-packing patch." + ) + return self + @model_validator(mode="after") def check_gradient_checkpointing_w_offload(self): if self.gradient_checkpointing == "offload":