super nemo support (#3508)
* nemo support * config * rename , config * nemotron packing * config fix * read me + configs * gc compat bug * config chnages for qwen and pad token nemo * patch nemotron_h weight renaming so it doesn't get reversed to embedding (singular noun) on checkpoint save * lint * revert qwen3.5 config changes, not needed in this pr * lint * Update examples/nemotron-h/120b-a12b-qlora.yaml Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * Update examples/nemotron-h/nano-30b-a3b-qlora.yaml Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * readme + validation * lazy load comment * Update examples/nemotron-h/120b-a12b-qlora.yaml Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * val fix * add nemo to multi packing --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
74
examples/nemotron-h/120b-a12b-qlora.yaml
Normal file
74
examples/nemotron-h/120b-a12b-qlora.yaml
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
|
||||||
|
|
||||||
|
# LoRA kernel patches are incompatible with this architecture — see README.
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
chat_template: tokenizer_default
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:20%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
use_cut_cross_entropy: true
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.0
|
||||||
|
lora_target_modules:
|
||||||
|
# Attention projection layers (present in ~12 attention layers out of 88)
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
# To also train MoE expert weights, add them via lora_target_parameters
|
||||||
|
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - up_proj
|
||||||
|
# - down_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_4bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
48
examples/nemotron-h/README.md
Normal file
48
examples/nemotron-h/README.md
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
# Nemotron-H (nvidia/NVIDIA-Nemotron-3-*)
|
||||||
|
|
||||||
|
Hybrid Mamba2 / Attention / MoE architecture (`model_type: nemotron_h`).
|
||||||
|
|
||||||
|
| Model | Total params | Active params | Layers |
|
||||||
|
|---|---|---|---|
|
||||||
|
| NVIDIA-Nemotron-3-Super-120B-A12B-BF16 | 120B | ~12B | 88 |
|
||||||
|
| NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 | 30B | ~3B | — |
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mamba-ssm causal-conv1d # fast Mamba2 CUDA kernels
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture notes
|
||||||
|
|
||||||
|
- Three block types per layer: **Mamba2** (selective SSM), **Attention** (sparse), **MoE** (mixture-of-experts).
|
||||||
|
- Only ~12 out of 88 blocks are attention layers (120B variant).
|
||||||
|
- MLP activation is `relu2` via `mlp_hidden_act` (not the usual `hidden_act`).
|
||||||
|
|
||||||
|
## LoRA kernel patches
|
||||||
|
|
||||||
|
All three LoRA Triton kernel patches must be disabled:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lora_qkv_kernel: false # attention lives in NemotronHBlock.mixer, not layer.self_attn
|
||||||
|
lora_o_kernel: false # same reason
|
||||||
|
lora_mlp_kernel: false # relu2 (mlp_hidden_act) is not supported by lora_mlp_kernel
|
||||||
|
```
|
||||||
|
|
||||||
|
## MoE expert weights
|
||||||
|
|
||||||
|
NemotronH experts store `up_proj` and `down_proj` as 3D `nn.Parameter` tensors
|
||||||
|
(shape `[num_experts, out_dim, in_dim]`), **not** `nn.Linear` modules — there is no
|
||||||
|
`gate_proj`. To fine-tune them alongside attention, use `lora_target_parameters`
|
||||||
|
instead of `lora_target_modules`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lora_target_parameters:
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
```
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- **MoE Triton kernels**: `lora_mlp_kernel` is not supported for NemotronH's MoE expert layers. The expert weights are 3D `nn.Parameter` tensors (not `nn.Linear`), which the Triton kernel does not support. Keep `lora_mlp_kernel: false`.
|
||||||
|
- **Gradient checkpointing**: Only supported when `sample_packing: true`. Without sample packing the upstream model marks `supports_gradient_checkpointing = False`.
|
||||||
74
examples/nemotron-h/nano-30b-a3b-qlora.yaml
Normal file
74
examples/nemotron-h/nano-30b-a3b-qlora.yaml
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# See examples/nemotron-h/README.md for architecture notes and requirements.
|
||||||
|
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
|
||||||
|
|
||||||
|
# LoRA kernel patches are incompatible with this architecture — see README.
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
chat_template: tokenizer_default
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:20%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
use_cut_cross_entropy: true
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
# To also train MoE expert weights, add them via lora_target_parameters
|
||||||
|
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - up_proj
|
||||||
|
# - down_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_4bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
@@ -23,4 +23,5 @@ MOE_ARCH_BLOCK = {
|
|||||||
"glm4_moe": "Glm4MoeDecoderLayer",
|
"glm4_moe": "Glm4MoeDecoderLayer",
|
||||||
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
||||||
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
||||||
|
"nemotron_h": "NemotronHMoE",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -590,9 +590,11 @@ class ModelLoader:
|
|||||||
"bnb_4bit_quant_type": "nf4",
|
"bnb_4bit_quant_type": "nf4",
|
||||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||||
}
|
}
|
||||||
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
|
if self.cfg.model_config_type in [
|
||||||
self.cfg.deepspeed or self.is_fsdp_enabled
|
"jamba",
|
||||||
):
|
"qwen2_moe",
|
||||||
|
"nemotron_h",
|
||||||
|
] and not (self.cfg.deepspeed or self.is_fsdp_enabled):
|
||||||
# for some reason, this causes the loss to be off by an order of magnitude
|
# for some reason, this causes the loss to be off by an order of magnitude
|
||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
|
|||||||
@@ -142,6 +142,12 @@ class PatchManager:
|
|||||||
|
|
||||||
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches right after model build, before post-load setup."""
|
"""Apply patches right after model build, before post-load setup."""
|
||||||
|
if self.cfg.model_config_type == "nemotron_h":
|
||||||
|
# Must run after model build because NemotronHForCausalLM.__init__
|
||||||
|
# calls register_nemotron_h_conversion_mapping() with overwrite=True,
|
||||||
|
# which would clobber any earlier fix.
|
||||||
|
self._fix_nemotron_h_conversion_mapping()
|
||||||
|
|
||||||
self._finalize_moe_expert_quantization(model)
|
self._finalize_moe_expert_quantization(model)
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
@@ -291,6 +297,66 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_kimi_model()
|
patch_kimi_model()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "nemotron_h":
|
||||||
|
if self.cfg.sample_packing:
|
||||||
|
from transformers.models.nemotron_h.modeling_nemotron_h import (
|
||||||
|
NemotronHPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.models.nemotron_h.modeling import (
|
||||||
|
patch_nemotron_h_modeling_packing,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_nemotron_h_modeling_packing()
|
||||||
|
# supports_gradient_checkpointing is only enabled after
|
||||||
|
# patch_nemotron_h_modeling_packing() installs the GC-compatible
|
||||||
|
# NemotronHBlock.forward. Without the patch, upstream marks this
|
||||||
|
# False because the original block forward is not GC-safe.
|
||||||
|
NemotronHPreTrainedModel.supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_nemotron_h_conversion_mapping():
|
||||||
|
"""Remove the spurious embedding→embeddings WeightRenaming from the
|
||||||
|
nemotron_h checkpoint conversion mapping.
|
||||||
|
|
||||||
|
The nvidia Hub model registers:
|
||||||
|
WeightRenaming("embedding.weight", "embeddings.weight")
|
||||||
|
to handle a legacy checkpoint variant. Its reverse (applied on save)
|
||||||
|
converts ``embeddings`` back to ``embedding``, which silently renames
|
||||||
|
``backbone.embeddings.weight`` → ``backbone.embedding.weight`` when
|
||||||
|
merging LoRA adapters back into the base model.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from transformers.conversion_mapping import (
|
||||||
|
WeightRenaming,
|
||||||
|
get_checkpoint_conversion_mapping,
|
||||||
|
register_checkpoint_conversion_mapping,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
mapping = get_checkpoint_conversion_mapping("nemotron_h")
|
||||||
|
if mapping is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
filtered = [
|
||||||
|
entry
|
||||||
|
for entry in mapping
|
||||||
|
if not (
|
||||||
|
isinstance(entry, WeightRenaming)
|
||||||
|
and entry.source_patterns == ["embedding.weight"]
|
||||||
|
and entry.target_patterns == ["embeddings.weight"]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if len(filtered) != len(mapping):
|
||||||
|
register_checkpoint_conversion_mapping(
|
||||||
|
"nemotron_h", filtered, overwrite=True
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
"Removed embedding→embeddings WeightRenaming from nemotron_h "
|
||||||
|
"checkpoint conversion mapping"
|
||||||
|
)
|
||||||
|
|
||||||
def _apply_fp8_patches(self):
|
def _apply_fp8_patches(self):
|
||||||
"""Apply patches for FP8 support."""
|
"""Apply patches for FP8 support."""
|
||||||
if self.cfg.fp8:
|
if self.cfg.fp8:
|
||||||
|
|||||||
@@ -234,4 +234,6 @@ def get_linear_embedding_layers(model_type: str) -> list[str]:
|
|||||||
return ["embed_in", "embed_out"]
|
return ["embed_in", "embed_out"]
|
||||||
if model_type == "falcon":
|
if model_type == "falcon":
|
||||||
return ["word_embeddings", "lm_head"]
|
return ["word_embeddings", "lm_head"]
|
||||||
|
if model_type == "nemotron_h":
|
||||||
|
return ["embeddings", "lm_head"]
|
||||||
return ["embed_tokens", "lm_head"]
|
return ["embed_tokens", "lm_head"]
|
||||||
|
|||||||
@@ -394,15 +394,15 @@ def apply_lora_kernel_patches(
|
|||||||
activation = text_config.hidden_act
|
activation = text_config.hidden_act
|
||||||
elif hasattr(text_config, "hidden_activation"):
|
elif hasattr(text_config, "hidden_activation"):
|
||||||
activation = text_config.hidden_activation
|
activation = text_config.hidden_activation
|
||||||
|
elif hasattr(text_config, "mlp_hidden_act"):
|
||||||
|
# Hybrid models (e.g. nemotron_h) use mlp_hidden_act instead of hidden_act
|
||||||
|
activation = text_config.mlp_hidden_act
|
||||||
|
|
||||||
# map activation to supported activation
|
# map activation to supported activation
|
||||||
if "gelu" in activation:
|
if activation and "gelu" in activation:
|
||||||
# gemma3 uses gelu_pytorch_tanh
|
# gemma3 uses gelu_pytorch_tanh
|
||||||
activation = "gelu"
|
activation = "gelu"
|
||||||
|
|
||||||
if activation not in SUPPORTED_ACTIVATIONS:
|
|
||||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
|
||||||
|
|
||||||
layers = get_layers(model)
|
layers = get_layers(model)
|
||||||
|
|
||||||
# Patch each layer
|
# Patch each layer
|
||||||
@@ -444,6 +444,15 @@ def apply_lora_kernel_patches(
|
|||||||
)
|
)
|
||||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||||
if cfg.lora_mlp_kernel:
|
if cfg.lora_mlp_kernel:
|
||||||
|
# Check is inside lora_mlp_kernel guard so models with an
|
||||||
|
# unsupported activation (e.g. nemotron_h uses relu2) can set
|
||||||
|
# lora_mlp_kernel: false without hitting an error here.
|
||||||
|
if activation not in SUPPORTED_ACTIVATIONS:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Activation {activation!r} is not supported by lora_mlp_kernel. "
|
||||||
|
f"Set `lora_mlp_kernel: false` in your config or use a model with "
|
||||||
|
f"a supported activation ({SUPPORTED_ACTIVATIONS})."
|
||||||
|
)
|
||||||
# MLP patching
|
# MLP patching
|
||||||
can_patch_mlp = all(
|
can_patch_mlp = all(
|
||||||
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
|
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
|
||||||
|
|||||||
315
src/axolotl/monkeypatch/models/nemotron_h/modeling.py
Normal file
315
src/axolotl/monkeypatch/models/nemotron_h/modeling.py
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
"""Sample-packing patch for NemotronH (Mamba2/Attention/MoE hybrid).
|
||||||
|
|
||||||
|
Threads seq_idx (derived from position_ids) into the Mamba2 SSM kernels so
|
||||||
|
packed-sequence boundaries reset SSM state. Upstream hard-codes seq_idx=None,
|
||||||
|
which leaks hidden state across boundaries. Attention and MoE blocks need no
|
||||||
|
changes — only the Mamba2 mixer is patched.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_seq_idx(position_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Convert position_ids [B, T] → seq_idx [B, T] int32 for mamba-ssm kernels.
|
||||||
|
|
||||||
|
Example: position_ids [[0,1,2,3,0,1,2]] → seq_idx [[0,0,0,0,1,1,1]]
|
||||||
|
"""
|
||||||
|
return (torch.cumsum((position_ids == 0).int(), dim=-1) - 1).to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_nemotron_h_modeling_packing():
|
||||||
|
"""Patch NemotronH for sample packing: seq_idx threading into Mamba2 SSM kernels.
|
||||||
|
|
||||||
|
_get_unpad_data is handled by SUPPORTED_MULTIPACK_MODEL_TYPES / patch_for_multipack().
|
||||||
|
This function only applies the seq_idx patches that are unique to nemotron_h.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
mod = importlib.import_module(
|
||||||
|
"transformers.models.nemotron_h.modeling_nemotron_h"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning("nemotron_h not found in transformers, skipping packing patches")
|
||||||
|
return
|
||||||
|
|
||||||
|
NemotronHMamba2Mixer = mod.NemotronHMamba2Mixer
|
||||||
|
NemotronHBlock = mod.NemotronHBlock
|
||||||
|
|
||||||
|
# Patch 1: cuda_kernels_forward — add seq_idx param and thread it to
|
||||||
|
# causal_conv1d_fn and mamba_chunk_scan_combined. Fused fast path is
|
||||||
|
# bypassed when seq_idx is set (requires causal_conv1d_cuda C extension).
|
||||||
|
def patched_cuda_kernels_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cache_params=None,
|
||||||
|
attention_mask=None,
|
||||||
|
seq_idx=None,
|
||||||
|
):
|
||||||
|
batch_size, seq_len, _ = hidden_states.shape
|
||||||
|
groups_time_state_size = self.n_groups * self.ssm_state_size
|
||||||
|
d_to_remove = (
|
||||||
|
2 * self.intermediate_size
|
||||||
|
+ 2 * self.n_groups * self.ssm_state_size
|
||||||
|
+ self.num_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_params is not None and cache_params.has_previous_state:
|
||||||
|
in_projected_states = self.in_proj(hidden_states.squeeze(1))
|
||||||
|
d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
|
||||||
|
split_projection_dim = [
|
||||||
|
d_mlp,
|
||||||
|
d_mlp,
|
||||||
|
self.intermediate_size,
|
||||||
|
self.conv_dim,
|
||||||
|
self.num_heads,
|
||||||
|
]
|
||||||
|
_, _, gate, hidden_states_B_C, dt = torch.split(
|
||||||
|
in_projected_states, split_projection_dim, dim=-1
|
||||||
|
)
|
||||||
|
hidden_states_B_C = mod.causal_conv1d_update(
|
||||||
|
hidden_states_B_C,
|
||||||
|
cache_params.conv_states[self.layer_idx],
|
||||||
|
self.conv1d.weight.squeeze(1),
|
||||||
|
self.conv1d.bias,
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
|
hidden_states, B, C = torch.split(
|
||||||
|
hidden_states_B_C,
|
||||||
|
[
|
||||||
|
self.intermediate_size,
|
||||||
|
groups_time_state_size,
|
||||||
|
groups_time_state_size,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
A = -torch.exp(self.A_log.float())
|
||||||
|
A = (
|
||||||
|
A[:, None, ...][:, :, None]
|
||||||
|
.expand(-1, self.head_dim, self.ssm_state_size)
|
||||||
|
.to(dtype=torch.float32)
|
||||||
|
)
|
||||||
|
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
|
||||||
|
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||||
|
D = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||||
|
B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
|
||||||
|
C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
|
||||||
|
hidden_states_reshaped = hidden_states.view(
|
||||||
|
batch_size, self.num_heads, self.head_dim
|
||||||
|
)
|
||||||
|
hidden_states = mod.selective_state_update(
|
||||||
|
cache_params.ssm_states[self.layer_idx],
|
||||||
|
hidden_states_reshaped,
|
||||||
|
dt,
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
D,
|
||||||
|
z=None,
|
||||||
|
dt_bias=dt_bias,
|
||||||
|
dt_softplus=True,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
batch_size, self.num_heads * self.head_dim
|
||||||
|
)
|
||||||
|
hidden_states = self.norm(hidden_states, gate)
|
||||||
|
out = self.out_proj(hidden_states)[:, None, ...]
|
||||||
|
|
||||||
|
else:
|
||||||
|
if attention_mask is not None and not torch.all(attention_mask == 1):
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
|
||||||
|
|
||||||
|
projected_states = self.in_proj(hidden_states)
|
||||||
|
A = -torch.exp(self.A_log.float())
|
||||||
|
dt_limit_kwargs = (
|
||||||
|
{}
|
||||||
|
if self.time_step_limit is None
|
||||||
|
else {"dt_limit": self.time_step_limit}
|
||||||
|
)
|
||||||
|
if attention_mask is not None:
|
||||||
|
input_not_masked = torch.all(attention_mask == 1)
|
||||||
|
else:
|
||||||
|
input_not_masked = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.use_mem_eff_path
|
||||||
|
and self.training
|
||||||
|
and cache_params is None
|
||||||
|
and input_not_masked
|
||||||
|
and seq_idx is None
|
||||||
|
):
|
||||||
|
out, ssm_state = mod.mamba_split_conv1d_scan_combined(
|
||||||
|
projected_states,
|
||||||
|
self.conv1d.weight.squeeze(1),
|
||||||
|
self.conv1d.bias,
|
||||||
|
self.dt_bias,
|
||||||
|
A,
|
||||||
|
D=self.D,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
seq_idx=seq_idx,
|
||||||
|
activation=self.activation,
|
||||||
|
rmsnorm_weight=self.norm.weight,
|
||||||
|
rmsnorm_eps=self.norm.variance_epsilon,
|
||||||
|
outproj_weight=self.out_proj.weight,
|
||||||
|
outproj_bias=self.out_proj.bias,
|
||||||
|
headdim=self.head_dim,
|
||||||
|
ngroups=self.n_groups,
|
||||||
|
norm_before_gate=False,
|
||||||
|
return_final_states=True,
|
||||||
|
**dt_limit_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gate, hidden_states_B_C, time_step = torch.split(
|
||||||
|
projected_states,
|
||||||
|
[self.intermediate_size, self.conv_dim, self.num_heads],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_params is not None:
|
||||||
|
hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2)
|
||||||
|
conv_state = torch.nn.functional.pad(
|
||||||
|
hidden_states_B_C_t,
|
||||||
|
(self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0),
|
||||||
|
)
|
||||||
|
cache_params.conv_states[self.layer_idx].copy_(conv_state)
|
||||||
|
|
||||||
|
if mod.causal_conv1d_fn is None or self.activation not in [
|
||||||
|
"silu",
|
||||||
|
"swish",
|
||||||
|
]:
|
||||||
|
hidden_states_B_C = self.act(
|
||||||
|
self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[
|
||||||
|
:, :seq_len
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states_B_C = mod.causal_conv1d_fn(
|
||||||
|
x=hidden_states_B_C.transpose(1, 2),
|
||||||
|
weight=self.conv1d.weight.squeeze(1),
|
||||||
|
bias=self.conv1d.bias,
|
||||||
|
activation=self.activation,
|
||||||
|
seq_idx=seq_idx,
|
||||||
|
).transpose(1, 2)[:, :seq_len]
|
||||||
|
|
||||||
|
hidden_states, B, C = torch.split(
|
||||||
|
hidden_states_B_C,
|
||||||
|
[
|
||||||
|
self.intermediate_size,
|
||||||
|
groups_time_state_size,
|
||||||
|
groups_time_state_size,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None and not torch.all(attention_mask == 1):
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
hidden_states = (hidden_states * attention_mask[:, :, None]).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
scan_output, ssm_state = mod.mamba_chunk_scan_combined(
|
||||||
|
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
|
||||||
|
time_step,
|
||||||
|
A,
|
||||||
|
B.view(batch_size, seq_len, self.n_groups, -1),
|
||||||
|
C.view(batch_size, seq_len, self.n_groups, -1),
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
D=self.D,
|
||||||
|
z=None,
|
||||||
|
seq_idx=seq_idx,
|
||||||
|
return_final_states=True,
|
||||||
|
dt_bias=self.dt_bias,
|
||||||
|
dt_softplus=True,
|
||||||
|
**dt_limit_kwargs,
|
||||||
|
)
|
||||||
|
if ssm_state is not None and cache_params is not None:
|
||||||
|
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
||||||
|
scan_output = scan_output.view(batch_size, seq_len, -1)
|
||||||
|
scan_output = self.norm(scan_output, gate)
|
||||||
|
out = self.out_proj(scan_output)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
NemotronHMamba2Mixer.cuda_kernels_forward = patched_cuda_kernels_forward
|
||||||
|
|
||||||
|
# Patch 2: Mamba2Mixer.forward — add seq_idx, guard on causal_conv1d_fn,
|
||||||
|
# restore the cuda stream context (matches upstream; avoids NaN on multi-GPU).
|
||||||
|
def patched_mixer_forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cache_params=None,
|
||||||
|
attention_mask=None,
|
||||||
|
seq_idx=None,
|
||||||
|
):
|
||||||
|
if seq_idx is not None and mod.causal_conv1d_fn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Nemotron-H sample packing requires causal_conv1d_fn. "
|
||||||
|
"Install with: pip install mamba-ssm causal-conv1d"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
mod.is_fast_path_available
|
||||||
|
and "cuda" in self.in_proj.weight.device.type
|
||||||
|
and not mod.is_torchdynamo_compiling()
|
||||||
|
):
|
||||||
|
with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
|
||||||
|
return self.cuda_kernels_forward(
|
||||||
|
hidden_states, cache_params, attention_mask, seq_idx=seq_idx
|
||||||
|
)
|
||||||
|
return self.torch_forward(hidden_states, cache_params, attention_mask)
|
||||||
|
|
||||||
|
NemotronHMamba2Mixer.forward = patched_mixer_forward
|
||||||
|
|
||||||
|
# Patch 3: NemotronHBlock.forward — compute seq_idx from position_ids and
|
||||||
|
# pass it to the Mamba2 mixer. Skipped during decode (has_previous_state).
|
||||||
|
def patched_block_forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
past_key_values=None,
|
||||||
|
cache_position=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
use_cache=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
||||||
|
|
||||||
|
if self.block_type == "mamba":
|
||||||
|
is_decoding = (
|
||||||
|
past_key_values is not None and past_key_values.has_previous_state
|
||||||
|
)
|
||||||
|
seq_idx = (
|
||||||
|
get_seq_idx(position_ids)
|
||||||
|
if position_ids is not None and not is_decoding
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
hidden_states = self.mixer(
|
||||||
|
hidden_states,
|
||||||
|
cache_params=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
seq_idx=seq_idx,
|
||||||
|
)
|
||||||
|
elif self.block_type == "attention":
|
||||||
|
hidden_states, _ = self.mixer(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
user_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = self.mixer(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
NemotronHBlock.forward = patched_block_forward
|
||||||
|
|
||||||
|
LOG.info("Applied NemotronH sample packing patch (seq_idx threading into Mamba2)")
|
||||||
@@ -154,6 +154,8 @@ def patch_peft_target_parameters_matching():
|
|||||||
1. Expands short suffixes to full module paths for parametrized modules.
|
1. Expands short suffixes to full module paths for parametrized modules.
|
||||||
2. Iterates params in definition order (not alphabetical order) so saved
|
2. Iterates params in definition order (not alphabetical order) so saved
|
||||||
adapters are compatible with standard PEFT, vLLM, etc.
|
adapters are compatible with standard PEFT, vLLM, etc.
|
||||||
|
3. Skips ParametrizationList synthetic paths to prevent PEFT from mistakenly
|
||||||
|
targeting quantized expert params via name-suffix matching.
|
||||||
"""
|
"""
|
||||||
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
||||||
return
|
return
|
||||||
@@ -293,5 +295,23 @@ def patch_peft_target_parameters_matching():
|
|||||||
self.targeted_parameter_names.append(key)
|
self.targeted_parameter_names.append(key)
|
||||||
|
|
||||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||||
|
|
||||||
|
# Skip ParametrizationList synthetic paths (e.g. "...parametrizations.up_proj")
|
||||||
|
# so PEFT suffix-matching doesn't try to wrap quantized expert params in LoRA.
|
||||||
|
# Previous MoE models (Mixtral, DeepSeek, etc.) stored experts as nn.Linear
|
||||||
|
# modules, so PEFT's normal target_modules path worked fine. NemotronH uses
|
||||||
|
# 3D nn.Parameter tensors via our quantize_moe_experts parametrization, which
|
||||||
|
# exposes synthetic ".parametrizations.<name>" paths that PEFT's suffix match
|
||||||
|
# would otherwise treat as target_modules candidates.
|
||||||
|
_original_check = BaseTuner._check_target_module_exists
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patched_check_target_module_exists(config, key):
|
||||||
|
if ".parametrizations." in key:
|
||||||
|
return False
|
||||||
|
return _original_check(config, key)
|
||||||
|
|
||||||
|
BaseTuner._check_target_module_exists = _patched_check_target_module_exists
|
||||||
|
|
||||||
patch_peft_target_parameters_matching._axolotl_patched = True
|
patch_peft_target_parameters_matching._axolotl_patched = True
|
||||||
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")
|
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"mistral4",
|
"mistral4",
|
||||||
"afmoe",
|
"afmoe",
|
||||||
"nemotron",
|
"nemotron",
|
||||||
|
"nemotron_h",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
16
src/axolotl/utils/chat_templates/templates/nemotron_h.jinja
Normal file
16
src/axolotl/utils/chat_templates/templates/nemotron_h.jinja
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{%- if messages and messages[0].role == 'system' %}
|
||||||
|
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||||
|
{%- set messages = messages[1:] %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if message.role == 'user' %}
|
||||||
|
{{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}
|
||||||
|
{%- elif message.role == 'assistant' %}
|
||||||
|
{{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }}
|
||||||
|
{%- else %}
|
||||||
|
{{- raise_exception('Unexpected role: ' + message.role) }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|im_start|>assistant\n' }}
|
||||||
|
{%- endif %}
|
||||||
@@ -62,6 +62,7 @@ class ChatTemplate(str, Enum):
|
|||||||
qwen3 = "qwen3"
|
qwen3 = "qwen3"
|
||||||
qwen3_5 = "qwen3_5"
|
qwen3_5 = "qwen3_5"
|
||||||
falcon_h1 = "falcon_h1"
|
falcon_h1 = "falcon_h1"
|
||||||
|
nemotron_h = "nemotron_h"
|
||||||
tokenizer_default = "tokenizer_default"
|
tokenizer_default = "tokenizer_default"
|
||||||
exaone = "exaone"
|
exaone = "exaone"
|
||||||
exaone4 = "exaone4"
|
exaone4 = "exaone4"
|
||||||
|
|||||||
@@ -1258,6 +1258,21 @@ class ModelCompatibilityValidationMixin:
|
|||||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_nemotron_h_gradient_checkpointing(self):
|
||||||
|
if (
|
||||||
|
self.base_model
|
||||||
|
and "nemotron-h" in self.base_model.lower()
|
||||||
|
and self.gradient_checkpointing
|
||||||
|
and not self.sample_packing
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"gradient_checkpointing for nemotron_h requires sample_packing: true. "
|
||||||
|
"The upstream model marks supports_gradient_checkpointing=False; "
|
||||||
|
"axolotl only enables it after applying the sample-packing patch."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_gradient_checkpointing_w_offload(self):
|
def check_gradient_checkpointing_w_offload(self):
|
||||||
if self.gradient_checkpointing == "offload":
|
if self.gradient_checkpointing == "offload":
|
||||||
|
|||||||
Reference in New Issue
Block a user