From a6c03217f57b90efd2374500bc88a19fbeeb655f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Apr 2025 04:12:28 +0700 Subject: [PATCH 1/6] feat: add llama4 CCE (#2498) * feat: add llama4 CCE * fix: update model support list doc * feat: include llama4_text --- .../integrations/cut_cross_entropy/README.md | 3 + .../cut_cross_entropy/monkeypatch/llama4.py | 414 ++++++++++++++++++ .../cut_cross_entropy/monkeypatch/patch.py | 15 +- 3 files changed, 431 insertions(+), 1 deletion(-) create mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 7b428eb58..29b91bc8d 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -32,6 +32,9 @@ cut_cross_entropy: true ## Supported Models - llama +- llama4_text +- llama4 +- mllama - phi3 - gemma - gemma2 diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py new file mode 100644 index 000000000..f08663f99 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py @@ -0,0 +1,414 @@ +"""Llama4 CCE patch. Adapted from transformers 4.51.0.""" + +# pylint: disable=duplicate-code + +from types import MethodType +from typing import Optional, Tuple, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama4.modeling_llama4 import ( + _CONFIG_FOR_DOC, + LLAMA4_INPUTS_DOCSTRING, + Llama4CausalLMOutputWithPast, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + +_PATCH_OPTS: PatchOptions | None = None + + +@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def cce_forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + defer_logits_calculation: bool = False, + **kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + defer_logits_calculation (`bool`, *optional*, defaults to `False`): + If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the + memory overhead of calculating logits using regular lm_head forward pass and to use CCE. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Llama4ForCausalLM + + >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + loss = None + logits = None + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states[:, slice_indices, :], + self.lm_head.weight, + labels, + _PATCH_OPTS, + **kwargs, + ) + elif _PATCH_OPTS is not None and defer_logits_calculation: + # defer logits calculation to the ConditionalGeneration forward + logits = hidden_states[:, slice_indices, :] + else: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@replace_return_docstrings( + output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def cce_forward_multimodal( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor | None = None, + **lm_kwargs, +) -> Union[Tuple, Llama4CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + vision_feature_layer = ( + vision_feature_layer + if vision_feature_layer is not None + else self.config.vision_config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + image_sizes=image_sizes, + ) + original_inputs_embeds_shape = inputs_embeds.shape + + vision_flat = image_features.view(-1, image_features.size(-1)) + projected_vision_flat = self.multi_modal_projector(vision_flat) + + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + final_mask = special_image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore + + final_mask_1d = final_mask[..., 0].reshape(-1) + num_tokens_to_fill = final_mask_1d.sum() + + if num_tokens_to_fill != projected_vision_flat.size(0): + raise ValueError( + f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " + f"but multi_modal_projector returned {projected_vision_flat.size(0)}" + ) + + expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) + inputs_embeds = inputs_embeds.masked_scatter( + expanded_mask, projected_vision_flat + ) # type: ignore + inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + defer_logits_calculation=True, # enable deferred logits calculation + **lm_kwargs, + ) + + hidden_states = outputs[0] + loss = None + logits = None + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + # TODO: check if need to handle attention_mask + loss = apply_lce( + hidden_states, + self.language_model.lm_head.weight, + labels, + _PATCH_OPTS, + **lm_kwargs, + ) + else: + logits = hidden_states + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( + logits.device + ) + shift_logits = logits[..., :-1, :][ + shift_attention_mask.to(logits.device) != 0 + ].contiguous() + shift_labels = labels[..., 1:][ + shift_attention_mask.to(labels.device) != 0 + ].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1).to(shift_logits.device), + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Llama4CausalLMOutputWithPast( + loss=loss, + logits=logits, # type: ignore # TODO: check if need to create dummy logits + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +def patch_llama4_text( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.llama4 import modeling_llama4 + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_llama4.Llama4ForCausalLM + ), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + + return maybe_model + + setattr( + modeling_llama4.Llama4ForCausalLM, + "forward", + cce_forward, + ) + return None + + +def patch_llama4( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.llama4 import modeling_llama4 + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_llama4.Llama4ForConditionalGeneration + ), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) + + # patch the language model + maybe_model.language_model.forward = MethodType( + cce_forward, maybe_model.language_model + ) + return maybe_model + + setattr( + modeling_llama4.Llama4ForConditionalGeneration, + "forward", + cce_forward_multimodal, + ) + + # patch the causal language model + setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward) + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py index b9c83ff02..5263956ce 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py @@ -20,6 +20,10 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import ( patch_gemma3, patch_gemma3_text, ) +from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import ( + patch_llama4, + patch_llama4_text, +) from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import ( patch_mistral, patch_mistral3, @@ -28,6 +32,8 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mlla CUT_CROSS_ENTROPY_MODEL_MAPPING = { "llama": patch_llama, + "llama4": patch_llama4, + "llama4_text": patch_llama4_text, "mllama": patch_mllama, "phi3": patch_phi3, "gemma": patch_gemma, @@ -60,7 +66,14 @@ def cce_patch( raise ValueError(f"Unknown {impl=}") if isinstance(model_type_or_model, transformers.PreTrainedModel): - model_type = model_type_or_model.config.model_type + if hasattr(model_type_or_model, "config"): + model_type = getattr( + getattr(model_type_or_model, "config", None), "model_type", None + ) + else: + raise ValueError( + "model_type_or_model is a PreTrainedModel but does not have a config attribute" + ) elif isinstance(model_type_or_model, transformers.PretrainedConfig): model_type = model_type_or_model.model_type else: From 0dac2ddeacf06dac8d4fbadcdde3d02d6ef5e2b0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 7 Apr 2025 20:47:00 -0400 Subject: [PATCH 2/6] Llama4 linearized (#2502) * llama4 support for linearized experts * clean up fsdp2 sharding to prevent hang * add yaml config * cleanup example [skip ci] --- examples/{llama4 => llama-4}/scout-lora.yaml | 0 examples/llama-4/scout-qlora-fsdp1.yaml | 93 ++++++++++ requirements-dev.txt | 2 + .../monkeypatch/accelerate/__init__.py | 0 src/axolotl/monkeypatch/accelerate/fsdp2.py | 63 +++++++ src/axolotl/monkeypatch/lora_kernels.py | 174 +++++++++++------- .../monkeypatch/models/llama4/__init__.py | 0 .../monkeypatch/models/llama4/modeling.py | 101 ++++++++++ src/axolotl/utils/models.py | 12 ++ src/axolotl/utils/schemas/config.py | 2 + 10 files changed, 384 insertions(+), 63 deletions(-) rename examples/{llama4 => llama-4}/scout-lora.yaml (100%) create mode 100644 examples/llama-4/scout-qlora-fsdp1.yaml create mode 100644 src/axolotl/monkeypatch/accelerate/__init__.py create mode 100644 src/axolotl/monkeypatch/accelerate/fsdp2.py create mode 100644 src/axolotl/monkeypatch/models/llama4/__init__.py create mode 100644 src/axolotl/monkeypatch/models/llama4/modeling.py diff --git a/examples/llama4/scout-lora.yaml b/examples/llama-4/scout-lora.yaml similarity index 100% rename from examples/llama4/scout-lora.yaml rename to examples/llama-4/scout-lora.yaml diff --git a/examples/llama-4/scout-qlora-fsdp1.yaml b/examples/llama-4/scout-qlora-fsdp1.yaml new file mode 100644 index 000000000..ad2e46786 --- /dev/null +++ b/examples/llama-4/scout-qlora-fsdp1.yaml @@ -0,0 +1,93 @@ +base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16 +model_type: Llama4ForConditionalGeneration + # Automatically upload checkpoint and final model to HF + # hub_model_id: username/custom_model_name + +strict: false + +# torch_compile: true +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_glu_activation: true +liger_rms_norm: true +liger_layer_norm: true + +llama4_linearized_experts: true +load_in_4bit: true +adapter: qlora +lora_r: 32 +lora_alpha: 64 +lora_target_modules: + - self_attn.q_proj + - self_attn.k_proj + - self_attn.v_proj + - self_attn.o_proj + - shared_expert.gate_proj + - shared_expert.up_proj + - shared_expert.down_proj + # - experts.gate_projs.[0-9]+$ + # - experts.up_projs.[0-9]+$ + # - experts.down_projs.[0-9]+$ +lora_modules_to_save: + - lm_head + - embed_tokens + +chat_template: llama4 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +logging_steps: 1 +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - auto_wrap + - full_shard +fsdp_config: + fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD + fsdp_activation_checkpointing: true +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot|> diff --git a/requirements-dev.txt b/requirements-dev.txt index 9f523de54..1dce5df5f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,3 +4,5 @@ mypy types-requests quartodoc jupyter +blobfile +tiktoken diff --git a/src/axolotl/monkeypatch/accelerate/__init__.py b/src/axolotl/monkeypatch/accelerate/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py new file mode 100644 index 000000000..2a5d2151d --- /dev/null +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -0,0 +1,63 @@ +""" +monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation +""" + +import logging +import sys + +import torch + +LOG = logging.getLogger(__name__) + + +def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): + """ + Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the + parameters from rank 0 to all other ranks. This function modifies the model in-place. + + Args: + accelerator (`Accelerator`): The accelerator instance + model (`torch.nn.Module`): The model to load the state dict into + full_sd (`dict`): The full state dict to load, can only be on rank 0 + """ + import torch.distributed as dist + from torch.distributed.tensor import distribute_tensor + + LOG.info("Broadcasting full state dict to all ranks...") + sharded_sd = model.state_dict() + param_names = sorted(sharded_sd.keys()) + for param_name in param_names: + mesh = sharded_sd[param_name].device_mesh + if accelerator.is_main_process: + # Use the corresponding tensor from full_sd (assuming the key exists in full_sd) + full_param = full_sd[param_name].detach().cuda() + dist.broadcast(full_param, src=0, group=mesh.get_group()) + sharded_tensor = distribute_tensor( + full_param, mesh, sharded_sd[param_name].placements + ) + sharded_sd[param_name] = sharded_tensor + else: + # Prepare a tensor of matching shape and dtype + full_tensor = torch.empty( + sharded_sd[param_name].size(), + device="cuda", + dtype=sharded_sd[param_name].dtype, + ) + dist.broadcast(full_tensor, src=0, group=mesh.get_group()) + sharded_tensor = distribute_tensor( + full_tensor, mesh, sharded_sd[param_name].placements + ) + sharded_sd[param_name] = sharded_tensor + + model.load_state_dict(sharded_sd) + + +def patch_accelerate_fsdp_utils(): + from accelerate.utils import fsdp_utils + + fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict + setattr( + sys.modules["accelerate.utils.fsdp_utils"], + "fsdp2_load_full_state_dict", + fsdp2_load_full_state_dict, + ) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 96cfb1b69..0036fe003 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -4,7 +4,7 @@ import importlib import inspect import logging import types -from typing import Type +from typing import Generator, Tuple, Type import torch from accelerate.logging import get_logger @@ -200,6 +200,46 @@ def patch_self_attn_lora(cfg: DictDefault): ) +def find_self_attn_in_layer( + layer: nn.Module, +) -> Generator[Tuple[nn.Module], None, None]: + # general case of most models + if hasattr(layer, "self_attn"): + if all( + hasattr(layer.self_attn, proj) + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"] + ): + yield layer.self_attn + + +def find_mlp_in_layer( + layer: nn.Module, +) -> Generator[Tuple[nn.Module, nn.Module, nn.Module, nn.Module], None, None]: + # general case of most models + if hasattr(layer, "mlp"): + if all( + hasattr(layer.mlp, proj) for proj in ["gate_proj", "up_proj", "down_proj"] + ): + yield layer.mlp.gate_proj, layer.mlp.up_proj, layer.mlp.down_proj, layer.mlp + # llama4 linearized experts + if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "shared_expert"): + mlp = layer.feedforward.shared_expert + yield mlp.gate_proj, mlp.up_proj, mlp.down_proj, mlp + if hasattr(layer, "feedforward") and hasattr(layer.feedforward, "experts"): + if all( + hasattr(layer.feedforward.experts, proj) + for proj in ["gate_projs", "up_projs", "down_projs"] + ): + for gate_proj, up_proj, down_proj in zip( + layer.feedforward.experts.gate_projs, + layer.feedforward.experts.up_projs, + layer.feedforward.experts.down_projs, + ): + yield gate_proj, up_proj, down_proj, FakeMLP( + gate_proj, up_proj, down_proj + ) + + def apply_lora_kernel_patches( model: PeftModelForCausalLM, cfg: DictDefault ) -> PeftModelForCausalLM: @@ -286,74 +326,82 @@ def apply_lora_kernel_patches( for layer in layers: # Add QKV, O fallback implementations to start # These will be overwritten later (if some conditions apply) - layer.self_attn.apply_qkv = types.MethodType( - original_apply_qkv, layer.self_attn - ) - layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn) + for self_attn in find_self_attn_in_layer(layer): + self_attn.apply_qkv = types.MethodType(original_apply_qkv, self_attn) + self_attn.apply_o = types.MethodType(original_apply_o, self_attn) - if cfg.lora_mlp_kernel: - # MLP patching - gate_proj = layer.mlp.gate_proj - up_proj = layer.mlp.up_proj - down_proj = layer.mlp.down_proj + if cfg.lora_qkv_kernel: + # Query, key, value patching + layer_modules = [ + getattr(self_attn, linear_proj) + for linear_proj in ["q_proj", "k_proj", "v_proj"] + ] + can_patch_qkv = all( + hasattr(module, "lora_A") + and getattr(module, "base_layer", module).bias is None + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 + for module in layer_modules + ) - can_patch_mlp = all( - hasattr(proj, "lora_A") - and getattr(proj, "base_layer", proj).bias is None - and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 - for proj in (gate_proj, up_proj, down_proj) - ) + if can_patch_qkv: + # Add optimized implementation + self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) + else: + LOG.warning_once( + "Cannot patch some attention QKV projections - requires LoRA adapters with no bias" + ) + if cfg.lora_o_kernel: + # Output patching + layer_modules = [ + getattr(self_attn, linear_proj) for linear_proj in ["o_proj"] + ] + can_patch_o = all( + hasattr(module, "lora_A") + and getattr(module, "base_layer", module).bias is None + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 + for module in layer_modules + ) - if can_patch_mlp: - apply_fn = APPLY_FN_MAPPING[activation] - layer.mlp.forward = types.MethodType(apply_fn, layer.mlp) - else: - LOG.warning_once( - "Cannot patch some MLP layers - requires LoRA adapters with no bias" + if can_patch_o: + self_attn.apply_o = types.MethodType(apply_lora_o, self_attn) + else: + LOG.warning_once( + "Cannot patch some attention output projection - requires LoRA adapters with no bias" + ) + for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer): + if cfg.lora_mlp_kernel: + # MLP patching + can_patch_mlp = all( + hasattr(proj, "lora_A") + and getattr(proj, "base_layer", proj).bias is None + and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 + for proj in (gate_proj, up_proj, down_proj) ) - if cfg.lora_qkv_kernel: - # Query, key, value patching - layer_modules = [ - getattr(layer.self_attn, linear_proj) - for linear_proj in ["q_proj", "k_proj", "v_proj"] - ] - can_patch_qkv = all( - hasattr(module, "lora_A") - and getattr(module, "base_layer", module).bias is None - and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 - for module in layer_modules - ) - if can_patch_qkv: - # Add optimized implementation - layer.self_attn.apply_qkv = types.MethodType( - apply_lora_qkv, layer.self_attn - ) - else: - LOG.warning_once( - "Cannot patch some attention QKV projections - requires LoRA adapters with no bias" - ) - if cfg.lora_o_kernel: - # Output patching - layer_modules = [ - getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] - ] - can_patch_o = all( - hasattr(module, "lora_A") - and getattr(module, "base_layer", module).bias is None - and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 - for module in layer_modules - ) - - if can_patch_o: - layer.self_attn.apply_o = types.MethodType( - apply_lora_o, layer.self_attn - ) - else: - LOG.warning_once( - "Cannot patch some attention output projection - requires LoRA adapters with no bias" - ) + if can_patch_mlp: + apply_fn = APPLY_FN_MAPPING[activation] + layer.mlp.forward = types.MethodType(apply_fn, mlp) + else: + LOG.warning_once( + "Cannot patch some MLP layers - requires LoRA adapters with no bias" + ) LOG.setLevel(original_level) return model + + +class FakeMLP(nn.Module): + """ + placeholder MLP for triton patching + """ + + gate_proj: nn.Linear + up_proj: nn.Linear + down_proj: nn.Linear + + def __init__(self, gate_proj, up_proj, down_proj): + super().__init__() + self.gate_proj = gate_proj + self.up_proj = up_proj + self.down_proj = down_proj diff --git a/src/axolotl/monkeypatch/models/llama4/__init__.py b/src/axolotl/monkeypatch/models/llama4/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/llama4/modeling.py b/src/axolotl/monkeypatch/models/llama4/modeling.py new file mode 100644 index 000000000..b2a46ab86 --- /dev/null +++ b/src/axolotl/monkeypatch/models/llama4/modeling.py @@ -0,0 +1,101 @@ +""" +Modified Llama-4 text experts modeling for linearized experts for improved LoRA support +""" + +import sys + +import torch +from torch import nn +from transformers import Llama4Config +from transformers.activations import ACT2FN + + +class Llama4TextExperts(nn.Module): + """ + Modified Llama-4 text experts modeling for linearized experts + """ + + def __init__(self, config: Llama4Config): + super().__init__() + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + + # Replace fused gate_up_proj with separate Linear modules + self.gate_projs = nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.expert_dim, bias=False) + for _ in range(self.num_experts) + ] + ) + + self.up_projs = nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.expert_dim, bias=False) + for _ in range(self.num_experts) + ] + ) + + # Replace down_proj Parameter with Linear modules + self.down_projs = nn.ModuleList( + [ + nn.Linear(self.expert_dim, self.hidden_size, bias=False) + for _ in range(self.num_experts) + ] + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward method using separate Linear layers for each expert. + + Args: + hidden_states (torch.Tensor): (num_experts * batch_size, hidden_size) + The input should be organized by expert + + Returns: + torch.Tensor: (num_experts * batch_size, hidden_size) + """ + # Reshape to separate by expert + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + # batch_size_per_expert = hidden_states.size(1) + + # Initialize output tensor + next_states = torch.zeros_like(hidden_states) + + # Process each expert separately + for i in range(self.num_experts): + # Get input for this expert + expert_input = hidden_states[ + i + ] # Shape: (batch_size_per_expert, hidden_size) + + # Apply gate and up projections + gate = self.gate_projs[i]( + expert_input + ) # Shape: (batch_size_per_expert, expert_dim) + up = self.up_projs[i]( + expert_input + ) # Shape: (batch_size_per_expert, expert_dim) + + # Apply activation and down projection + next_states[i] = self.down_projs[i](up * self.act_fn(gate)) + + # Flatten back to original shape + return next_states.view(-1, self.hidden_size) + + +def patch_llama4_linearized_modeling(): + """ + Patch Llama4TextExperts to use separate Linear layers for each expert. + """ + from transformers.models.llama4 import modeling_llama4 + + modeling_llama4.Llama4TextExperts = Llama4TextExperts + setattr( + sys.modules["transformers.models.llama4"], + "Llama4TextExperts", + Llama4TextExperts, + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 024673b8e..f808f4bdd 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -544,8 +544,20 @@ class ModelLoader: self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name def apply_patches(self) -> None: + if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": + from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils + + patch_accelerate_fsdp_utils() # patch gemma3 conditional generation forward before loading plugins # as it could be overridden by plugins + if self.cfg.model_config_type == "llama4": + if self.cfg.llama4_linearized_experts: + from axolotl.monkeypatch.models.llama4.modeling import ( + patch_llama4_linearized_modeling, + ) + + patch_llama4_linearized_modeling() + if self.cfg.model_config_type == "gemma3": from axolotl.monkeypatch.gemma3 import ( patch_gemma3conditionalgeneration_forward, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4083fcc22..882c9a248 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -245,6 +245,8 @@ class AxolotlInputConfig( lora_qkv_kernel: bool | None = None lora_o_kernel: bool | None = None + llama4_linearized_experts: bool | None = None + deepspeed: str | dict[str, Any] | None = None fsdp: list[str] | None = None fsdp_config: dict[str, Any] | None = None From 04624c5a8d3f4b6c798fb9de8cfacbade0861bdc Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Mon, 7 Apr 2025 15:12:45 -0400 Subject: [PATCH 3/6] bump flex patching transformers to v4.51, update torch compile kwargs to be in line with transformers v4.51 --- src/axolotl/monkeypatch/attention/flex_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index d65ee706f..c643e2fd2 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -10,9 +10,9 @@ import transformers def patch_flex_wrapper(): # TODO remove this patch when transformers#37285 is merged and in a release is_torch_2_6 = torch.__version__.startswith("2.6") - is_transformers_below_4_51 = transformers.__version__ < "4.51.0" + is_transformers_below_4_52 = transformers.__version__ < "4.52.0" - if not (is_torch_2_6 and is_transformers_below_4_51): + if not (is_torch_2_6 and is_transformers_below_4_52): return from torch.nn.attention.flex_attention import flex_attention @@ -40,7 +40,7 @@ def patch_flex_wrapper(): if not self._is_flex_compiled: self._compiled_flex_attention = torch.compile( flex_attention, - dynamic=False, + backend="inductor", mode="max-autotune-no-cudagraphs", fullgraph=True, ) From bdaaba2784f002898a1d771497d385a4e214f7a7 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Mon, 7 Apr 2025 17:05:08 -0400 Subject: [PATCH 4/6] remove backend='inductor' in local patch --- src/axolotl/monkeypatch/attention/flex_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index c643e2fd2..58e0c8e89 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -40,7 +40,6 @@ def patch_flex_wrapper(): if not self._is_flex_compiled: self._compiled_flex_attention = torch.compile( flex_attention, - backend="inductor", mode="max-autotune-no-cudagraphs", fullgraph=True, ) From 75c565d476aa381f6627120b60ca7a1860cfc0de Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Mon, 7 Apr 2025 17:06:51 -0400 Subject: [PATCH 5/6] add back dynamic=False --- src/axolotl/monkeypatch/attention/flex_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index 58e0c8e89..babecfc7a 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -40,6 +40,7 @@ def patch_flex_wrapper(): if not self._is_flex_compiled: self._compiled_flex_attention = torch.compile( flex_attention, + dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True, ) From cdb16069afad3d02d0727092e37d1683b25db296 Mon Sep 17 00:00:00 2001 From: salman Date: Tue, 8 Apr 2025 11:28:52 +0100 Subject: [PATCH 6/6] fixing transformers version --- src/axolotl/monkeypatch/attention/flex_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index babecfc7a..18d195f17 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -10,9 +10,9 @@ import transformers def patch_flex_wrapper(): # TODO remove this patch when transformers#37285 is merged and in a release is_torch_2_6 = torch.__version__.startswith("2.6") - is_transformers_below_4_52 = transformers.__version__ < "4.52.0" + is_transformers_below_4_51_1 = transformers.__version__ < "4.51.1" - if not (is_torch_2_6 and is_transformers_below_4_52): + if not (is_torch_2_6 and is_transformers_below_4_51_1): return from torch.nn.attention.flex_attention import flex_attention