Llama4 linearized (#2502)
* llama4 support for linearized experts * clean up fsdp2 sharding to prevent hang * add yaml config * cleanup example [skip ci]
This commit is contained in:
93
examples/llama-4/scout-qlora-fsdp1.yaml
Normal file
93
examples/llama-4/scout-qlora-fsdp1.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
||||
model_type: Llama4ForConditionalGeneration
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
strict: false
|
||||
|
||||
# torch_compile: true
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm: true
|
||||
liger_layer_norm: true
|
||||
|
||||
llama4_linearized_experts: true
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_modules:
|
||||
- self_attn.q_proj
|
||||
- self_attn.k_proj
|
||||
- self_attn.v_proj
|
||||
- self_attn.o_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
# - experts.gate_projs.[0-9]+$
|
||||
# - experts.up_projs.[0-9]+$
|
||||
# - experts.down_projs.[0-9]+$
|
||||
lora_modules_to_save:
|
||||
- lm_head
|
||||
- embed_tokens
|
||||
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- auto_wrap
|
||||
- full_shard
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
||||
fsdp_limit_all_gathers: true
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_offload_params: true
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_activation_checkpointing: true
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot|>
|
||||
@@ -4,3 +4,5 @@ mypy
|
||||
types-requests
|
||||
quartodoc
|
||||
jupyter
|
||||
blobfile
|
||||
tiktoken
|
||||
|
||||
0
src/axolotl/monkeypatch/accelerate/__init__.py
Normal file
0
src/axolotl/monkeypatch/accelerate/__init__.py
Normal file
63
src/axolotl/monkeypatch/accelerate/fsdp2.py
Normal file
63
src/axolotl/monkeypatch/accelerate/fsdp2.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
|
||||
"""
|
||||
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
||||
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
||||
|
||||
Args:
|
||||
accelerator (`Accelerator`): The accelerator instance
|
||||
model (`torch.nn.Module`): The model to load the state dict into
|
||||
full_sd (`dict`): The full state dict to load, can only be on rank 0
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.tensor import distribute_tensor
|
||||
|
||||
LOG.info("Broadcasting full state dict to all ranks...")
|
||||
sharded_sd = model.state_dict()
|
||||
param_names = sorted(sharded_sd.keys())
|
||||
for param_name in param_names:
|
||||
mesh = sharded_sd[param_name].device_mesh
|
||||
if accelerator.is_main_process:
|
||||
# Use the corresponding tensor from full_sd (assuming the key exists in full_sd)
|
||||
full_param = full_sd[param_name].detach().cuda()
|
||||
dist.broadcast(full_param, src=0, group=mesh.get_group())
|
||||
sharded_tensor = distribute_tensor(
|
||||
full_param, mesh, sharded_sd[param_name].placements
|
||||
)
|
||||
sharded_sd[param_name] = sharded_tensor
|
||||
else:
|
||||
# Prepare a tensor of matching shape and dtype
|
||||
full_tensor = torch.empty(
|
||||
sharded_sd[param_name].size(),
|
||||
device="cuda",
|
||||
dtype=sharded_sd[param_name].dtype,
|
||||
)
|
||||
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
|
||||
sharded_tensor = distribute_tensor(
|
||||
full_tensor, mesh, sharded_sd[param_name].placements
|
||||
)
|
||||
sharded_sd[param_name] = sharded_tensor
|
||||
|
||||
model.load_state_dict(sharded_sd)
|
||||
|
||||
|
||||
def patch_accelerate_fsdp_utils():
|
||||
from accelerate.utils import fsdp_utils
|
||||
|
||||
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
|
||||
setattr(
|
||||
sys.modules["accelerate.utils.fsdp_utils"],
|
||||
"fsdp2_load_full_state_dict",
|
||||
fsdp2_load_full_state_dict,
|
||||
)
|
||||
@@ -4,7 +4,7 @@ import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
from typing import Type
|
||||
from typing import Generator, Tuple, Type
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
@@ -200,6 +200,46 @@ def patch_self_attn_lora(cfg: DictDefault):
|
||||
)
|
||||
|
||||
|
||||
def find_self_attn_in_layer(
|
||||
layer: nn.Module,
|
||||
) -> Generator[Tuple[nn.Module], None, None]:
|
||||
# general case of most models
|
||||
if hasattr(layer, "self_attn"):
|
||||
if all(
|
||||
hasattr(layer.self_attn, proj)
|
||||
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
):
|
||||
yield layer.self_attn
|
||||
|
||||
|
||||
def find_mlp_in_layer(
|
||||
layer: nn.Module,
|
||||
) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]:
|
||||
# general case of most models
|
||||
if hasattr(layer, "mlp"):
|
||||
if all(
|
||||
hasattr(layer.mlp, proj) for proj in ["gate_proj", "up_proj", "down_proj"]
|
||||
):
|
||||
yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp
|
||||
# llama4 linearized experts
|
||||
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "shared_expert"):
|
||||
mlp = layer.feedforward.shared_expert
|
||||
yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp
|
||||
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "experts"):
|
||||
if all(
|
||||
hasattr(layer.feedforward.experts, proj)
|
||||
for proj in ["gate_projs", "up_projs", "down_projs"]
|
||||
):
|
||||
for gate_proj, up_proj, down_proj in zip(
|
||||
layer.feedforward.experts.gate_projs,
|
||||
layer.feedforward.experts.up_projs,
|
||||
layer.feedforward.experts.down_projs,
|
||||
):
|
||||
yield gate_proj, up_proj, down_proj, FakeMLP(
|
||||
gate_proj, up_proj, down_proj
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_kernel_patches(
|
||||
model: PeftModelForCausalLM, cfg: DictDefault
|
||||
) -> PeftModelForCausalLM:
|
||||
@@ -286,74 +326,82 @@ def apply_lora_kernel_patches(
|
||||
for layer in layers:
|
||||
# Add QKV, O fallback implementations to start
|
||||
# These will be overwritten later (if some conditions apply)
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
original_apply_qkv, layer.self_attn
|
||||
)
|
||||
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
|
||||
for self_attn in find_self_attn_in_layer(layer):
|
||||
self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
|
||||
self_attn.apply_o = types.MethodType(original_apply_o, self_attn)
|
||||
|
||||
if cfg.lora_mlp_kernel:
|
||||
# MLP patching
|
||||
gate_proj = layer.mlp.gate_proj
|
||||
up_proj = layer.mlp.up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
if cfg.lora_qkv_kernel:
|
||||
# Query, key, value patching
|
||||
layer_modules = [
|
||||
getattr(self_attn, linear_proj)
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A")
|
||||
and getattr(proj, "base_layer", proj).bias is None
|
||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||
for proj in (gate_proj, up_proj, down_proj)
|
||||
)
|
||||
if can_patch_qkv:
|
||||
# Add optimized implementation
|
||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
||||
)
|
||||
if cfg.lora_o_kernel:
|
||||
# Output patching
|
||||
layer_modules = [
|
||||
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
]
|
||||
can_patch_o = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_mlp:
|
||||
apply_fn = APPLY_FN_MAPPING[activation]
|
||||
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
||||
if can_patch_o:
|
||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
||||
)
|
||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||
if cfg.lora_mlp_kernel:
|
||||
# MLP patching
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A")
|
||||
and getattr(proj, "base_layer", proj).bias is None
|
||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||
for proj in (gate_proj, up_proj, down_proj)
|
||||
)
|
||||
if cfg.lora_qkv_kernel:
|
||||
# Query, key, value patching
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj)
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_qkv:
|
||||
# Add optimized implementation
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
apply_lora_qkv, layer.self_attn
|
||||
)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
||||
)
|
||||
if cfg.lora_o_kernel:
|
||||
# Output patching
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
]
|
||||
can_patch_o = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_o:
|
||||
layer.self_attn.apply_o = types.MethodType(
|
||||
apply_lora_o, layer.self_attn
|
||||
)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
||||
)
|
||||
if can_patch_mlp:
|
||||
apply_fn = APPLY_FN_MAPPING[activation]
|
||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
||||
)
|
||||
|
||||
LOG.setLevel(original_level)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class FakeMLP(nn.Module):
|
||||
"""
|
||||
placeholder MLP for triton patching
|
||||
"""
|
||||
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
|
||||
def __init__(self, gate_proj, up_proj, down_proj):
|
||||
super().__init__()
|
||||
self.gate_proj = gate_proj
|
||||
self.up_proj = up_proj
|
||||
self.down_proj = down_proj
|
||||
|
||||
0
src/axolotl/monkeypatch/models/llama4/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/llama4/__init__.py
Normal file
101
src/axolotl/monkeypatch/models/llama4/modeling.py
Normal file
101
src/axolotl/monkeypatch/models/llama4/modeling.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Modified Llama-4 text experts modeling for linearized experts for improved LoRA support
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Llama4Config
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
|
||||
class Llama4TextExperts(nn.Module):
|
||||
"""
|
||||
Modified Llama-4 text experts modeling for linearized experts
|
||||
"""
|
||||
|
||||
def __init__(self, config: Llama4Config):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.expert_dim = self.intermediate_size
|
||||
|
||||
# Replace fused gate_up_proj with separate Linear modules
|
||||
self.gate_projs = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
)
|
||||
|
||||
self.up_projs = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
)
|
||||
|
||||
# Replace down_proj Parameter with Linear modules
|
||||
self.down_projs = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(self.expert_dim, self.hidden_size, bias=False)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
)
|
||||
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward method using separate Linear layers for each expert.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): (num_experts * batch_size, hidden_size)
|
||||
The input should be organized by expert
|
||||
|
||||
Returns:
|
||||
torch.Tensor: (num_experts * batch_size, hidden_size)
|
||||
"""
|
||||
# Reshape to separate by expert
|
||||
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
|
||||
# batch_size_per_expert = hidden_states.size(1)
|
||||
|
||||
# Initialize output tensor
|
||||
next_states = torch.zeros_like(hidden_states)
|
||||
|
||||
# Process each expert separately
|
||||
for i in range(self.num_experts):
|
||||
# Get input for this expert
|
||||
expert_input = hidden_states[
|
||||
i
|
||||
] # Shape: (batch_size_per_expert, hidden_size)
|
||||
|
||||
# Apply gate and up projections
|
||||
gate = self.gate_projs[i](
|
||||
expert_input
|
||||
) # Shape: (batch_size_per_expert, expert_dim)
|
||||
up = self.up_projs[i](
|
||||
expert_input
|
||||
) # Shape: (batch_size_per_expert, expert_dim)
|
||||
|
||||
# Apply activation and down projection
|
||||
next_states[i] = self.down_projs[i](up * self.act_fn(gate))
|
||||
|
||||
# Flatten back to original shape
|
||||
return next_states.view(-1, self.hidden_size)
|
||||
|
||||
|
||||
def patch_llama4_linearized_modeling():
|
||||
"""
|
||||
Patch Llama4TextExperts to use separate Linear layers for each expert.
|
||||
"""
|
||||
from transformers.models.llama4 import modeling_llama4
|
||||
|
||||
modeling_llama4.Llama4TextExperts = Llama4TextExperts
|
||||
setattr(
|
||||
sys.modules["transformers.models.llama4"],
|
||||
"Llama4TextExperts",
|
||||
Llama4TextExperts,
|
||||
)
|
||||
@@ -544,8 +544,20 @@ class ModelLoader:
|
||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
|
||||
def apply_patches(self) -> None:
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
|
||||
|
||||
patch_accelerate_fsdp_utils()
|
||||
# patch gemma3 conditional generation forward before loading plugins
|
||||
# as it could be overridden by plugins
|
||||
if self.cfg.model_config_type == "llama4":
|
||||
if self.cfg.llama4_linearized_experts:
|
||||
from axolotl.monkeypatch.models.llama4.modeling import (
|
||||
patch_llama4_linearized_modeling,
|
||||
)
|
||||
|
||||
patch_llama4_linearized_modeling()
|
||||
|
||||
if self.cfg.model_config_type == "gemma3":
|
||||
from axolotl.monkeypatch.gemma3 import (
|
||||
patch_gemma3conditionalgeneration_forward,
|
||||
|
||||
@@ -245,6 +245,8 @@ class AxolotlInputConfig(
|
||||
lora_qkv_kernel: bool | None = None
|
||||
lora_o_kernel: bool | None = None
|
||||
|
||||
llama4_linearized_experts: bool | None = None
|
||||
|
||||
deepspeed: str | dict[str, Any] | None = None
|
||||
fsdp: list[str] | None = None
|
||||
fsdp_config: dict[str, Any] | None = None
|
||||
|
||||
Reference in New Issue
Block a user