super nemo support (#3508)
* nemo support * config * rename , config * nemotron packing * config fix * read me + configs * gc compat bug * config chnages for qwen and pad token nemo * patch nemotron_h weight renaming so it doesn't get reversed to embedding (singular noun) on checkpoint save * lint * revert qwen3.5 config changes, not needed in this pr * lint * Update examples/nemotron-h/120b-a12b-qlora.yaml Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * Update examples/nemotron-h/nano-30b-a3b-qlora.yaml Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * readme + validation * lazy load comment * Update examples/nemotron-h/120b-a12b-qlora.yaml Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * val fix * add nemo to multi packing --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
74
examples/nemotron-h/120b-a12b-qlora.yaml
Normal file
74
examples/nemotron-h/120b-a12b-qlora.yaml
Normal file
@@ -0,0 +1,74 @@
|
||||
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
|
||||
|
||||
# LoRA kernel patches are incompatible with this architecture — see README.
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
chat_template: tokenizer_default
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
use_cut_cross_entropy: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
# Attention projection layers (present in ~12 attention layers out of 88)
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# To also train MoE expert weights, add them via lora_target_parameters
|
||||
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||
# lora_target_parameters:
|
||||
# - up_proj
|
||||
# - down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
48
examples/nemotron-h/README.md
Normal file
48
examples/nemotron-h/README.md
Normal file
@@ -0,0 +1,48 @@
|
||||
# Nemotron-H (nvidia/NVIDIA-Nemotron-3-*)
|
||||
|
||||
Hybrid Mamba2 / Attention / MoE architecture (`model_type: nemotron_h`).
|
||||
|
||||
| Model | Total params | Active params | Layers |
|
||||
|---|---|---|---|
|
||||
| NVIDIA-Nemotron-3-Super-120B-A12B-BF16 | 120B | ~12B | 88 |
|
||||
| NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 | 30B | ~3B | — |
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install mamba-ssm causal-conv1d # fast Mamba2 CUDA kernels
|
||||
```
|
||||
|
||||
## Architecture notes
|
||||
|
||||
- Three block types per layer: **Mamba2** (selective SSM), **Attention** (sparse), **MoE** (mixture-of-experts).
|
||||
- Only ~12 out of 88 blocks are attention layers (120B variant).
|
||||
- MLP activation is `relu2` via `mlp_hidden_act` (not the usual `hidden_act`).
|
||||
|
||||
## LoRA kernel patches
|
||||
|
||||
All three LoRA Triton kernel patches must be disabled:
|
||||
|
||||
```yaml
|
||||
lora_qkv_kernel: false # attention lives in NemotronHBlock.mixer, not layer.self_attn
|
||||
lora_o_kernel: false # same reason
|
||||
lora_mlp_kernel: false # relu2 (mlp_hidden_act) is not supported by lora_mlp_kernel
|
||||
```
|
||||
|
||||
## MoE expert weights
|
||||
|
||||
NemotronH experts store `up_proj` and `down_proj` as 3D `nn.Parameter` tensors
|
||||
(shape `[num_experts, out_dim, in_dim]`), **not** `nn.Linear` modules — there is no
|
||||
`gate_proj`. To fine-tune them alongside attention, use `lora_target_parameters`
|
||||
instead of `lora_target_modules`:
|
||||
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- up_proj
|
||||
- down_proj
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- **MoE Triton kernels**: `lora_mlp_kernel` is not supported for NemotronH's MoE expert layers. The expert weights are 3D `nn.Parameter` tensors (not `nn.Linear`), which the Triton kernel does not support. Keep `lora_mlp_kernel: false`.
|
||||
- **Gradient checkpointing**: Only supported when `sample_packing: true`. Without sample packing the upstream model marks `supports_gradient_checkpointing = False`.
|
||||
74
examples/nemotron-h/nano-30b-a3b-qlora.yaml
Normal file
74
examples/nemotron-h/nano-30b-a3b-qlora.yaml
Normal file
@@ -0,0 +1,74 @@
|
||||
# See examples/nemotron-h/README.md for architecture notes and requirements.
|
||||
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
|
||||
|
||||
# LoRA kernel patches are incompatible with this architecture — see README.
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
chat_template: tokenizer_default
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
use_cut_cross_entropy: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# To also train MoE expert weights, add them via lora_target_parameters
|
||||
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
|
||||
# lora_target_parameters:
|
||||
# - up_proj
|
||||
# - down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
@@ -23,4 +23,5 @@ MOE_ARCH_BLOCK = {
|
||||
"glm4_moe": "Glm4MoeDecoderLayer",
|
||||
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
||||
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
||||
"nemotron_h": "NemotronHMoE",
|
||||
}
|
||||
|
||||
@@ -590,9 +590,11 @@ class ModelLoader:
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||
}
|
||||
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
|
||||
self.cfg.deepspeed or self.is_fsdp_enabled
|
||||
):
|
||||
if self.cfg.model_config_type in [
|
||||
"jamba",
|
||||
"qwen2_moe",
|
||||
"nemotron_h",
|
||||
] and not (self.cfg.deepspeed or self.is_fsdp_enabled):
|
||||
# for some reason, this causes the loss to be off by an order of magnitude
|
||||
# but deepspeed needs this still in bfloat16
|
||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||
|
||||
@@ -142,6 +142,12 @@ class PatchManager:
|
||||
|
||||
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches right after model build, before post-load setup."""
|
||||
if self.cfg.model_config_type == "nemotron_h":
|
||||
# Must run after model build because NemotronHForCausalLM.__init__
|
||||
# calls register_nemotron_h_conversion_mapping() with overwrite=True,
|
||||
# which would clobber any earlier fix.
|
||||
self._fix_nemotron_h_conversion_mapping()
|
||||
|
||||
self._finalize_moe_expert_quantization(model)
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
@@ -291,6 +297,66 @@ class PatchManager:
|
||||
|
||||
patch_kimi_model()
|
||||
|
||||
if self.cfg.model_config_type == "nemotron_h":
|
||||
if self.cfg.sample_packing:
|
||||
from transformers.models.nemotron_h.modeling_nemotron_h import (
|
||||
NemotronHPreTrainedModel,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.models.nemotron_h.modeling import (
|
||||
patch_nemotron_h_modeling_packing,
|
||||
)
|
||||
|
||||
patch_nemotron_h_modeling_packing()
|
||||
# supports_gradient_checkpointing is only enabled after
|
||||
# patch_nemotron_h_modeling_packing() installs the GC-compatible
|
||||
# NemotronHBlock.forward. Without the patch, upstream marks this
|
||||
# False because the original block forward is not GC-safe.
|
||||
NemotronHPreTrainedModel.supports_gradient_checkpointing = True
|
||||
|
||||
@staticmethod
|
||||
def _fix_nemotron_h_conversion_mapping():
|
||||
"""Remove the spurious embedding→embeddings WeightRenaming from the
|
||||
nemotron_h checkpoint conversion mapping.
|
||||
|
||||
The nvidia Hub model registers:
|
||||
WeightRenaming("embedding.weight", "embeddings.weight")
|
||||
to handle a legacy checkpoint variant. Its reverse (applied on save)
|
||||
converts ``embeddings`` back to ``embedding``, which silently renames
|
||||
``backbone.embeddings.weight`` → ``backbone.embedding.weight`` when
|
||||
merging LoRA adapters back into the base model.
|
||||
"""
|
||||
try:
|
||||
from transformers.conversion_mapping import (
|
||||
WeightRenaming,
|
||||
get_checkpoint_conversion_mapping,
|
||||
register_checkpoint_conversion_mapping,
|
||||
)
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
mapping = get_checkpoint_conversion_mapping("nemotron_h")
|
||||
if mapping is None:
|
||||
return
|
||||
|
||||
filtered = [
|
||||
entry
|
||||
for entry in mapping
|
||||
if not (
|
||||
isinstance(entry, WeightRenaming)
|
||||
and entry.source_patterns == ["embedding.weight"]
|
||||
and entry.target_patterns == ["embeddings.weight"]
|
||||
)
|
||||
]
|
||||
if len(filtered) != len(mapping):
|
||||
register_checkpoint_conversion_mapping(
|
||||
"nemotron_h", filtered, overwrite=True
|
||||
)
|
||||
LOG.info(
|
||||
"Removed embedding→embeddings WeightRenaming from nemotron_h "
|
||||
"checkpoint conversion mapping"
|
||||
)
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
|
||||
@@ -234,4 +234,6 @@ def get_linear_embedding_layers(model_type: str) -> list[str]:
|
||||
return ["embed_in", "embed_out"]
|
||||
if model_type == "falcon":
|
||||
return ["word_embeddings", "lm_head"]
|
||||
if model_type == "nemotron_h":
|
||||
return ["embeddings", "lm_head"]
|
||||
return ["embed_tokens", "lm_head"]
|
||||
|
||||
@@ -394,15 +394,15 @@ def apply_lora_kernel_patches(
|
||||
activation = text_config.hidden_act
|
||||
elif hasattr(text_config, "hidden_activation"):
|
||||
activation = text_config.hidden_activation
|
||||
elif hasattr(text_config, "mlp_hidden_act"):
|
||||
# Hybrid models (e.g. nemotron_h) use mlp_hidden_act instead of hidden_act
|
||||
activation = text_config.mlp_hidden_act
|
||||
|
||||
# map activation to supported activation
|
||||
if "gelu" in activation:
|
||||
if activation and "gelu" in activation:
|
||||
# gemma3 uses gelu_pytorch_tanh
|
||||
activation = "gelu"
|
||||
|
||||
if activation not in SUPPORTED_ACTIVATIONS:
|
||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||
|
||||
layers = get_layers(model)
|
||||
|
||||
# Patch each layer
|
||||
@@ -444,6 +444,15 @@ def apply_lora_kernel_patches(
|
||||
)
|
||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||
if cfg.lora_mlp_kernel:
|
||||
# Check is inside lora_mlp_kernel guard so models with an
|
||||
# unsupported activation (e.g. nemotron_h uses relu2) can set
|
||||
# lora_mlp_kernel: false without hitting an error here.
|
||||
if activation not in SUPPORTED_ACTIVATIONS:
|
||||
raise NotImplementedError(
|
||||
f"Activation {activation!r} is not supported by lora_mlp_kernel. "
|
||||
f"Set `lora_mlp_kernel: false` in your config or use a model with "
|
||||
f"a supported activation ({SUPPORTED_ACTIVATIONS})."
|
||||
)
|
||||
# MLP patching
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
|
||||
|
||||
315
src/axolotl/monkeypatch/models/nemotron_h/modeling.py
Normal file
315
src/axolotl/monkeypatch/models/nemotron_h/modeling.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Sample-packing patch for NemotronH (Mamba2/Attention/MoE hybrid).
|
||||
|
||||
Threads seq_idx (derived from position_ids) into the Mamba2 SSM kernels so
|
||||
packed-sequence boundaries reset SSM state. Upstream hard-codes seq_idx=None,
|
||||
which leaks hidden state across boundaries. Attention and MoE blocks need no
|
||||
changes — only the Mamba2 mixer is patched.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def get_seq_idx(position_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert position_ids [B, T] → seq_idx [B, T] int32 for mamba-ssm kernels.
|
||||
|
||||
Example: position_ids [[0,1,2,3,0,1,2]] → seq_idx [[0,0,0,0,1,1,1]]
|
||||
"""
|
||||
return (torch.cumsum((position_ids == 0).int(), dim=-1) - 1).to(torch.int32)
|
||||
|
||||
|
||||
def patch_nemotron_h_modeling_packing():
|
||||
"""Patch NemotronH for sample packing: seq_idx threading into Mamba2 SSM kernels.
|
||||
|
||||
_get_unpad_data is handled by SUPPORTED_MULTIPACK_MODEL_TYPES / patch_for_multipack().
|
||||
This function only applies the seq_idx patches that are unique to nemotron_h.
|
||||
"""
|
||||
try:
|
||||
mod = importlib.import_module(
|
||||
"transformers.models.nemotron_h.modeling_nemotron_h"
|
||||
)
|
||||
except ImportError:
|
||||
LOG.warning("nemotron_h not found in transformers, skipping packing patches")
|
||||
return
|
||||
|
||||
NemotronHMamba2Mixer = mod.NemotronHMamba2Mixer
|
||||
NemotronHBlock = mod.NemotronHBlock
|
||||
|
||||
# Patch 1: cuda_kernels_forward — add seq_idx param and thread it to
|
||||
# causal_conv1d_fn and mamba_chunk_scan_combined. Fused fast path is
|
||||
# bypassed when seq_idx is set (requires causal_conv1d_cuda C extension).
|
||||
def patched_cuda_kernels_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params=None,
|
||||
attention_mask=None,
|
||||
seq_idx=None,
|
||||
):
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
groups_time_state_size = self.n_groups * self.ssm_state_size
|
||||
d_to_remove = (
|
||||
2 * self.intermediate_size
|
||||
+ 2 * self.n_groups * self.ssm_state_size
|
||||
+ self.num_heads
|
||||
)
|
||||
|
||||
if cache_params is not None and cache_params.has_previous_state:
|
||||
in_projected_states = self.in_proj(hidden_states.squeeze(1))
|
||||
d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
|
||||
split_projection_dim = [
|
||||
d_mlp,
|
||||
d_mlp,
|
||||
self.intermediate_size,
|
||||
self.conv_dim,
|
||||
self.num_heads,
|
||||
]
|
||||
_, _, gate, hidden_states_B_C, dt = torch.split(
|
||||
in_projected_states, split_projection_dim, dim=-1
|
||||
)
|
||||
hidden_states_B_C = mod.causal_conv1d_update(
|
||||
hidden_states_B_C,
|
||||
cache_params.conv_states[self.layer_idx],
|
||||
self.conv1d.weight.squeeze(1),
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
hidden_states, B, C = torch.split(
|
||||
hidden_states_B_C,
|
||||
[
|
||||
self.intermediate_size,
|
||||
groups_time_state_size,
|
||||
groups_time_state_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
A = -torch.exp(self.A_log.float())
|
||||
A = (
|
||||
A[:, None, ...][:, :, None]
|
||||
.expand(-1, self.head_dim, self.ssm_state_size)
|
||||
.to(dtype=torch.float32)
|
||||
)
|
||||
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||
D = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
|
||||
C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
|
||||
hidden_states_reshaped = hidden_states.view(
|
||||
batch_size, self.num_heads, self.head_dim
|
||||
)
|
||||
hidden_states = mod.selective_state_update(
|
||||
cache_params.ssm_states[self.layer_idx],
|
||||
hidden_states_reshaped,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size, self.num_heads * self.head_dim
|
||||
)
|
||||
hidden_states = self.norm(hidden_states, gate)
|
||||
out = self.out_proj(hidden_states)[:, None, ...]
|
||||
|
||||
else:
|
||||
if attention_mask is not None and not torch.all(attention_mask == 1):
|
||||
dtype = hidden_states.dtype
|
||||
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
|
||||
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
A = -torch.exp(self.A_log.float())
|
||||
dt_limit_kwargs = (
|
||||
{}
|
||||
if self.time_step_limit is None
|
||||
else {"dt_limit": self.time_step_limit}
|
||||
)
|
||||
if attention_mask is not None:
|
||||
input_not_masked = torch.all(attention_mask == 1)
|
||||
else:
|
||||
input_not_masked = True
|
||||
|
||||
if (
|
||||
self.use_mem_eff_path
|
||||
and self.training
|
||||
and cache_params is None
|
||||
and input_not_masked
|
||||
and seq_idx is None
|
||||
):
|
||||
out, ssm_state = mod.mamba_split_conv1d_scan_combined(
|
||||
projected_states,
|
||||
self.conv1d.weight.squeeze(1),
|
||||
self.conv1d.bias,
|
||||
self.dt_bias,
|
||||
A,
|
||||
D=self.D,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
activation=self.activation,
|
||||
rmsnorm_weight=self.norm.weight,
|
||||
rmsnorm_eps=self.norm.variance_epsilon,
|
||||
outproj_weight=self.out_proj.weight,
|
||||
outproj_bias=self.out_proj.bias,
|
||||
headdim=self.head_dim,
|
||||
ngroups=self.n_groups,
|
||||
norm_before_gate=False,
|
||||
return_final_states=True,
|
||||
**dt_limit_kwargs,
|
||||
)
|
||||
else:
|
||||
gate, hidden_states_B_C, time_step = torch.split(
|
||||
projected_states,
|
||||
[self.intermediate_size, self.conv_dim, self.num_heads],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if cache_params is not None:
|
||||
hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2)
|
||||
conv_state = torch.nn.functional.pad(
|
||||
hidden_states_B_C_t,
|
||||
(self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0),
|
||||
)
|
||||
cache_params.conv_states[self.layer_idx].copy_(conv_state)
|
||||
|
||||
if mod.causal_conv1d_fn is None or self.activation not in [
|
||||
"silu",
|
||||
"swish",
|
||||
]:
|
||||
hidden_states_B_C = self.act(
|
||||
self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[
|
||||
:, :seq_len
|
||||
]
|
||||
)
|
||||
else:
|
||||
hidden_states_B_C = mod.causal_conv1d_fn(
|
||||
x=hidden_states_B_C.transpose(1, 2),
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=seq_idx,
|
||||
).transpose(1, 2)[:, :seq_len]
|
||||
|
||||
hidden_states, B, C = torch.split(
|
||||
hidden_states_B_C,
|
||||
[
|
||||
self.intermediate_size,
|
||||
groups_time_state_size,
|
||||
groups_time_state_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if attention_mask is not None and not torch.all(attention_mask == 1):
|
||||
dtype = hidden_states.dtype
|
||||
hidden_states = (hidden_states * attention_mask[:, :, None]).to(
|
||||
dtype
|
||||
)
|
||||
|
||||
scan_output, ssm_state = mod.mamba_chunk_scan_combined(
|
||||
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
|
||||
time_step,
|
||||
A,
|
||||
B.view(batch_size, seq_len, self.n_groups, -1),
|
||||
C.view(batch_size, seq_len, self.n_groups, -1),
|
||||
chunk_size=self.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
seq_idx=seq_idx,
|
||||
return_final_states=True,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
**dt_limit_kwargs,
|
||||
)
|
||||
if ssm_state is not None and cache_params is not None:
|
||||
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
||||
scan_output = scan_output.view(batch_size, seq_len, -1)
|
||||
scan_output = self.norm(scan_output, gate)
|
||||
out = self.out_proj(scan_output)
|
||||
|
||||
return out
|
||||
|
||||
NemotronHMamba2Mixer.cuda_kernels_forward = patched_cuda_kernels_forward
|
||||
|
||||
# Patch 2: Mamba2Mixer.forward — add seq_idx, guard on causal_conv1d_fn,
|
||||
# restore the cuda stream context (matches upstream; avoids NaN on multi-GPU).
|
||||
def patched_mixer_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cache_params=None,
|
||||
attention_mask=None,
|
||||
seq_idx=None,
|
||||
):
|
||||
if seq_idx is not None and mod.causal_conv1d_fn is None:
|
||||
raise RuntimeError(
|
||||
"Nemotron-H sample packing requires causal_conv1d_fn. "
|
||||
"Install with: pip install mamba-ssm causal-conv1d"
|
||||
)
|
||||
if (
|
||||
mod.is_fast_path_available
|
||||
and "cuda" in self.in_proj.weight.device.type
|
||||
and not mod.is_torchdynamo_compiling()
|
||||
):
|
||||
with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
|
||||
return self.cuda_kernels_forward(
|
||||
hidden_states, cache_params, attention_mask, seq_idx=seq_idx
|
||||
)
|
||||
return self.torch_forward(hidden_states, cache_params, attention_mask)
|
||||
|
||||
NemotronHMamba2Mixer.forward = patched_mixer_forward
|
||||
|
||||
# Patch 3: NemotronHBlock.forward — compute seq_idx from position_ids and
|
||||
# pass it to the Mamba2 mixer. Skipped during decode (has_previous_state).
|
||||
def patched_block_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
use_cache=False,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
||||
|
||||
if self.block_type == "mamba":
|
||||
is_decoding = (
|
||||
past_key_values is not None and past_key_values.has_previous_state
|
||||
)
|
||||
seq_idx = (
|
||||
get_seq_idx(position_ids)
|
||||
if position_ids is not None and not is_decoding
|
||||
else None
|
||||
)
|
||||
hidden_states = self.mixer(
|
||||
hidden_states,
|
||||
cache_params=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
seq_idx=seq_idx,
|
||||
)
|
||||
elif self.block_type == "attention":
|
||||
hidden_states, _ = self.mixer(
|
||||
hidden_states=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
user_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.mixer(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
NemotronHBlock.forward = patched_block_forward
|
||||
|
||||
LOG.info("Applied NemotronH sample packing patch (seq_idx threading into Mamba2)")
|
||||
@@ -154,6 +154,8 @@ def patch_peft_target_parameters_matching():
|
||||
1. Expands short suffixes to full module paths for parametrized modules.
|
||||
2. Iterates params in definition order (not alphabetical order) so saved
|
||||
adapters are compatible with standard PEFT, vLLM, etc.
|
||||
3. Skips ParametrizationList synthetic paths to prevent PEFT from mistakenly
|
||||
targeting quantized expert params via name-suffix matching.
|
||||
"""
|
||||
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
||||
return
|
||||
@@ -293,5 +295,23 @@ def patch_peft_target_parameters_matching():
|
||||
self.targeted_parameter_names.append(key)
|
||||
|
||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||
|
||||
# Skip ParametrizationList synthetic paths (e.g. "...parametrizations.up_proj")
|
||||
# so PEFT suffix-matching doesn't try to wrap quantized expert params in LoRA.
|
||||
# Previous MoE models (Mixtral, DeepSeek, etc.) stored experts as nn.Linear
|
||||
# modules, so PEFT's normal target_modules path worked fine. NemotronH uses
|
||||
# 3D nn.Parameter tensors via our quantize_moe_experts parametrization, which
|
||||
# exposes synthetic ".parametrizations.<name>" paths that PEFT's suffix match
|
||||
# would otherwise treat as target_modules candidates.
|
||||
_original_check = BaseTuner._check_target_module_exists
|
||||
|
||||
@staticmethod
|
||||
def _patched_check_target_module_exists(config, key):
|
||||
if ".parametrizations." in key:
|
||||
return False
|
||||
return _original_check(config, key)
|
||||
|
||||
BaseTuner._check_target_module_exists = _patched_check_target_module_exists
|
||||
|
||||
patch_peft_target_parameters_matching._axolotl_patched = True
|
||||
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")
|
||||
|
||||
@@ -62,6 +62,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"mistral4",
|
||||
"afmoe",
|
||||
"nemotron",
|
||||
"nemotron_h",
|
||||
]
|
||||
|
||||
|
||||
|
||||
16
src/axolotl/utils/chat_templates/templates/nemotron_h.jinja
Normal file
16
src/axolotl/utils/chat_templates/templates/nemotron_h.jinja
Normal file
@@ -0,0 +1,16 @@
|
||||
{%- if messages and messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||
{%- set messages = messages[1:] %}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- if message.role == 'user' %}
|
||||
{{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}
|
||||
{%- elif message.role == 'assistant' %}
|
||||
{{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{{- raise_exception('Unexpected role: ' + message.role) }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
@@ -62,6 +62,7 @@ class ChatTemplate(str, Enum):
|
||||
qwen3 = "qwen3"
|
||||
qwen3_5 = "qwen3_5"
|
||||
falcon_h1 = "falcon_h1"
|
||||
nemotron_h = "nemotron_h"
|
||||
tokenizer_default = "tokenizer_default"
|
||||
exaone = "exaone"
|
||||
exaone4 = "exaone4"
|
||||
|
||||
@@ -1258,6 +1258,21 @@ class ModelCompatibilityValidationMixin:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_nemotron_h_gradient_checkpointing(self):
|
||||
if (
|
||||
self.base_model
|
||||
and "nemotron-h" in self.base_model.lower()
|
||||
and self.gradient_checkpointing
|
||||
and not self.sample_packing
|
||||
):
|
||||
raise ValueError(
|
||||
"gradient_checkpointing for nemotron_h requires sample_packing: true. "
|
||||
"The upstream model marks supports_gradient_checkpointing=False; "
|
||||
"axolotl only enables it after applying the sample-packing patch."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_gradient_checkpointing_w_offload(self):
|
||||
if self.gradient_checkpointing == "offload":
|
||||
|
||||
Reference in New Issue
Block a user