Llama4 linearized (#2502)

* llama4 support for linearized experts

* clean up fsdp2 sharding to prevent hang

* add yaml config

* cleanup example [skip ci]
This commit is contained in:
Wing Lian
2025-04-07 20:47:00 -04:00
committed by GitHub
parent a6c03217f5
commit 0dac2ddeac
10 changed files with 384 additions and 63 deletions

View File

@@ -0,0 +1,93 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
# torch_compile: true
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_steps: 100
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -4,3 +4,5 @@ mypy
types-requests
quartodoc
jupyter
blobfile
tiktoken

View File

@@ -0,0 +1,63 @@
"""
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
"""
import logging
import sys
import torch
LOG = logging.getLogger(__name__)
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args:
accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`): The model to load the state dict into
full_sd (`dict`): The full state dict to load, can only be on rank 0
"""
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor
LOG.info("Broadcasting full state dict to all ranks...")
sharded_sd = model.state_dict()
param_names = sorted(sharded_sd.keys())
for param_name in param_names:
mesh = sharded_sd[param_name].device_mesh
if accelerator.is_main_process:
# Use the corresponding tensor from full_sd (assuming the key exists in full_sd)
full_param = full_sd[param_name].detach().cuda()
dist.broadcast(full_param, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_param, mesh, sharded_sd[param_name].placements
)
sharded_sd[param_name] = sharded_tensor
else:
# Prepare a tensor of matching shape and dtype
full_tensor = torch.empty(
sharded_sd[param_name].size(),
device="cuda",
dtype=sharded_sd[param_name].dtype,
)
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_tensor, mesh, sharded_sd[param_name].placements
)
sharded_sd[param_name] = sharded_tensor
model.load_state_dict(sharded_sd)
def patch_accelerate_fsdp_utils():
from accelerate.utils import fsdp_utils
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
setattr(
sys.modules["accelerate.utils.fsdp_utils"],
"fsdp2_load_full_state_dict",
fsdp2_load_full_state_dict,
)

View File

@@ -4,7 +4,7 @@ import importlib
import inspect
import logging
import types
from typing import Type
from typing import Generator, Tuple, Type
import torch
from accelerate.logging import get_logger
@@ -200,6 +200,46 @@ def patch_self_attn_lora(cfg: DictDefault):
)
def find_self_attn_in_layer(
layer: nn.Module,
) -> Generator[Tuple[nn.Module], None, None]:
# general case of most models
if hasattr(layer, "self_attn"):
if all(
hasattr(layer.self_attn, proj)
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
):
yield layer.self_attn
def find_mlp_in_layer(
layer: nn.Module,
) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]:
# general case of most models
if hasattr(layer, "mlp"):
if all(
hasattr(layer.mlp, proj) for proj in ["gate_proj", "up_proj", "down_proj"]
):
yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp
# llama4 linearized experts
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "shared_expert"):
mlp = layer.feedforward.shared_expert
yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "experts"):
if all(
hasattr(layer.feedforward.experts, proj)
for proj in ["gate_projs", "up_projs", "down_projs"]
):
for gate_proj, up_proj, down_proj in zip(
layer.feedforward.experts.gate_projs,
layer.feedforward.experts.up_projs,
layer.feedforward.experts.down_projs,
):
yield gate_proj, up_proj, down_proj, FakeMLP(
gate_proj, up_proj, down_proj
)
def apply_lora_kernel_patches(
model: PeftModelForCausalLM, cfg: DictDefault
) -> PeftModelForCausalLM:
@@ -286,74 +326,82 @@ def apply_lora_kernel_patches(
for layer in layers:
# Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply)
layer.self_attn.apply_qkv = types.MethodType(
original_apply_qkv, layer.self_attn
)
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
for self_attn in find_self_attn_in_layer(layer):
self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
self_attn.apply_o = types.MethodType(original_apply_o, self_attn)
if cfg.lora_mlp_kernel:
# MLP patching
gate_proj = layer.mlp.gate_proj
up_proj = layer.mlp.up_proj
down_proj = layer.mlp.down_proj
if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_qkv:
# Add optimized implementation
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_mlp:
apply_fn = APPLY_FN_MAPPING[activation]
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
if can_patch_o:
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
layer.self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_o:
layer.self_attn.apply_o = types.MethodType(
apply_lora_o, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
if can_patch_mlp:
apply_fn = APPLY_FN_MAPPING[activation]
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
)
LOG.setLevel(original_level)
return model
class FakeMLP(nn.Module):
"""
placeholder MLP for triton patching
"""
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(self, gate_proj, up_proj, down_proj):
super().__init__()
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj

View File

@@ -0,0 +1,101 @@
"""
Modified Llama-4 text experts modeling for linearized experts for improved LoRA support
"""
import sys
import torch
from torch import nn
from transformers import Llama4Config
from transformers.activations import ACT2FN
class Llama4TextExperts(nn.Module):
"""
Modified Llama-4 text experts modeling for linearized experts
"""
def __init__(self, config: Llama4Config):
super().__init__()
self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
# Replace fused gate_up_proj with separate Linear modules
self.gate_projs = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
for _ in range(self.num_experts)
]
)
self.up_projs = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
for _ in range(self.num_experts)
]
)
# Replace down_proj Parameter with Linear modules
self.down_projs = nn.ModuleList(
[
nn.Linear(self.expert_dim, self.hidden_size, bias=False)
for _ in range(self.num_experts)
]
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Forward method using separate Linear layers for each expert.
Args:
hidden_states (torch.Tensor): (num_experts * batch_size, hidden_size)
The input should be organized by expert
Returns:
torch.Tensor: (num_experts * batch_size, hidden_size)
"""
# Reshape to separate by expert
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
# batch_size_per_expert = hidden_states.size(1)
# Initialize output tensor
next_states = torch.zeros_like(hidden_states)
# Process each expert separately
for i in range(self.num_experts):
# Get input for this expert
expert_input = hidden_states[
i
] # Shape: (batch_size_per_expert, hidden_size)
# Apply gate and up projections
gate = self.gate_projs[i](
expert_input
) # Shape: (batch_size_per_expert, expert_dim)
up = self.up_projs[i](
expert_input
) # Shape: (batch_size_per_expert, expert_dim)
# Apply activation and down projection
next_states[i] = self.down_projs[i](up * self.act_fn(gate))
# Flatten back to original shape
return next_states.view(-1, self.hidden_size)
def patch_llama4_linearized_modeling():
"""
Patch Llama4TextExperts to use separate Linear layers for each expert.
"""
from transformers.models.llama4 import modeling_llama4
modeling_llama4.Llama4TextExperts = Llama4TextExperts
setattr(
sys.modules["transformers.models.llama4"],
"Llama4TextExperts",
Llama4TextExperts,
)

View File

@@ -544,8 +544,20 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
# patch gemma3 conditional generation forward before loading plugins
# as it could be overridden by plugins
if self.cfg.model_config_type == "llama4":
if self.cfg.llama4_linearized_experts:
from axolotl.monkeypatch.models.llama4.modeling import (
patch_llama4_linearized_modeling,
)
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "gemma3":
from axolotl.monkeypatch.gemma3 import (
patch_gemma3conditionalgeneration_forward,

View File

@@ -245,6 +245,8 @@ class AxolotlInputConfig(
lora_qkv_kernel: bool | None = None
lora_o_kernel: bool | None = None
llama4_linearized_experts: bool | None = None
deepspeed: str | dict[str, Any] | None = None
fsdp: list[str] | None = None
fsdp_config: dict[str, Any] | None = None