Llama4 linearized (#2502)
* llama4 support for linearized experts * clean up fsdp2 sharding to prevent hang * add yaml config * cleanup example [skip ci]
This commit is contained in:
93
examples/llama-4/scout-qlora-fsdp1.yaml
Normal file
93
examples/llama-4/scout-qlora-fsdp1.yaml
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
||||||
|
model_type: Llama4ForConditionalGeneration
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# torch_compile: true
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
|
||||||
|
llama4_linearized_experts: true
|
||||||
|
load_in_4bit: true
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 64
|
||||||
|
lora_target_modules:
|
||||||
|
- self_attn.q_proj
|
||||||
|
- self_attn.k_proj
|
||||||
|
- self_attn.v_proj
|
||||||
|
- self_attn.o_proj
|
||||||
|
- shared_expert.gate_proj
|
||||||
|
- shared_expert.up_proj
|
||||||
|
- shared_expert.down_proj
|
||||||
|
# - experts.gate_projs.[0-9]+$
|
||||||
|
# - experts.up_projs.[0-9]+$
|
||||||
|
# - experts.down_projs.[0-9]+$
|
||||||
|
lora_modules_to_save:
|
||||||
|
- lm_head
|
||||||
|
- embed_tokens
|
||||||
|
|
||||||
|
chat_template: llama4
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:20%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- auto_wrap
|
||||||
|
- full_shard
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
fsdp_activation_checkpointing: true
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|finetune_right_pad_id|>
|
||||||
|
eos_token: <|eot|>
|
||||||
@@ -4,3 +4,5 @@ mypy
|
|||||||
types-requests
|
types-requests
|
||||||
quartodoc
|
quartodoc
|
||||||
jupyter
|
jupyter
|
||||||
|
blobfile
|
||||||
|
tiktoken
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/accelerate/__init__.py
Normal file
0
src/axolotl/monkeypatch/accelerate/__init__.py
Normal file
63
src/axolotl/monkeypatch/accelerate/fsdp2.py
Normal file
63
src/axolotl/monkeypatch/accelerate/fsdp2.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
|
||||||
|
"""
|
||||||
|
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
||||||
|
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
accelerator (`Accelerator`): The accelerator instance
|
||||||
|
model (`torch.nn.Module`): The model to load the state dict into
|
||||||
|
full_sd (`dict`): The full state dict to load, can only be on rank 0
|
||||||
|
"""
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed.tensor import distribute_tensor
|
||||||
|
|
||||||
|
LOG.info("Broadcasting full state dict to all ranks...")
|
||||||
|
sharded_sd = model.state_dict()
|
||||||
|
param_names = sorted(sharded_sd.keys())
|
||||||
|
for param_name in param_names:
|
||||||
|
mesh = sharded_sd[param_name].device_mesh
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
# Use the corresponding tensor from full_sd (assuming the key exists in full_sd)
|
||||||
|
full_param = full_sd[param_name].detach().cuda()
|
||||||
|
dist.broadcast(full_param, src=0, group=mesh.get_group())
|
||||||
|
sharded_tensor = distribute_tensor(
|
||||||
|
full_param, mesh, sharded_sd[param_name].placements
|
||||||
|
)
|
||||||
|
sharded_sd[param_name] = sharded_tensor
|
||||||
|
else:
|
||||||
|
# Prepare a tensor of matching shape and dtype
|
||||||
|
full_tensor = torch.empty(
|
||||||
|
sharded_sd[param_name].size(),
|
||||||
|
device="cuda",
|
||||||
|
dtype=sharded_sd[param_name].dtype,
|
||||||
|
)
|
||||||
|
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
|
||||||
|
sharded_tensor = distribute_tensor(
|
||||||
|
full_tensor, mesh, sharded_sd[param_name].placements
|
||||||
|
)
|
||||||
|
sharded_sd[param_name] = sharded_tensor
|
||||||
|
|
||||||
|
model.load_state_dict(sharded_sd)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_accelerate_fsdp_utils():
|
||||||
|
from accelerate.utils import fsdp_utils
|
||||||
|
|
||||||
|
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
|
||||||
|
setattr(
|
||||||
|
sys.modules["accelerate.utils.fsdp_utils"],
|
||||||
|
"fsdp2_load_full_state_dict",
|
||||||
|
fsdp2_load_full_state_dict,
|
||||||
|
)
|
||||||
@@ -4,7 +4,7 @@ import importlib
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import types
|
import types
|
||||||
from typing import Type
|
from typing import Generator, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
@@ -200,6 +200,46 @@ def patch_self_attn_lora(cfg: DictDefault):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def find_self_attn_in_layer(
|
||||||
|
layer: nn.Module,
|
||||||
|
) -> Generator[Tuple[nn.Module], None, None]:
|
||||||
|
# general case of most models
|
||||||
|
if hasattr(layer, "self_attn"):
|
||||||
|
if all(
|
||||||
|
hasattr(layer.self_attn, proj)
|
||||||
|
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||||
|
):
|
||||||
|
yield layer.self_attn
|
||||||
|
|
||||||
|
|
||||||
|
def find_mlp_in_layer(
|
||||||
|
layer: nn.Module,
|
||||||
|
) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]:
|
||||||
|
# general case of most models
|
||||||
|
if hasattr(layer, "mlp"):
|
||||||
|
if all(
|
||||||
|
hasattr(layer.mlp, proj) for proj in ["gate_proj", "up_proj", "down_proj"]
|
||||||
|
):
|
||||||
|
yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp
|
||||||
|
# llama4 linearized experts
|
||||||
|
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "shared_expert"):
|
||||||
|
mlp = layer.feedforward.shared_expert
|
||||||
|
yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp
|
||||||
|
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "experts"):
|
||||||
|
if all(
|
||||||
|
hasattr(layer.feedforward.experts, proj)
|
||||||
|
for proj in ["gate_projs", "up_projs", "down_projs"]
|
||||||
|
):
|
||||||
|
for gate_proj, up_proj, down_proj in zip(
|
||||||
|
layer.feedforward.experts.gate_projs,
|
||||||
|
layer.feedforward.experts.up_projs,
|
||||||
|
layer.feedforward.experts.down_projs,
|
||||||
|
):
|
||||||
|
yield gate_proj, up_proj, down_proj, FakeMLP(
|
||||||
|
gate_proj, up_proj, down_proj
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_lora_kernel_patches(
|
def apply_lora_kernel_patches(
|
||||||
model: PeftModelForCausalLM, cfg: DictDefault
|
model: PeftModelForCausalLM, cfg: DictDefault
|
||||||
) -> PeftModelForCausalLM:
|
) -> PeftModelForCausalLM:
|
||||||
@@ -286,74 +326,82 @@ def apply_lora_kernel_patches(
|
|||||||
for layer in layers:
|
for layer in layers:
|
||||||
# Add QKV, O fallback implementations to start
|
# Add QKV, O fallback implementations to start
|
||||||
# These will be overwritten later (if some conditions apply)
|
# These will be overwritten later (if some conditions apply)
|
||||||
layer.self_attn.apply_qkv = types.MethodType(
|
for self_attn in find_self_attn_in_layer(layer):
|
||||||
original_apply_qkv, layer.self_attn
|
self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
|
||||||
)
|
self_attn.apply_o = types.MethodType(original_apply_o, self_attn)
|
||||||
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
|
|
||||||
|
|
||||||
if cfg.lora_mlp_kernel:
|
if cfg.lora_qkv_kernel:
|
||||||
# MLP patching
|
# Query, key, value patching
|
||||||
gate_proj = layer.mlp.gate_proj
|
layer_modules = [
|
||||||
up_proj = layer.mlp.up_proj
|
getattr(self_attn, linear_proj)
|
||||||
down_proj = layer.mlp.down_proj
|
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||||
|
]
|
||||||
|
can_patch_qkv = all(
|
||||||
|
hasattr(module, "lora_A")
|
||||||
|
and getattr(module, "base_layer", module).bias is None
|
||||||
|
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
|
||||||
can_patch_mlp = all(
|
if can_patch_qkv:
|
||||||
hasattr(proj, "lora_A")
|
# Add optimized implementation
|
||||||
and getattr(proj, "base_layer", proj).bias is None
|
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
else:
|
||||||
for proj in (gate_proj, up_proj, down_proj)
|
LOG.warning_once(
|
||||||
)
|
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
||||||
|
)
|
||||||
|
if cfg.lora_o_kernel:
|
||||||
|
# Output patching
|
||||||
|
layer_modules = [
|
||||||
|
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
|
]
|
||||||
|
can_patch_o = all(
|
||||||
|
hasattr(module, "lora_A")
|
||||||
|
and getattr(module, "base_layer", module).bias is None
|
||||||
|
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
|
||||||
if can_patch_mlp:
|
if can_patch_o:
|
||||||
apply_fn = APPLY_FN_MAPPING[activation]
|
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||||
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
|
else:
|
||||||
else:
|
LOG.warning_once(
|
||||||
LOG.warning_once(
|
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
||||||
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
)
|
||||||
|
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||||
|
if cfg.lora_mlp_kernel:
|
||||||
|
# MLP patching
|
||||||
|
can_patch_mlp = all(
|
||||||
|
hasattr(proj, "lora_A")
|
||||||
|
and getattr(proj, "base_layer", proj).bias is None
|
||||||
|
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||||
|
for proj in (gate_proj, up_proj, down_proj)
|
||||||
)
|
)
|
||||||
if cfg.lora_qkv_kernel:
|
|
||||||
# Query, key, value patching
|
|
||||||
layer_modules = [
|
|
||||||
getattr(layer.self_attn, linear_proj)
|
|
||||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
|
||||||
]
|
|
||||||
can_patch_qkv = all(
|
|
||||||
hasattr(module, "lora_A")
|
|
||||||
and getattr(module, "base_layer", module).bias is None
|
|
||||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
if can_patch_qkv:
|
if can_patch_mlp:
|
||||||
# Add optimized implementation
|
apply_fn = APPLY_FN_MAPPING[activation]
|
||||||
layer.self_attn.apply_qkv = types.MethodType(
|
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||||
apply_lora_qkv, layer.self_attn
|
else:
|
||||||
)
|
LOG.warning_once(
|
||||||
else:
|
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
||||||
LOG.warning_once(
|
)
|
||||||
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
|
||||||
)
|
|
||||||
if cfg.lora_o_kernel:
|
|
||||||
# Output patching
|
|
||||||
layer_modules = [
|
|
||||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
|
||||||
]
|
|
||||||
can_patch_o = all(
|
|
||||||
hasattr(module, "lora_A")
|
|
||||||
and getattr(module, "base_layer", module).bias is None
|
|
||||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
if can_patch_o:
|
|
||||||
layer.self_attn.apply_o = types.MethodType(
|
|
||||||
apply_lora_o, layer.self_attn
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
LOG.warning_once(
|
|
||||||
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.setLevel(original_level)
|
LOG.setLevel(original_level)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class FakeMLP(nn.Module):
|
||||||
|
"""
|
||||||
|
placeholder MLP for triton patching
|
||||||
|
"""
|
||||||
|
|
||||||
|
gate_proj: nn.Linear
|
||||||
|
up_proj: nn.Linear
|
||||||
|
down_proj: nn.Linear
|
||||||
|
|
||||||
|
def __init__(self, gate_proj, up_proj, down_proj):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = gate_proj
|
||||||
|
self.up_proj = up_proj
|
||||||
|
self.down_proj = down_proj
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/models/llama4/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/llama4/__init__.py
Normal file
101
src/axolotl/monkeypatch/models/llama4/modeling.py
Normal file
101
src/axolotl/monkeypatch/models/llama4/modeling.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""
|
||||||
|
Modified Llama-4 text experts modeling for linearized experts for improved LoRA support
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Llama4Config
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4TextExperts(nn.Module):
|
||||||
|
"""
|
||||||
|
Modified Llama-4 text experts modeling for linearized experts
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Llama4Config):
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = config.num_local_experts
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.expert_dim = self.intermediate_size
|
||||||
|
|
||||||
|
# Replace fused gate_up_proj with separate Linear modules
|
||||||
|
self.gate_projs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
|
||||||
|
for _ in range(self.num_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up_projs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
|
||||||
|
for _ in range(self.num_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace down_proj Parameter with Linear modules
|
||||||
|
self.down_projs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Linear(self.expert_dim, self.hidden_size, bias=False)
|
||||||
|
for _ in range(self.num_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward method using separate Linear layers for each expert.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (torch.Tensor): (num_experts * batch_size, hidden_size)
|
||||||
|
The input should be organized by expert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: (num_experts * batch_size, hidden_size)
|
||||||
|
"""
|
||||||
|
# Reshape to separate by expert
|
||||||
|
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
|
||||||
|
# batch_size_per_expert = hidden_states.size(1)
|
||||||
|
|
||||||
|
# Initialize output tensor
|
||||||
|
next_states = torch.zeros_like(hidden_states)
|
||||||
|
|
||||||
|
# Process each expert separately
|
||||||
|
for i in range(self.num_experts):
|
||||||
|
# Get input for this expert
|
||||||
|
expert_input = hidden_states[
|
||||||
|
i
|
||||||
|
] # Shape: (batch_size_per_expert, hidden_size)
|
||||||
|
|
||||||
|
# Apply gate and up projections
|
||||||
|
gate = self.gate_projs[i](
|
||||||
|
expert_input
|
||||||
|
) # Shape: (batch_size_per_expert, expert_dim)
|
||||||
|
up = self.up_projs[i](
|
||||||
|
expert_input
|
||||||
|
) # Shape: (batch_size_per_expert, expert_dim)
|
||||||
|
|
||||||
|
# Apply activation and down projection
|
||||||
|
next_states[i] = self.down_projs[i](up * self.act_fn(gate))
|
||||||
|
|
||||||
|
# Flatten back to original shape
|
||||||
|
return next_states.view(-1, self.hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama4_linearized_modeling():
|
||||||
|
"""
|
||||||
|
Patch Llama4TextExperts to use separate Linear layers for each expert.
|
||||||
|
"""
|
||||||
|
from transformers.models.llama4 import modeling_llama4
|
||||||
|
|
||||||
|
modeling_llama4.Llama4TextExperts = Llama4TextExperts
|
||||||
|
setattr(
|
||||||
|
sys.modules["transformers.models.llama4"],
|
||||||
|
"Llama4TextExperts",
|
||||||
|
Llama4TextExperts,
|
||||||
|
)
|
||||||
@@ -544,8 +544,20 @@ class ModelLoader:
|
|||||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||||
|
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
|
||||||
|
|
||||||
|
patch_accelerate_fsdp_utils()
|
||||||
# patch gemma3 conditional generation forward before loading plugins
|
# patch gemma3 conditional generation forward before loading plugins
|
||||||
# as it could be overridden by plugins
|
# as it could be overridden by plugins
|
||||||
|
if self.cfg.model_config_type == "llama4":
|
||||||
|
if self.cfg.llama4_linearized_experts:
|
||||||
|
from axolotl.monkeypatch.models.llama4.modeling import (
|
||||||
|
patch_llama4_linearized_modeling,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_llama4_linearized_modeling()
|
||||||
|
|
||||||
if self.cfg.model_config_type == "gemma3":
|
if self.cfg.model_config_type == "gemma3":
|
||||||
from axolotl.monkeypatch.gemma3 import (
|
from axolotl.monkeypatch.gemma3 import (
|
||||||
patch_gemma3conditionalgeneration_forward,
|
patch_gemma3conditionalgeneration_forward,
|
||||||
|
|||||||
@@ -245,6 +245,8 @@ class AxolotlInputConfig(
|
|||||||
lora_qkv_kernel: bool | None = None
|
lora_qkv_kernel: bool | None = None
|
||||||
lora_o_kernel: bool | None = None
|
lora_o_kernel: bool | None = None
|
||||||
|
|
||||||
|
llama4_linearized_experts: bool | None = None
|
||||||
|
|
||||||
deepspeed: str | dict[str, Any] | None = None
|
deepspeed: str | dict[str, Any] | None = None
|
||||||
fsdp: list[str] | None = None
|
fsdp: list[str] | None = None
|
||||||
fsdp_config: dict[str, Any] | None = None
|
fsdp_config: dict[str, Any] | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user