From 0dac2ddeacf06dac8d4fbadcdde3d02d6ef5e2b0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 7 Apr 2025 20:47:00 -0400 Subject: [PATCH] Llama4 linearized (#2502) * llama4 support for linearized experts * clean up fsdp2 sharding to prevent hang * add yaml config * cleanup example [skip ci] --- examples/{llama4 => llama-4}/scout-lora.yaml | 0 examples/llama-4/scout-qlora-fsdp1.yaml | 93 ++++++++++ requirements-dev.txt | 2 + .../monkeypatch/accelerate/__init__.py | 0 src/axolotl/monkeypatch/accelerate/fsdp2.py | 63 +++++++ src/axolotl/monkeypatch/lora_kernels.py | 174 +++++++++++------- .../monkeypatch/models/llama4/__init__.py | 0 .../monkeypatch/models/llama4/modeling.py | 101 ++++++++++ src/axolotl/utils/models.py | 12 ++ src/axolotl/utils/schemas/config.py | 2 + 10 files changed, 384 insertions(+), 63 deletions(-) rename examples/{llama4 => llama-4}/scout-lora.yaml (100%) create mode 100644 examples/llama-4/scout-qlora-fsdp1.yaml create mode 100644 src/axolotl/monkeypatch/accelerate/__init__.py create mode 100644 src/axolotl/monkeypatch/accelerate/fsdp2.py create mode 100644 src/axolotl/monkeypatch/models/llama4/__init__.py create mode 100644 src/axolotl/monkeypatch/models/llama4/modeling.py diff --git a/examples/llama4/scout-lora.yaml b/examples/llama-4/scout-lora.yaml similarity index 100% rename from examples/llama4/scout-lora.yaml rename to examples/llama-4/scout-lora.yaml diff --git a/examples/llama-4/scout-qlora-fsdp1.yaml b/examples/llama-4/scout-qlora-fsdp1.yaml new file mode 100644 index 000000000..ad2e46786 --- /dev/null +++ b/examples/llama-4/scout-qlora-fsdp1.yaml @@ -0,0 +1,93 @@ +base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16 +model_type: Llama4ForConditionalGeneration + # Automatically upload checkpoint and final model to HF + # hub_model_id: username/custom_model_name + +strict: false + +# torch_compile: true +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_glu_activation: true +liger_rms_norm: true +liger_layer_norm: true + +llama4_linearized_experts: true +load_in_4bit: true +adapter: qlora +lora_r: 32 +lora_alpha: 64 +lora_target_modules: + - self_attn.q_proj + - self_attn.k_proj + - self_attn.v_proj + - self_attn.o_proj + - shared_expert.gate_proj + - shared_expert.up_proj + - shared_expert.down_proj + # - experts.gate_projs.[0-9]+$ + # - experts.up_projs.[0-9]+$ + # - experts.down_projs.[0-9]+$ +lora_modules_to_save: + - lm_head + - embed_tokens + +chat_template: llama4 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +logging_steps: 1 +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - auto_wrap + - full_shard +fsdp_config: + fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD + fsdp_activation_checkpointing: true +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot|> diff --git a/requirements-dev.txt b/requirements-dev.txt index 9f523de54..1dce5df5f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,3 +4,5 @@ mypy types-requests quartodoc jupyter +blobfile +tiktoken diff --git a/src/axolotl/monkeypatch/accelerate/__init__.py b/src/axolotl/monkeypatch/accelerate/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py new file mode 100644 index 000000000..2a5d2151d --- /dev/null +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -0,0 +1,63 @@ +""" +monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation +""" + +import logging +import sys + +import torch + +LOG = logging.getLogger(__name__) + + +def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): + """ + Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the + parameters from rank 0 to all other ranks. This function modifies the model in-place. + + Args: + accelerator (`Accelerator`): The accelerator instance + model (`torch.nn.Module`): The model to load the state dict into + full_sd (`dict`): The full state dict to load, can only be on rank 0 + """ + import torch.distributed as dist + from torch.distributed.tensor import distribute_tensor + + LOG.info("Broadcasting full state dict to all ranks...") + sharded_sd = model.state_dict() + param_names = sorted(sharded_sd.keys()) + for param_name in param_names: + mesh = sharded_sd[param_name].device_mesh + if accelerator.is_main_process: + # Use the corresponding tensor from full_sd (assuming the key exists in full_sd) + full_param = full_sd[param_name].detach().cuda() + dist.broadcast(full_param, src=0, group=mesh.get_group()) + sharded_tensor = distribute_tensor( + full_param, mesh, sharded_sd[param_name].placements + ) + sharded_sd[param_name] = sharded_tensor + else: + # Prepare a tensor of matching shape and dtype + full_tensor = torch.empty( + sharded_sd[param_name].size(), + device="cuda", + dtype=sharded_sd[param_name].dtype, + ) + dist.broadcast(full_tensor, src=0, group=mesh.get_group()) + sharded_tensor = distribute_tensor( + full_tensor, mesh, sharded_sd[param_name].placements + ) + sharded_sd[param_name] = sharded_tensor + + model.load_state_dict(sharded_sd) + + +def patch_accelerate_fsdp_utils(): + from accelerate.utils import fsdp_utils + + fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict + setattr( + sys.modules["accelerate.utils.fsdp_utils"], + "fsdp2_load_full_state_dict", + fsdp2_load_full_state_dict, + ) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 96cfb1b69..0036fe003 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -4,7 +4,7 @@ import importlib import inspect import logging import types -from typing import Type +from typing import Generator, Tuple, Type import torch from accelerate.logging import get_logger @@ -200,6 +200,46 @@ def patch_self_attn_lora(cfg: DictDefault): ) +def find_self_attn_in_layer( + layer: nn.Module, +) -> Generator[Tuple[nn.Module], None, None]: + # general case of most models + if hasattr(layer, "self_attn"): + if all( + hasattr(layer.self_attn, proj) + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] + ): + yield layer.self_attn + + +def find_mlp_in_layer( + layer: nn.Module, +) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]: + # general case of most models + if hasattr(layer, "mlp"): + if all( + hasattr(layer.mlp, proj) for proj in ["gate_proj", "up_proj", "down_proj"] + ): + yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp + # llama4 linearized experts + if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "shared_expert"): + mlp = layer.feedforward.shared_expert + yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp + if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "experts"): + if all( + hasattr(layer.feedforward.experts, proj) + for proj in ["gate_projs", "up_projs", "down_projs"] + ): + for gate_proj, up_proj, down_proj in zip( + layer.feedforward.experts.gate_projs, + layer.feedforward.experts.up_projs, + layer.feedforward.experts.down_projs, + ): + yield gate_proj, up_proj, down_proj, FakeMLP( + gate_proj, up_proj, down_proj + ) + + def apply_lora_kernel_patches( model: PeftModelForCausalLM, cfg: DictDefault ) -> PeftModelForCausalLM: @@ -286,74 +326,82 @@ def apply_lora_kernel_patches( for layer in layers: # Add QKV, O fallback implementations to start # These will be overwritten later (if some conditions apply) - layer.self_attn.apply_qkv = types.MethodType( - original_apply_qkv, layer.self_attn - ) - layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn) + for self_attn in find_self_attn_in_layer(layer): + self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn) + self_attn.apply_o = types.MethodType(original_apply_o, self_attn) - if cfg.lora_mlp_kernel: - # MLP patching - gate_proj = layer.mlp.gate_proj - up_proj = layer.mlp.up_proj - down_proj = layer.mlp.down_proj + if cfg.lora_qkv_kernel: + # Query, key, value patching + layer_modules = [ + getattr(self_attn, linear_proj) + for linear_proj in ["q_proj", "k_proj", "v_proj"] + ] + can_patch_qkv = all( + hasattr(module, "lora_A") + and getattr(module, "base_layer", module).bias is None + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 + for module in layer_modules + ) - can_patch_mlp = all( - hasattr(proj, "lora_A") - and getattr(proj, "base_layer", proj).bias is None - and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 - for proj in (gate_proj, up_proj, down_proj) - ) + if can_patch_qkv: + # Add optimized implementation + self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) + else: + LOG.warning_once( + "Cannot patch some attention QKV projections - requires LoRA adapters with no bias" + ) + if cfg.lora_o_kernel: + # Output patching + layer_modules = [ + getattr(self_attn, linear_proj) for linear_proj in ["o_proj"] + ] + can_patch_o = all( + hasattr(module, "lora_A") + and getattr(module, "base_layer", module).bias is None + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 + for module in layer_modules + ) - if can_patch_mlp: - apply_fn = APPLY_FN_MAPPING[activation] - layer.mlp.forward = types.MethodType(apply_fn, layer.mlp) - else: - LOG.warning_once( - "Cannot patch some MLP layers - requires LoRA adapters with no bias" + if can_patch_o: + self_attn.apply_o = types.MethodType(apply_lora_o, self_attn) + else: + LOG.warning_once( + "Cannot patch some attention output projection - requires LoRA adapters with no bias" + ) + for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer): + if cfg.lora_mlp_kernel: + # MLP patching + can_patch_mlp = all( + hasattr(proj, "lora_A") + and getattr(proj, "base_layer", proj).bias is None + and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 + for proj in (gate_proj, up_proj, down_proj) ) - if cfg.lora_qkv_kernel: - # Query, key, value patching - layer_modules = [ - getattr(layer.self_attn, linear_proj) - for linear_proj in ["q_proj", "k_proj", "v_proj"] - ] - can_patch_qkv = all( - hasattr(module, "lora_A") - and getattr(module, "base_layer", module).bias is None - and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 - for module in layer_modules - ) - if can_patch_qkv: - # Add optimized implementation - layer.self_attn.apply_qkv = types.MethodType( - apply_lora_qkv, layer.self_attn - ) - else: - LOG.warning_once( - "Cannot patch some attention QKV projections - requires LoRA adapters with no bias" - ) - if cfg.lora_o_kernel: - # Output patching - layer_modules = [ - getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] - ] - can_patch_o = all( - hasattr(module, "lora_A") - and getattr(module, "base_layer", module).bias is None - and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 - for module in layer_modules - ) - - if can_patch_o: - layer.self_attn.apply_o = types.MethodType( - apply_lora_o, layer.self_attn - ) - else: - LOG.warning_once( - "Cannot patch some attention output projection - requires LoRA adapters with no bias" - ) + if can_patch_mlp: + apply_fn = APPLY_FN_MAPPING[activation] + layer.mlp.forward = types.MethodType(apply_fn, mlp) + else: + LOG.warning_once( + "Cannot patch some MLP layers - requires LoRA adapters with no bias" + ) LOG.setLevel(original_level) return model + + +class FakeMLP(nn.Module): + """ + placeholder MLP for triton patching + """ + + gate_proj: nn.Linear + up_proj: nn.Linear + down_proj: nn.Linear + + def __init__(self, gate_proj, up_proj, down_proj): + super().__init__() + self.gate_proj = gate_proj + self.up_proj = up_proj + self.down_proj = down_proj diff --git a/src/axolotl/monkeypatch/models/llama4/__init__.py b/src/axolotl/monkeypatch/models/llama4/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/llama4/modeling.py b/src/axolotl/monkeypatch/models/llama4/modeling.py new file mode 100644 index 000000000..b2a46ab86 --- /dev/null +++ b/src/axolotl/monkeypatch/models/llama4/modeling.py @@ -0,0 +1,101 @@ +""" +Modified Llama-4 text experts modeling for linearized experts for improved LoRA support +""" + +import sys + +import torch +from torch import nn +from transformers import Llama4Config +from transformers.activations import ACT2FN + + +class Llama4TextExperts(nn.Module): + """ + Modified Llama-4 text experts modeling for linearized experts + """ + + def __init__(self, config: Llama4Config): + super().__init__() + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + + # Replace fused gate_up_proj with separate Linear modules + self.gate_projs = nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.expert_dim, bias=False) + for _ in range(self.num_experts) + ] + ) + + self.up_projs = nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.expert_dim, bias=False) + for _ in range(self.num_experts) + ] + ) + + # Replace down_proj Parameter with Linear modules + self.down_projs = nn.ModuleList( + [ + nn.Linear(self.expert_dim, self.hidden_size, bias=False) + for _ in range(self.num_experts) + ] + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward method using separate Linear layers for each expert. + + Args: + hidden_states (torch.Tensor): (num_experts * batch_size, hidden_size) + The input should be organized by expert + + Returns: + torch.Tensor: (num_experts * batch_size, hidden_size) + """ + # Reshape to separate by expert + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + # batch_size_per_expert = hidden_states.size(1) + + # Initialize output tensor + next_states = torch.zeros_like(hidden_states) + + # Process each expert separately + for i in range(self.num_experts): + # Get input for this expert + expert_input = hidden_states[ + i + ] # Shape: (batch_size_per_expert, hidden_size) + + # Apply gate and up projections + gate = self.gate_projs[i]( + expert_input + ) # Shape: (batch_size_per_expert, expert_dim) + up = self.up_projs[i]( + expert_input + ) # Shape: (batch_size_per_expert, expert_dim) + + # Apply activation and down projection + next_states[i] = self.down_projs[i](up * self.act_fn(gate)) + + # Flatten back to original shape + return next_states.view(-1, self.hidden_size) + + +def patch_llama4_linearized_modeling(): + """ + Patch Llama4TextExperts to use separate Linear layers for each expert. + """ + from transformers.models.llama4 import modeling_llama4 + + modeling_llama4.Llama4TextExperts = Llama4TextExperts + setattr( + sys.modules["transformers.models.llama4"], + "Llama4TextExperts", + Llama4TextExperts, + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 024673b8e..f808f4bdd 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -544,8 +544,20 @@ class ModelLoader: self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name def apply_patches(self) -> None: + if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": + from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils + + patch_accelerate_fsdp_utils() # patch gemma3 conditional generation forward before loading plugins # as it could be overridden by plugins + if self.cfg.model_config_type == "llama4": + if self.cfg.llama4_linearized_experts: + from axolotl.monkeypatch.models.llama4.modeling import ( + patch_llama4_linearized_modeling, + ) + + patch_llama4_linearized_modeling() + if self.cfg.model_config_type == "gemma3": from axolotl.monkeypatch.gemma3 import ( patch_gemma3conditionalgeneration_forward, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4083fcc22..882c9a248 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -245,6 +245,8 @@ class AxolotlInputConfig( lora_qkv_kernel: bool | None = None lora_o_kernel: bool | None = None + llama4_linearized_experts: bool | None = None + deepspeed: str | dict[str, Any] | None = None fsdp: list[str] | None = None fsdp_config: dict[str, Any] | None = None