Llama4 linearized (#2502)

* llama4 support for linearized experts

* clean up fsdp2 sharding to prevent hang

* add yaml config

* cleanup example [skip ci]
This commit is contained in:
Wing Lian
2025-04-07 20:47:00 -04:00
committed by GitHub
parent a6c03217f5
commit 0dac2ddeac
10 changed files with 384 additions and 63 deletions

View File

@@ -0,0 +1,93 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
# torch_compile: true
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_steps: 100
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -4,3 +4,5 @@ mypy
types-requests types-requests
quartodoc quartodoc
jupyter jupyter
blobfile
tiktoken

View File

@@ -0,0 +1,63 @@
"""
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
"""
import logging
import sys
import torch
LOG = logging.getLogger(__name__)
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args:
accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`): The model to load the state dict into
full_sd (`dict`): The full state dict to load, can only be on rank 0
"""
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor
LOG.info("Broadcasting full state dict to all ranks...")
sharded_sd = model.state_dict()
param_names = sorted(sharded_sd.keys())
for param_name in param_names:
mesh = sharded_sd[param_name].device_mesh
if accelerator.is_main_process:
# Use the corresponding tensor from full_sd (assuming the key exists in full_sd)
full_param = full_sd[param_name].detach().cuda()
dist.broadcast(full_param, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_param, mesh, sharded_sd[param_name].placements
)
sharded_sd[param_name] = sharded_tensor
else:
# Prepare a tensor of matching shape and dtype
full_tensor = torch.empty(
sharded_sd[param_name].size(),
device="cuda",
dtype=sharded_sd[param_name].dtype,
)
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_tensor, mesh, sharded_sd[param_name].placements
)
sharded_sd[param_name] = sharded_tensor
model.load_state_dict(sharded_sd)
def patch_accelerate_fsdp_utils():
from accelerate.utils import fsdp_utils
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
setattr(
sys.modules["accelerate.utils.fsdp_utils"],
"fsdp2_load_full_state_dict",
fsdp2_load_full_state_dict,
)

View File

@@ -4,7 +4,7 @@ import importlib
import inspect import inspect
import logging import logging
import types import types
from typing import Type from typing import Generator, Tuple, Type
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -200,6 +200,46 @@ def patch_self_attn_lora(cfg: DictDefault):
) )
def find_self_attn_in_layer(
layer: nn.Module,
) -> Generator[Tuple[nn.Module], None, None]:
# general case of most models
if hasattr(layer, "self_attn"):
if all(
hasattr(layer.self_attn, proj)
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
):
yield layer.self_attn
def find_mlp_in_layer(
layer: nn.Module,
) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]:
# general case of most models
if hasattr(layer, "mlp"):
if all(
hasattr(layer.mlp, proj) for proj in ["gate_proj", "up_proj", "down_proj"]
):
yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp
# llama4 linearized experts
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "shared_expert"):
mlp = layer.feedforward.shared_expert
yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "experts"):
if all(
hasattr(layer.feedforward.experts, proj)
for proj in ["gate_projs", "up_projs", "down_projs"]
):
for gate_proj, up_proj, down_proj in zip(
layer.feedforward.experts.gate_projs,
layer.feedforward.experts.up_projs,
layer.feedforward.experts.down_projs,
):
yield gate_proj, up_proj, down_proj, FakeMLP(
gate_proj, up_proj, down_proj
)
def apply_lora_kernel_patches( def apply_lora_kernel_patches(
model: PeftModelForCausalLM, cfg: DictDefault model: PeftModelForCausalLM, cfg: DictDefault
) -> PeftModelForCausalLM: ) -> PeftModelForCausalLM:
@@ -286,74 +326,82 @@ def apply_lora_kernel_patches(
for layer in layers: for layer in layers:
# Add QKV, O fallback implementations to start # Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply) # These will be overwritten later (if some conditions apply)
layer.self_attn.apply_qkv = types.MethodType( for self_attn in find_self_attn_in_layer(layer):
original_apply_qkv, layer.self_attn self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
) self_attn.apply_o = types.MethodType(original_apply_o, self_attn)
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
if cfg.lora_mlp_kernel: if cfg.lora_qkv_kernel:
# MLP patching # Query, key, value patching
gate_proj = layer.mlp.gate_proj layer_modules = [
up_proj = layer.mlp.up_proj getattr(self_attn, linear_proj)
down_proj = layer.mlp.down_proj for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
can_patch_mlp = all( if can_patch_qkv:
hasattr(proj, "lora_A") # Add optimized implementation
and getattr(proj, "base_layer", proj).bias is None self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 else:
for proj in (gate_proj, up_proj, down_proj) LOG.warning_once(
) "Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_mlp: if can_patch_o:
apply_fn = APPLY_FN_MAPPING[activation] self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp) else:
else: LOG.warning_once(
LOG.warning_once( "Cannot patch some attention output projection - requires LoRA adapters with no bias"
"Cannot patch some MLP layers - requires LoRA adapters with no bias" )
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
) )
if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_qkv: if can_patch_mlp:
# Add optimized implementation apply_fn = APPLY_FN_MAPPING[activation]
layer.self_attn.apply_qkv = types.MethodType( layer.mlp.forward = types.MethodType(apply_fn, mlp)
apply_lora_qkv, layer.self_attn else:
) LOG.warning_once(
else: "Cannot patch some MLP layers - requires LoRA adapters with no bias"
LOG.warning_once( )
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_o:
layer.self_attn.apply_o = types.MethodType(
apply_lora_o, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
LOG.setLevel(original_level) LOG.setLevel(original_level)
return model return model
class FakeMLP(nn.Module):
"""
placeholder MLP for triton patching
"""
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(self, gate_proj, up_proj, down_proj):
super().__init__()
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj

View File

@@ -0,0 +1,101 @@
"""
Modified Llama-4 text experts modeling for linearized experts for improved LoRA support
"""
import sys
import torch
from torch import nn
from transformers import Llama4Config
from transformers.activations import ACT2FN
class Llama4TextExperts(nn.Module):
"""
Modified Llama-4 text experts modeling for linearized experts
"""
def __init__(self, config: Llama4Config):
super().__init__()
self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
# Replace fused gate_up_proj with separate Linear modules
self.gate_projs = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
for _ in range(self.num_experts)
]
)
self.up_projs = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
for _ in range(self.num_experts)
]
)
# Replace down_proj Parameter with Linear modules
self.down_projs = nn.ModuleList(
[
nn.Linear(self.expert_dim, self.hidden_size, bias=False)
for _ in range(self.num_experts)
]
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Forward method using separate Linear layers for each expert.
Args:
hidden_states (torch.Tensor): (num_experts * batch_size, hidden_size)
The input should be organized by expert
Returns:
torch.Tensor: (num_experts * batch_size, hidden_size)
"""
# Reshape to separate by expert
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
# batch_size_per_expert = hidden_states.size(1)
# Initialize output tensor
next_states = torch.zeros_like(hidden_states)
# Process each expert separately
for i in range(self.num_experts):
# Get input for this expert
expert_input = hidden_states[
i
] # Shape: (batch_size_per_expert, hidden_size)
# Apply gate and up projections
gate = self.gate_projs[i](
expert_input
) # Shape: (batch_size_per_expert, expert_dim)
up = self.up_projs[i](
expert_input
) # Shape: (batch_size_per_expert, expert_dim)
# Apply activation and down projection
next_states[i] = self.down_projs[i](up * self.act_fn(gate))
# Flatten back to original shape
return next_states.view(-1, self.hidden_size)
def patch_llama4_linearized_modeling():
"""
Patch Llama4TextExperts to use separate Linear layers for each expert.
"""
from transformers.models.llama4 import modeling_llama4
modeling_llama4.Llama4TextExperts = Llama4TextExperts
setattr(
sys.modules["transformers.models.llama4"],
"Llama4TextExperts",
Llama4TextExperts,
)

View File

@@ -544,8 +544,20 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None: def apply_patches(self) -> None:
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
# patch gemma3 conditional generation forward before loading plugins # patch gemma3 conditional generation forward before loading plugins
# as it could be overridden by plugins # as it could be overridden by plugins
if self.cfg.model_config_type == "llama4":
if self.cfg.llama4_linearized_experts:
from axolotl.monkeypatch.models.llama4.modeling import (
patch_llama4_linearized_modeling,
)
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "gemma3": if self.cfg.model_config_type == "gemma3":
from axolotl.monkeypatch.gemma3 import ( from axolotl.monkeypatch.gemma3 import (
patch_gemma3conditionalgeneration_forward, patch_gemma3conditionalgeneration_forward,

View File

@@ -245,6 +245,8 @@ class AxolotlInputConfig(
lora_qkv_kernel: bool | None = None lora_qkv_kernel: bool | None = None
lora_o_kernel: bool | None = None lora_o_kernel: bool | None = None
llama4_linearized_experts: bool | None = None
deepspeed: str | dict[str, Any] | None = None deepspeed: str | dict[str, Any] | None = None
fsdp: list[str] | None = None fsdp: list[str] | None = None
fsdp_config: dict[str, Any] | None = None fsdp_config: dict[str, Any] | None = None