Compare commits

..

6 Commits

Author SHA1 Message Date
Dan Saunders
954b989e88 log warning re: logged losses / gradient scaling per rank 2025-04-07 18:47:43 +00:00
Dan Saunders
c64c881460 using existing packed seqlens util 2025-04-07 18:47:43 +00:00
Dan Saunders
cefd57cecb adding smoke test 2025-04-07 18:47:43 +00:00
Dan Saunders
2f3c52ea2f pre-commit fix 2025-04-07 18:47:43 +00:00
Dan Saunders
741015b3cf refactor and fix multipack seqlens 2025-04-07 18:47:43 +00:00
Dan Saunders
4188700b7b working on masking fix 2025-04-07 18:47:43 +00:00
15 changed files with 64 additions and 816 deletions

View File

@@ -164,7 +164,7 @@ Here is an example of a multi-modal dataset:
{
"role": "user",
"content": [
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
},

View File

@@ -1,93 +0,0 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
# torch_compile: true
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_steps: 100
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -4,5 +4,3 @@ mypy
types-requests
quartodoc
jupyter
blobfile
tiktoken

View File

@@ -32,9 +32,6 @@ cut_cross_entropy: true
## Supported Models
- llama
- llama4_text
- llama4
- mllama
- phi3
- gemma
- gemma2

View File

@@ -1,414 +0,0 @@
"""Llama4 CCE patch. Adapted from transformers 4.51.0."""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama4.modeling_llama4 import (
_CONFIG_FOR_DOC,
LLAMA4_INPUTS_DOCSTRING,
Llama4CausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
defer_logits_calculation: bool = False,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
defer_logits_calculation (`bool`, *optional*, defaults to `False`):
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :]
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@replace_return_docstrings(
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor | None = None,
**lm_kwargs,
) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
original_inputs_embeds_shape = inputs_embeds.shape
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
final_mask = special_image_mask.to(inputs_embeds.device)
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
final_mask_1d = final_mask[..., 0].reshape(-1)
num_tokens_to_fill = final_mask_1d.sum()
if num_tokens_to_fill != projected_vision_flat.size(0):
raise ValueError(
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
inputs_embeds = inputs_embeds.masked_scatter(
expanded_mask, projected_vision_flat
) # type: ignore
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
defer_logits_calculation=True, # enable deferred logits calculation
**lm_kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
# TODO: check if need to handle attention_mask
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
labels,
_PATCH_OPTS,
**lm_kwargs,
)
else:
logits = hidden_states
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
logits.device
)
shift_logits = logits[..., :-1, :][
shift_attention_mask.to(logits.device) != 0
].contiguous()
shift_labels = labels[..., 1:][
shift_attention_mask.to(labels.device) != 0
].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1).to(shift_logits.device),
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Llama4CausalLMOutputWithPast(
loss=loss,
logits=logits, # type: ignore # TODO: check if need to create dummy logits
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def patch_llama4_text(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.llama4 import modeling_llama4
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_llama4.Llama4ForCausalLM
), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
setattr(
modeling_llama4.Llama4ForCausalLM,
"forward",
cce_forward,
)
return None
def patch_llama4(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.llama4 import modeling_llama4
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_llama4.Llama4ForConditionalGeneration
), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
# patch the language model
maybe_model.language_model.forward = MethodType(
cce_forward, maybe_model.language_model
)
return maybe_model
setattr(
modeling_llama4.Llama4ForConditionalGeneration,
"forward",
cce_forward_multimodal,
)
# patch the causal language model
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
return None

View File

@@ -20,10 +20,6 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma3,
patch_gemma3_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
patch_llama4,
patch_llama4_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
patch_mistral,
patch_mistral3,
@@ -32,8 +28,6 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mlla
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"llama": patch_llama,
"llama4": patch_llama4,
"llama4_text": patch_llama4_text,
"mllama": patch_mllama,
"phi3": patch_phi3,
"gemma": patch_gemma,
@@ -66,14 +60,7 @@ def cce_patch(
raise ValueError(f"Unknown {impl=}")
if isinstance(model_type_or_model, transformers.PreTrainedModel):
if hasattr(model_type_or_model, "config"):
model_type = getattr(
getattr(model_type_or_model, "config", None), "model_type", None
)
else:
raise ValueError(
"model_type_or_model is a PreTrainedModel but does not have a config attribute"
)
model_type = model_type_or_model.config.model_type
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
model_type = model_type_or_model.model_type
else:

View File

@@ -29,7 +29,6 @@ liger_fused_linear_cross_entropy: true
- granite
- jamba
- llama
- llama4 (partial support, no support for FLCE yet)
- mistral
- mixtral
- mllama

View File

@@ -1,63 +0,0 @@
"""
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
"""
import logging
import sys
import torch
LOG = logging.getLogger(__name__)
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
"""
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args:
accelerator (`Accelerator`): The accelerator instance
model (`torch.nn.Module`): The model to load the state dict into
full_sd (`dict`): The full state dict to load, can only be on rank 0
"""
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor
LOG.info("Broadcasting full state dict to all ranks...")
sharded_sd = model.state_dict()
param_names = sorted(sharded_sd.keys())
for param_name in param_names:
mesh = sharded_sd[param_name].device_mesh
if accelerator.is_main_process:
# Use the corresponding tensor from full_sd (assuming the key exists in full_sd)
full_param = full_sd[param_name].detach().cuda()
dist.broadcast(full_param, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_param, mesh, sharded_sd[param_name].placements
)
sharded_sd[param_name] = sharded_tensor
else:
# Prepare a tensor of matching shape and dtype
full_tensor = torch.empty(
sharded_sd[param_name].size(),
device="cuda",
dtype=sharded_sd[param_name].dtype,
)
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(
full_tensor, mesh, sharded_sd[param_name].placements
)
sharded_sd[param_name] = sharded_tensor
model.load_state_dict(sharded_sd)
def patch_accelerate_fsdp_utils():
from accelerate.utils import fsdp_utils
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
setattr(
sys.modules["accelerate.utils.fsdp_utils"],
"fsdp2_load_full_state_dict",
fsdp2_load_full_state_dict,
)

View File

@@ -4,7 +4,7 @@ import importlib
import inspect
import logging
import types
from typing import Generator, Tuple, Type
from typing import Type
import torch
from accelerate.logging import get_logger
@@ -200,46 +200,6 @@ def patch_self_attn_lora(cfg: DictDefault):
)
def find_self_attn_in_layer(
layer: nn.Module,
) -> Generator[Tuple[nn.Module], None, None]:
# general case of most models
if hasattr(layer, "self_attn"):
if all(
hasattr(layer.self_attn, proj)
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
):
yield layer.self_attn
def find_mlp_in_layer(
layer: nn.Module,
) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]:
# general case of most models
if hasattr(layer, "mlp"):
if all(
hasattr(layer.mlp, proj) for proj in ["gate_proj", "up_proj", "down_proj"]
):
yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp
# llama4 linearized experts
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "shared_expert"):
mlp = layer.feedforward.shared_expert
yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp
if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "experts"):
if all(
hasattr(layer.feedforward.experts, proj)
for proj in ["gate_projs", "up_projs", "down_projs"]
):
for gate_proj, up_proj, down_proj in zip(
layer.feedforward.experts.gate_projs,
layer.feedforward.experts.up_projs,
layer.feedforward.experts.down_projs,
):
yield gate_proj, up_proj, down_proj, FakeMLP(
gate_proj, up_proj, down_proj
)
def apply_lora_kernel_patches(
model: PeftModelForCausalLM, cfg: DictDefault
) -> PeftModelForCausalLM:
@@ -326,82 +286,74 @@ def apply_lora_kernel_patches(
for layer in layers:
# Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply)
for self_attn in find_self_attn_in_layer(layer):
self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn)
self_attn.apply_o = types.MethodType(original_apply_o, self_attn)
layer.self_attn.apply_qkv = types.MethodType(
original_apply_qkv, layer.self_attn
)
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
if cfg.lora_mlp_kernel:
# MLP patching
gate_proj = layer.mlp.gate_proj
up_proj = layer.mlp.up_proj
down_proj = layer.mlp.down_proj
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_mlp:
apply_fn = APPLY_FN_MAPPING[activation]
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
)
if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
if can_patch_qkv:
# Add optimized implementation
layer.self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, layer.self_attn
)
if can_patch_o:
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_mlp:
apply_fn = APPLY_FN_MAPPING[activation]
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
)
if can_patch_o:
layer.self_attn.apply_o = types.MethodType(
apply_lora_o, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
LOG.setLevel(original_level)
return model
class FakeMLP(nn.Module):
"""
placeholder MLP for triton patching
"""
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(self, gate_proj, up_proj, down_proj):
super().__init__()
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj

View File

@@ -1,101 +0,0 @@
"""
Modified Llama-4 text experts modeling for linearized experts for improved LoRA support
"""
import sys
import torch
from torch import nn
from transformers import Llama4Config
from transformers.activations import ACT2FN
class Llama4TextExperts(nn.Module):
"""
Modified Llama-4 text experts modeling for linearized experts
"""
def __init__(self, config: Llama4Config):
super().__init__()
self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
# Replace fused gate_up_proj with separate Linear modules
self.gate_projs = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
for _ in range(self.num_experts)
]
)
self.up_projs = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.expert_dim, bias=False)
for _ in range(self.num_experts)
]
)
# Replace down_proj Parameter with Linear modules
self.down_projs = nn.ModuleList(
[
nn.Linear(self.expert_dim, self.hidden_size, bias=False)
for _ in range(self.num_experts)
]
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Forward method using separate Linear layers for each expert.
Args:
hidden_states (torch.Tensor): (num_experts * batch_size, hidden_size)
The input should be organized by expert
Returns:
torch.Tensor: (num_experts * batch_size, hidden_size)
"""
# Reshape to separate by expert
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
# batch_size_per_expert = hidden_states.size(1)
# Initialize output tensor
next_states = torch.zeros_like(hidden_states)
# Process each expert separately
for i in range(self.num_experts):
# Get input for this expert
expert_input = hidden_states[
i
] # Shape: (batch_size_per_expert, hidden_size)
# Apply gate and up projections
gate = self.gate_projs[i](
expert_input
) # Shape: (batch_size_per_expert, expert_dim)
up = self.up_projs[i](
expert_input
) # Shape: (batch_size_per_expert, expert_dim)
# Apply activation and down projection
next_states[i] = self.down_projs[i](up * self.act_fn(gate))
# Flatten back to original shape
return next_states.view(-1, self.hidden_size)
def patch_llama4_linearized_modeling():
"""
Patch Llama4TextExperts to use separate Linear layers for each expert.
"""
from transformers.models.llama4 import modeling_llama4
modeling_llama4.Llama4TextExperts = Llama4TextExperts
setattr(
sys.modules["transformers.models.llama4"],
"Llama4TextExperts",
Llama4TextExperts,
)

View File

@@ -544,20 +544,8 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
# patch gemma3 conditional generation forward before loading plugins
# as it could be overridden by plugins
if self.cfg.model_config_type == "llama4":
if self.cfg.llama4_linearized_experts:
from axolotl.monkeypatch.models.llama4.modeling import (
patch_llama4_linearized_modeling,
)
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "gemma3":
from axolotl.monkeypatch.gemma3 import (
patch_gemma3conditionalgeneration_forward,

View File

@@ -245,8 +245,6 @@ class AxolotlInputConfig(
lora_qkv_kernel: bool | None = None
lora_o_kernel: bool | None = None
llama4_linearized_experts: bool | None = None
deepspeed: str | dict[str, Any] | None = None
fsdp: list[str] | None = None
fsdp_config: dict[str, Any] | None = None