super nemo support (#3508)

* nemo support

* config

* rename , config

* nemotron packing

* config fix

* read me + configs

* gc compat bug

* config chnages for qwen  and pad token nemo

* patch nemotron_h  weight renaming so it doesn't get reversed to embedding (singular noun) on checkpoint save

* lint

* revert qwen3.5 config changes, not needed in this pr

* lint

* Update examples/nemotron-h/120b-a12b-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/nemotron-h/nano-30b-a3b-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* readme + validation

* lazy load comment

* Update examples/nemotron-h/120b-a12b-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* val fix

* add nemo to multi packing

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
VED
2026-03-31 03:42:50 +05:30
committed by GitHub
parent 00dee05fc6
commit bb622b83de
15 changed files with 651 additions and 7 deletions

View File

@@ -0,0 +1,74 @@
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
# LoRA kernel patches are incompatible with this architecture — see README.
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
chat_template: tokenizer_default
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 4096
sample_packing: true
use_cut_cross_entropy: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
# Attention projection layers (present in ~12 attention layers out of 88)
- q_proj
- k_proj
- v_proj
- o_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,48 @@
# Nemotron-H (nvidia/NVIDIA-Nemotron-3-*)
Hybrid Mamba2 / Attention / MoE architecture (`model_type: nemotron_h`).
| Model | Total params | Active params | Layers |
|---|---|---|---|
| NVIDIA-Nemotron-3-Super-120B-A12B-BF16 | 120B | ~12B | 88 |
| NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 | 30B | ~3B | — |
## Requirements
```bash
pip install mamba-ssm causal-conv1d # fast Mamba2 CUDA kernels
```
## Architecture notes
- Three block types per layer: **Mamba2** (selective SSM), **Attention** (sparse), **MoE** (mixture-of-experts).
- Only ~12 out of 88 blocks are attention layers (120B variant).
- MLP activation is `relu2` via `mlp_hidden_act` (not the usual `hidden_act`).
## LoRA kernel patches
All three LoRA Triton kernel patches must be disabled:
```yaml
lora_qkv_kernel: false # attention lives in NemotronHBlock.mixer, not layer.self_attn
lora_o_kernel: false # same reason
lora_mlp_kernel: false # relu2 (mlp_hidden_act) is not supported by lora_mlp_kernel
```
## MoE expert weights
NemotronH experts store `up_proj` and `down_proj` as 3D `nn.Parameter` tensors
(shape `[num_experts, out_dim, in_dim]`), **not** `nn.Linear` modules — there is no
`gate_proj`. To fine-tune them alongside attention, use `lora_target_parameters`
instead of `lora_target_modules`:
```yaml
lora_target_parameters:
- up_proj
- down_proj
```
## Limitations
- **MoE Triton kernels**: `lora_mlp_kernel` is not supported for NemotronH's MoE expert layers. The expert weights are 3D `nn.Parameter` tensors (not `nn.Linear`), which the Triton kernel does not support. Keep `lora_mlp_kernel: false`.
- **Gradient checkpointing**: Only supported when `sample_packing: true`. Without sample packing the upstream model marks `supports_gradient_checkpointing = False`.

View File

@@ -0,0 +1,74 @@
# See examples/nemotron-h/README.md for architecture notes and requirements.
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
# LoRA kernel patches are incompatible with this architecture — see README.
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
chat_template: tokenizer_default
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 4096
sample_packing: true
use_cut_cross_entropy: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -23,4 +23,5 @@ MOE_ARCH_BLOCK = {
"glm4_moe": "Glm4MoeDecoderLayer",
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
"nemotron_h": "NemotronHMoE",
}

View File

@@ -590,9 +590,11 @@ class ModelLoader:
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
self.cfg.deepspeed or self.is_fsdp_enabled
):
if self.cfg.model_config_type in [
"jamba",
"qwen2_moe",
"nemotron_h",
] and not (self.cfg.deepspeed or self.is_fsdp_enabled):
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32

View File

@@ -142,6 +142,12 @@ class PatchManager:
def apply_post_model_build_patches(self, model: PreTrainedModel):
"""Apply patches right after model build, before post-load setup."""
if self.cfg.model_config_type == "nemotron_h":
# Must run after model build because NemotronHForCausalLM.__init__
# calls register_nemotron_h_conversion_mapping() with overwrite=True,
# which would clobber any earlier fix.
self._fix_nemotron_h_conversion_mapping()
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
@@ -291,6 +297,66 @@ class PatchManager:
patch_kimi_model()
if self.cfg.model_config_type == "nemotron_h":
if self.cfg.sample_packing:
from transformers.models.nemotron_h.modeling_nemotron_h import (
NemotronHPreTrainedModel,
)
from axolotl.monkeypatch.models.nemotron_h.modeling import (
patch_nemotron_h_modeling_packing,
)
patch_nemotron_h_modeling_packing()
# supports_gradient_checkpointing is only enabled after
# patch_nemotron_h_modeling_packing() installs the GC-compatible
# NemotronHBlock.forward. Without the patch, upstream marks this
# False because the original block forward is not GC-safe.
NemotronHPreTrainedModel.supports_gradient_checkpointing = True
@staticmethod
def _fix_nemotron_h_conversion_mapping():
"""Remove the spurious embedding→embeddings WeightRenaming from the
nemotron_h checkpoint conversion mapping.
The nvidia Hub model registers:
WeightRenaming("embedding.weight", "embeddings.weight")
to handle a legacy checkpoint variant. Its reverse (applied on save)
converts ``embeddings`` back to ``embedding``, which silently renames
``backbone.embeddings.weight`` → ``backbone.embedding.weight`` when
merging LoRA adapters back into the base model.
"""
try:
from transformers.conversion_mapping import (
WeightRenaming,
get_checkpoint_conversion_mapping,
register_checkpoint_conversion_mapping,
)
except ImportError:
return
mapping = get_checkpoint_conversion_mapping("nemotron_h")
if mapping is None:
return
filtered = [
entry
for entry in mapping
if not (
isinstance(entry, WeightRenaming)
and entry.source_patterns == ["embedding.weight"]
and entry.target_patterns == ["embeddings.weight"]
)
]
if len(filtered) != len(mapping):
register_checkpoint_conversion_mapping(
"nemotron_h", filtered, overwrite=True
)
LOG.info(
"Removed embedding→embeddings WeightRenaming from nemotron_h "
"checkpoint conversion mapping"
)
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:

View File

@@ -234,4 +234,6 @@ def get_linear_embedding_layers(model_type: str) -> list[str]:
return ["embed_in", "embed_out"]
if model_type == "falcon":
return ["word_embeddings", "lm_head"]
if model_type == "nemotron_h":
return ["embeddings", "lm_head"]
return ["embed_tokens", "lm_head"]

View File

@@ -394,15 +394,15 @@ def apply_lora_kernel_patches(
activation = text_config.hidden_act
elif hasattr(text_config, "hidden_activation"):
activation = text_config.hidden_activation
elif hasattr(text_config, "mlp_hidden_act"):
# Hybrid models (e.g. nemotron_h) use mlp_hidden_act instead of hidden_act
activation = text_config.mlp_hidden_act
# map activation to supported activation
if "gelu" in activation:
if activation and "gelu" in activation:
# gemma3 uses gelu_pytorch_tanh
activation = "gelu"
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(f"Activation {activation} is not supported")
layers = get_layers(model)
# Patch each layer
@@ -444,6 +444,15 @@ def apply_lora_kernel_patches(
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# Check is inside lora_mlp_kernel guard so models with an
# unsupported activation (e.g. nemotron_h uses relu2) can set
# lora_mlp_kernel: false without hitting an error here.
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(
f"Activation {activation!r} is not supported by lora_mlp_kernel. "
f"Set `lora_mlp_kernel: false` in your config or use a model with "
f"a supported activation ({SUPPORTED_ACTIVATIONS})."
)
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)

View File

@@ -0,0 +1,315 @@
"""Sample-packing patch for NemotronH (Mamba2/Attention/MoE hybrid).
Threads seq_idx (derived from position_ids) into the Mamba2 SSM kernels so
packed-sequence boundaries reset SSM state. Upstream hard-codes seq_idx=None,
which leaks hidden state across boundaries. Attention and MoE blocks need no
changes — only the Mamba2 mixer is patched.
"""
import importlib
import torch
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def get_seq_idx(position_ids: torch.Tensor) -> torch.Tensor:
"""Convert position_ids [B, T] → seq_idx [B, T] int32 for mamba-ssm kernels.
Example: position_ids [[0,1,2,3,0,1,2]] → seq_idx [[0,0,0,0,1,1,1]]
"""
return (torch.cumsum((position_ids == 0).int(), dim=-1) - 1).to(torch.int32)
def patch_nemotron_h_modeling_packing():
"""Patch NemotronH for sample packing: seq_idx threading into Mamba2 SSM kernels.
_get_unpad_data is handled by SUPPORTED_MULTIPACK_MODEL_TYPES / patch_for_multipack().
This function only applies the seq_idx patches that are unique to nemotron_h.
"""
try:
mod = importlib.import_module(
"transformers.models.nemotron_h.modeling_nemotron_h"
)
except ImportError:
LOG.warning("nemotron_h not found in transformers, skipping packing patches")
return
NemotronHMamba2Mixer = mod.NemotronHMamba2Mixer
NemotronHBlock = mod.NemotronHBlock
# Patch 1: cuda_kernels_forward — add seq_idx param and thread it to
# causal_conv1d_fn and mamba_chunk_scan_combined. Fused fast path is
# bypassed when seq_idx is set (requires causal_conv1d_cuda C extension).
def patched_cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params=None,
attention_mask=None,
seq_idx=None,
):
batch_size, seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size
d_to_remove = (
2 * self.intermediate_size
+ 2 * self.n_groups * self.ssm_state_size
+ self.num_heads
)
if cache_params is not None and cache_params.has_previous_state:
in_projected_states = self.in_proj(hidden_states.squeeze(1))
d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
split_projection_dim = [
d_mlp,
d_mlp,
self.intermediate_size,
self.conv_dim,
self.num_heads,
]
_, _, gate, hidden_states_B_C, dt = torch.split(
in_projected_states, split_projection_dim, dim=-1
)
hidden_states_B_C = mod.causal_conv1d_update(
hidden_states_B_C,
cache_params.conv_states[self.layer_idx],
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
hidden_states, B, C = torch.split(
hidden_states_B_C,
[
self.intermediate_size,
groups_time_state_size,
groups_time_state_size,
],
dim=-1,
)
A = -torch.exp(self.A_log.float())
A = (
A[:, None, ...][:, :, None]
.expand(-1, self.head_dim, self.ssm_state_size)
.to(dtype=torch.float32)
)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
hidden_states_reshaped = hidden_states.view(
batch_size, self.num_heads, self.head_dim
)
hidden_states = mod.selective_state_update(
cache_params.ssm_states[self.layer_idx],
hidden_states_reshaped,
dt,
A,
B,
C,
D,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
)
hidden_states = hidden_states.view(
batch_size, self.num_heads * self.head_dim
)
hidden_states = self.norm(hidden_states, gate)
out = self.out_proj(hidden_states)[:, None, ...]
else:
if attention_mask is not None and not torch.all(attention_mask == 1):
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
projected_states = self.in_proj(hidden_states)
A = -torch.exp(self.A_log.float())
dt_limit_kwargs = (
{}
if self.time_step_limit is None
else {"dt_limit": self.time_step_limit}
)
if attention_mask is not None:
input_not_masked = torch.all(attention_mask == 1)
else:
input_not_masked = True
if (
self.use_mem_eff_path
and self.training
and cache_params is None
and input_not_masked
and seq_idx is None
):
out, ssm_state = mod.mamba_split_conv1d_scan_combined(
projected_states,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.dt_bias,
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
outproj_weight=self.out_proj.weight,
outproj_bias=self.out_proj.bias,
headdim=self.head_dim,
ngroups=self.n_groups,
norm_before_gate=False,
return_final_states=True,
**dt_limit_kwargs,
)
else:
gate, hidden_states_B_C, time_step = torch.split(
projected_states,
[self.intermediate_size, self.conv_dim, self.num_heads],
dim=-1,
)
if cache_params is not None:
hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2)
conv_state = torch.nn.functional.pad(
hidden_states_B_C_t,
(self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0),
)
cache_params.conv_states[self.layer_idx].copy_(conv_state)
if mod.causal_conv1d_fn is None or self.activation not in [
"silu",
"swish",
]:
hidden_states_B_C = self.act(
self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[
:, :seq_len
]
)
else:
hidden_states_B_C = mod.causal_conv1d_fn(
x=hidden_states_B_C.transpose(1, 2),
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2)[:, :seq_len]
hidden_states, B, C = torch.split(
hidden_states_B_C,
[
self.intermediate_size,
groups_time_state_size,
groups_time_state_size,
],
dim=-1,
)
if attention_mask is not None and not torch.all(attention_mask == 1):
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(
dtype
)
scan_output, ssm_state = mod.mamba_chunk_scan_combined(
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
time_step,
A,
B.view(batch_size, seq_len, self.n_groups, -1),
C.view(batch_size, seq_len, self.n_groups, -1),
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=seq_idx,
return_final_states=True,
dt_bias=self.dt_bias,
dt_softplus=True,
**dt_limit_kwargs,
)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
scan_output = scan_output.view(batch_size, seq_len, -1)
scan_output = self.norm(scan_output, gate)
out = self.out_proj(scan_output)
return out
NemotronHMamba2Mixer.cuda_kernels_forward = patched_cuda_kernels_forward
# Patch 2: Mamba2Mixer.forward — add seq_idx, guard on causal_conv1d_fn,
# restore the cuda stream context (matches upstream; avoids NaN on multi-GPU).
def patched_mixer_forward(
self,
hidden_states,
cache_params=None,
attention_mask=None,
seq_idx=None,
):
if seq_idx is not None and mod.causal_conv1d_fn is None:
raise RuntimeError(
"Nemotron-H sample packing requires causal_conv1d_fn. "
"Install with: pip install mamba-ssm causal-conv1d"
)
if (
mod.is_fast_path_available
and "cuda" in self.in_proj.weight.device.type
and not mod.is_torchdynamo_compiling()
):
with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
return self.cuda_kernels_forward(
hidden_states, cache_params, attention_mask, seq_idx=seq_idx
)
return self.torch_forward(hidden_states, cache_params, attention_mask)
NemotronHMamba2Mixer.forward = patched_mixer_forward
# Patch 3: NemotronHBlock.forward — compute seq_idx from position_ids and
# pass it to the Mamba2 mixer. Skipped during decode (has_previous_state).
def patched_block_forward(
self,
hidden_states,
past_key_values=None,
cache_position=None,
attention_mask=None,
position_ids=None,
use_cache=False,
**kwargs,
):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.block_type == "mamba":
is_decoding = (
past_key_values is not None and past_key_values.has_previous_state
)
seq_idx = (
get_seq_idx(position_ids)
if position_ids is not None and not is_decoding
else None
)
hidden_states = self.mixer(
hidden_states,
cache_params=past_key_values,
attention_mask=attention_mask,
seq_idx=seq_idx,
)
elif self.block_type == "attention":
hidden_states, _ = self.mixer(
hidden_states=hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
user_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
else:
hidden_states = self.mixer(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
NemotronHBlock.forward = patched_block_forward
LOG.info("Applied NemotronH sample packing patch (seq_idx threading into Mamba2)")

View File

@@ -154,6 +154,8 @@ def patch_peft_target_parameters_matching():
1. Expands short suffixes to full module paths for parametrized modules.
2. Iterates params in definition order (not alphabetical order) so saved
adapters are compatible with standard PEFT, vLLM, etc.
3. Skips ParametrizationList synthetic paths to prevent PEFT from mistakenly
targeting quantized expert params via name-suffix matching.
"""
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
return
@@ -293,5 +295,23 @@ def patch_peft_target_parameters_matching():
self.targeted_parameter_names.append(key)
BaseTuner._inject_parameters = _patched_inject_parameters
# Skip ParametrizationList synthetic paths (e.g. "...parametrizations.up_proj")
# so PEFT suffix-matching doesn't try to wrap quantized expert params in LoRA.
# Previous MoE models (Mixtral, DeepSeek, etc.) stored experts as nn.Linear
# modules, so PEFT's normal target_modules path worked fine. NemotronH uses
# 3D nn.Parameter tensors via our quantize_moe_experts parametrization, which
# exposes synthetic ".parametrizations.<name>" paths that PEFT's suffix match
# would otherwise treat as target_modules candidates.
_original_check = BaseTuner._check_target_module_exists
@staticmethod
def _patched_check_target_module_exists(config, key):
if ".parametrizations." in key:
return False
return _original_check(config, key)
BaseTuner._check_target_module_exists = _patched_check_target_module_exists
patch_peft_target_parameters_matching._axolotl_patched = True
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")

View File

@@ -62,6 +62,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mistral4",
"afmoe",
"nemotron",
"nemotron_h",
]

View File

@@ -0,0 +1,16 @@
{%- if messages and messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- set messages = messages[1:] %}
{%- endif %}
{%- for message in messages %}
{%- if message.role == 'user' %}
{{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}
{%- elif message.role == 'assistant' %}
{{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }}
{%- else %}
{{- raise_exception('Unexpected role: ' + message.role) }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}

View File

@@ -62,6 +62,7 @@ class ChatTemplate(str, Enum):
qwen3 = "qwen3"
qwen3_5 = "qwen3_5"
falcon_h1 = "falcon_h1"
nemotron_h = "nemotron_h"
tokenizer_default = "tokenizer_default"
exaone = "exaone"
exaone4 = "exaone4"

View File

@@ -1258,6 +1258,21 @@ class ModelCompatibilityValidationMixin:
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_nemotron_h_gradient_checkpointing(self):
if (
self.base_model
and "nemotron-h" in self.base_model.lower()
and self.gradient_checkpointing
and not self.sample_packing
):
raise ValueError(
"gradient_checkpointing for nemotron_h requires sample_packing: true. "
"The upstream model marks supports_gradient_checkpointing=False; "
"axolotl only enables it after applying the sample-packing patch."
)
return self
@model_validator(mode="after")
def check_gradient_checkpointing_w_offload(self):
if self.gradient_checkpointing == "offload":