diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 350b04cca..df12b3c89 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.6.0 - axolotl_extras: + axolotl_extras: vllm is_latest: true runs-on: axolotl-gpu-runner steps: diff --git a/examples/llama-4/README.md b/examples/llama-4/README.md index a0ec1c70e..b33f8ae3c 100644 --- a/examples/llama-4/README.md +++ b/examples/llama-4/README.md @@ -1,16 +1,28 @@ # Llama 4 by Meta AI +## Flash Attention vs Flex Attention + +While Flash Attention to support is "enabled" for Llama-4, the upstream implementation is not correct and usage of Flex Attention is recommended. + ## Available Examples ### Llama 4 Scout 17Bx16Experts (109B) -- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml) -- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml) -- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml) -Our Single H100 implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-sft/runs/zic56rhd) +Flex Attention +- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100-flex.yaml) +- [Text Multi GPU QLoRA w/ FSDP2](./scout-qlora-flexattn-fsdp2.yaml) + +[//]: # (Flash Attention (Do not use)) + +[//]: # (- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)) + +[//]: # (- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)) + +[//]: # (- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)) + +Our Single H100 implementation for Llama 4 Scout uses only 64.5GB VRAM for post-training with 4k context length @ 519 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/wpie7dkj) +Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @ 280tps/gpu, [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/2lkezdj8) ### Llama 4 Maverick 17Bx128Experts (400B) -- [Text Multi GPU QLoRA w/FSDP1](./maverick-qlora-fsdp1.yaml) - -Our 4xH100 implementation for Llama 4 Maverick uses 79.5GB VRAM/GPU for post-training with 4k context length @ 206 tokens/second. [WandB logs here.](https://wandb.ai/axolotl-ai/llama-sft/runs/siyvwuxc?nw=nwuserwinglian) +Coming Soon diff --git a/examples/llama-4/maverick-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml similarity index 100% rename from examples/llama-4/maverick-qlora-fsdp1.yaml rename to examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml diff --git a/examples/llama-4/scout-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml similarity index 100% rename from examples/llama-4/scout-qlora-fsdp1.yaml rename to examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml diff --git a/examples/llama-4/scout-qlora-single-h100.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml similarity index 100% rename from examples/llama-4/scout-qlora-single-h100.yaml rename to examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml diff --git a/examples/llama-4/scout-vision-qlora-fsdp.yaml b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml similarity index 100% rename from examples/llama-4/scout-vision-qlora-fsdp.yaml rename to examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml diff --git a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml new file mode 100644 index 000000000..9a411883e --- /dev/null +++ b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml @@ -0,0 +1,86 @@ +base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16 +model_type: Llama4ForConditionalGeneration +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_glu_activation: true +liger_rms_norm: true +liger_layer_norm: true + +llama4_linearized_experts: true +load_in_4bit: true +adapter: qlora +lora_r: 32 +lora_alpha: 64 +lora_target_modules: + - self_attn.q_proj + - self_attn.k_proj + - self_attn.v_proj + - self_attn.o_proj + - shared_expert.gate_proj + - shared_expert.up_proj + - shared_expert.down_proj + # - experts.gate_projs.[0-9]+$ + # - experts.up_projs.[0-9]+$ + # - experts.down_projs.[0-9]+$ +lora_modules_to_save: + # - lm_head + # - embed_tokens + +chat_template: llama4 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +gradient_accumulation_steps: 1 +micro_batch_size: 2 +num_epochs: 3 +optimizer: adamw_torch_4bit +lr_scheduler: cosine +learning_rate: 1e-4 + +bf16: true +tf32: true + +logging_steps: 1 +flex_attention: true +flex_attn_compile_kwargs: + dynamic: false + mode: max-autotune-no-cudagraphs + +warmup_steps: 10 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - auto_wrap + - full_shard +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD + fsdp_reshard_after_forward: true + fsdp_activation_checkpointing: true +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot|> diff --git a/examples/llama-4/scout-qlora-single-h100-flex.yaml b/examples/llama-4/scout-qlora-single-h100-flex.yaml new file mode 100644 index 000000000..c7a3b28d0 --- /dev/null +++ b/examples/llama-4/scout-qlora-single-h100-flex.yaml @@ -0,0 +1,85 @@ +base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16 +model_type: Llama4ForConditionalGeneration +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.liger.LigerPlugin + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +liger_glu_activation: true +liger_rms_norm: true +liger_layer_norm: true +cut_cross_entropy: true + +llama4_linearized_experts: true # needed with custom linearized experts model +load_in_4bit: true +adapter: qlora +lora_r: 32 +lora_alpha: 64 +lora_target_modules: + - self_attn.q_proj + - self_attn.k_proj + - self_attn.v_proj + - self_attn.o_proj + - shared_expert.gate_proj + - shared_expert.up_proj + - shared_expert.down_proj + # - experts.gate_projs.[0-9]+$ # optionally train the moe experts + # - experts.up_projs.[0-9]+$ + # - experts.down_projs.[0-9]+$ +lora_modules_to_save: + # - lm_head # needed if modifying vocabulary + # - embed_tokens + +lora_mlp_kernel: true +lora_qkv_kernel: true +lora_o_kernel: true + +chat_template: llama4 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 4096 # up to 8k will work on a single H100 +sample_packing: true +pad_to_sequence_len: true + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_4bit +lr_scheduler: cosine +learning_rate: 1e-4 + +bf16: true +tf32: true + +torch_compile: true +flex_attention: true +flex_attn_compile_kwargs: + dynamic: false + mode: max-autotune-no-cudagraphs + +gradient_checkpointing: offload +gradient_checkpointing_kwargs: + use_reentrant: false + +logging_steps: 1 +warmup_steps: 20 +evals_per_epoch: 1 +saves_per_epoch: 1 + +weight_decay: 0.0 +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot|> diff --git a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml new file mode 100644 index 000000000..9fbd34107 --- /dev/null +++ b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml @@ -0,0 +1,89 @@ +base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16 +model_type: Llama4ForConditionalGeneration +processor_type: Llama4Processor +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +sequence_len: 4096 + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_glu_activation: true +liger_rms_norm: true +liger_layer_norm: true + +llama4_linearized_experts: true # use Axolotl's customized model +load_in_4bit: true +adapter: qlora +lora_r: 32 +lora_alpha: 64 +lora_target_modules: + - self_attn.q_proj + - self_attn.k_proj + - self_attn.v_proj + - self_attn.o_proj + - shared_expert.gate_proj + - shared_expert.up_proj + - shared_expert.down_proj + - vision_adapter.mlp.fc1 + - vision_adapter.mlp.fc2 + # - experts.gate_projs.[0-9]+$ + # - experts.up_projs.[0-9]+$ + # - experts.down_projs.[0-9]+$ +lora_modules_to_save: + - lm_head + - embed_tokens + +chat_template: llama4 +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_4bit +lr_scheduler: cosine +learning_rate: 1e-4 + +bf16: true +tf32: true + +logging_steps: 1 +flex_attention: true +flex_attn_compile_kwargs: + dynamic: false + mode: max-autotune-no-cudagraphs + +warmup_steps: 10 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - auto_wrap + - full_shard +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD + fsdp_reshard_after_forward: true + fsdp_activation_checkpointing: true +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot|> diff --git a/requirements.txt b/requirements.txt index e8377b880..76e8be8c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ liger-kernel==0.5.6 packaging==23.2 peft==0.15.1 -transformers==4.51.1 +transformers==4.51.3 tokenizers>=0.21.1 accelerate==1.6.0 datasets==3.5.0 diff --git a/setup.py b/setup.py index 29719b1f3..6c911d8f7 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ def parse_requirements(extras_require_map): if (major, minor) >= (2, 6): _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers==0.0.29.post2") - extras_require_map["vllm"] = ["vllm==0.8.1"] + extras_require_map["vllm"] = ["vllm==0.8.3"] elif (major, minor) >= (2, 5): _install_requires.pop(_install_requires.index(xformers_version)) if patch == 0: diff --git a/src/axolotl/cli/delinearize_llama4.py b/src/axolotl/cli/delinearize_llama4.py new file mode 100644 index 000000000..c92bae930 --- /dev/null +++ b/src/axolotl/cli/delinearize_llama4.py @@ -0,0 +1,156 @@ +""" +CLI tool to delinearize quantized/Linearized Llama-4 models. +""" + +import os +from pathlib import Path +from typing import Generator, Union + +import fire +import torch +from accelerate import init_empty_weights +from dotenv import load_dotenv +from transformers import AutoProcessor + + +def iter_convert_patched_to_hf(model_state_dict, num_experts) -> Generator: + keys = list(model_state_dict.keys()) + for key in keys: + if ".feed_forward.experts." not in key: + yield key, model_state_dict[key] + if ".feed_forward.experts.gate_projs" in key: + # gate gets fused with up so skip the yield on this and we'll fuse it when asking for the up + continue + if ".feed_forward.experts.up_projs" in key: + if ".feed_forward.experts.up_projs.0." in key: + # handle the re-shape and fusing of gate and up, and conversion from linear to parameter + prefix = key.split(".up_projs.0.")[0] + key = f"{prefix}.gate_up_proj" + # grab all the up_projs and gate_projs across all experts + gate_stacked = torch.stack( + [ + model_state_dict[ + f"{prefix}.gate_projs.{expert_idx}.weight" + ].transpose(0, 1) + for expert_idx in range(num_experts) + ] + ) + up_stacked = torch.stack( + [ + model_state_dict[ + f"{prefix}.up_projs.{expert_idx}.weight" + ].transpose(0, 1) + for expert_idx in range(num_experts) + ] + ) + gate_up_proj = torch.cat((gate_stacked, up_stacked), dim=-1) + del gate_stacked, up_stacked + yield key, gate_up_proj + else: + del model_state_dict[key] + continue + if ".feed_forward.experts.down_projs" in key: + if ".feed_forward.experts.down_projs.0." in key: + # handle the re-shape and fusing of gate and up, and conversion from linear to parameter + prefix = key.split(".down_projs.0.")[0] + key = f"{prefix}.down_proj" + # grab all the down_projs across all experts + down_stacked = torch.stack( + [ + model_state_dict[ + f"{prefix}.down_projs.{expert_idx}.weight" + ].transpose(0, 1) + for expert_idx in range(num_experts) + ] + ) + yield key, down_stacked + else: + del model_state_dict[key] + continue + + +def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None: + """ + Convert a patched HF format Llama4 model (with separated projections) + back to the original HF format (with fused projections). + + Args: + model: Path to the patched HF model + output: Path to save the converted model + """ + print(f"Loading model from {model}") + from axolotl.monkeypatch.models.llama4.modeling import ( + patch_llama4_linearized_modeling, + ) + + unpatch_llama4 = patch_llama4_linearized_modeling() + from transformers import Llama4ForConditionalGeneration + + model_ = Llama4ForConditionalGeneration.from_pretrained( + model, torch_dtype=torch.bfloat16 + ) + processor = AutoProcessor.from_pretrained(model) + processor.save_pretrained(output) + + device = model_.device.type + if device == "cuda": + print( + f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB" + ) + print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB") + model_config = model_.config + config = model_.config.get_text_config() + + # Get key dimensions from the config + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + num_experts = config.num_local_experts + + print( + f"Model dimensions: hidden_size={hidden_size}, intermediate_size={intermediate_size}, num_experts={num_experts}" + ) + + # Create output directory if it doesn't exist + os.makedirs(output, exist_ok=True) + + # Get state dict + state_dict = model_.state_dict() + del model_ + + # Create a new state dict for the converted model + converted_state_dict = {} + + # First, copy all keys that don't need modification + for key, value in iter_convert_patched_to_hf(state_dict, num_experts): + converted_state_dict[key] = value + + del state_dict + if device == "cuda": + torch.cuda.empty_cache() + print("State dict converted.") + print( + f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB" + ) + print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB") + # Ideally re-load the model import to load the converted state dict + # Save the converted model + with init_empty_weights(): + unpatch_llama4() + model_ = Llama4ForConditionalGeneration(model_config) + + if device == "cuda": + print("State dict loaded into model.") + print( + f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB" + ) + print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB") + model_.load_state_dict(converted_state_dict, strict=False, assign=True) + print(f"Saving converted model to {output}...") + model_.save_pretrained(output) + + print(f"Model successfully converted and saved to {output}") + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 7532a9689..593614733 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -330,6 +330,15 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs): do_vllm_serve(config, cli_args) +@cli.command() +@click.argument("model", type=click.Path(exists=True, path_type=str)) +@click.argument("output", type=click.Path(exists=False, path_type=str)) +def delinearize_llama4(model: str, output: str) -> None: + from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4 + + do_delinearize_llama4(model, output) + + cli.add_command(lm_eval) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index c7a3a3225..5c8802dd1 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -40,6 +40,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None: LOG.warning("Error raised: %s", e) model.generation_config.do_sample = True + model.config.use_cache = True if cfg.local_rank == 0: LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...") diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py index f08663f99..7204f5c90 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py @@ -165,7 +165,7 @@ def cce_forward( ) def cce_forward_multimodal( self, - input_ids: torch.LongTensor | None = None, + input_ids: torch.LongTensor | None = None, # type: ignore pixel_values: torch.FloatTensor | None = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -254,7 +254,7 @@ def cce_forward_multimodal( ) if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore if pixel_values is not None: image_features = self.get_image_features( @@ -263,13 +263,13 @@ def cce_forward_multimodal( vision_feature_select_strategy=vision_feature_select_strategy, image_sizes=image_sizes, ) - original_inputs_embeds_shape = inputs_embeds.shape + original_inputs_embeds_shape = inputs_embeds.shape # type: ignore vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.multi_modal_projector(vision_flat) special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - final_mask = special_image_mask.to(inputs_embeds.device) + final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore final_mask_1d = final_mask[..., 0].reshape(-1) diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 2a5d2151d..d8ec00c69 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -49,7 +49,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic ) sharded_sd[param_name] = sharded_tensor - model.load_state_dict(sharded_sd) + model.load_state_dict(sharded_sd, assign=True) def patch_accelerate_fsdp_utils(): diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index d65ee706f..3652a30b3 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -7,12 +7,11 @@ import torch import transformers -def patch_flex_wrapper(): +def patch_flex_wrapper(**flex_attn_compile_kwargs): # TODO remove this patch when transformers#37285 is merged and in a release is_torch_2_6 = torch.__version__.startswith("2.6") - is_transformers_below_4_51 = transformers.__version__ < "4.51.0" - if not (is_torch_2_6 and is_transformers_below_4_51): + if not is_torch_2_6: return from torch.nn.attention.flex_attention import flex_attention @@ -32,17 +31,24 @@ def patch_flex_wrapper(): cls._instance = super().__new__(cls) return cls._instance + @classmethod + def del_singleton(cls): + cls._instance = None + @torch.compiler.disable(recursive=False) - def __init__(self): + def __init__(self, training): """ Initialize or update the singleton instance. """ - if not self._is_flex_compiled: + self.training = None + if not self._is_flex_compiled or training != self.training: + # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may + # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" + # see https://github.com/pytorch/pytorch/issues/146260 for training + self.training = training self._compiled_flex_attention = torch.compile( flex_attention, - dynamic=False, - mode="max-autotune-no-cudagraphs", - fullgraph=True, + **flex_attn_compile_kwargs, ) self._is_flex_compiled = True @@ -50,15 +56,22 @@ def patch_flex_wrapper(): return self._compiled_flex_attention transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention + setattr( + sys.modules["transformers.integrations.flex_attention"], + "WrappedFlexAttention", + WrappedFlexAttention, + ) def patch_flex_make_mask(): is_torch_2_6 = torch.__version__.startswith("2.6") - is_transformers_eq_4_51 = transformers.__version__ == "4.51.0" - if not (is_torch_2_6 and is_transformers_eq_4_51): + if not is_torch_2_6: return + from torch.nn.attention.flex_attention import ( + _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size, + ) from torch.nn.attention.flex_attention import ( BlockMask, ) @@ -104,14 +117,16 @@ def patch_flex_make_mask(): if not query_length: query_length = total_seq_len attention_mask_2d = torch.nn.functional.pad( - attention_mask_2d, value=0, pad=(0, key_length) + attention_mask_2d, + value=0, + pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))), ) device = attention_mask_2d.device document_ids = attention_mask_2d.clone() if attention_chunk_size is not None: # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] - document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // ( + chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // ( attention_chunk_size ) @@ -138,6 +153,18 @@ def patch_flex_make_mask(): final_mask = causal_mask & padding_mask & document_mask return final_mask + def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + """ + Combines the chunk mask with the causal mask for chunked attention. + """ + chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx] + causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx) + return chunk_mask & causal_doc_mask + + mask_mod_maybe_combined = ( + causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod + ) + if offsets is not None: q_offset = offsets[0] kv_offset = offsets[1] @@ -145,10 +172,10 @@ def patch_flex_make_mask(): def mask_mod(batch_idx, head_idx, q_idx, kv_idx): offset_q = q_idx + q_offset offset_kv = kv_idx + kv_offset - return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv) + return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv) else: - mask_mod = causal_mask_mod + mask_mod = mask_mod_maybe_combined return create_block_causal_mask_flex( mask_mod=mask_mod, B=batch_size, @@ -160,11 +187,16 @@ def patch_flex_make_mask(): ) for n in tuple(sys.modules): - if ".modeling_" in n and "llama4" not in n: + if ".modeling_" in n: if hasattr(sys.modules[n], "make_flex_block_causal_mask"): sys.modules[n].make_flex_block_causal_mask = ( patched_make_flex_block_causal_mask ) + setattr( + sys.modules[n], + "make_flex_block_causal_mask", + patched_make_flex_block_causal_mask, + ) transformers.integrations.flex_attention.make_flex_block_causal_mask = ( patched_make_flex_block_causal_mask diff --git a/src/axolotl/monkeypatch/models/llama4/modeling.py b/src/axolotl/monkeypatch/models/llama4/modeling.py index b2a46ab86..4127793e7 100644 --- a/src/axolotl/monkeypatch/models/llama4/modeling.py +++ b/src/axolotl/monkeypatch/models/llama4/modeling.py @@ -93,9 +93,20 @@ def patch_llama4_linearized_modeling(): """ from transformers.models.llama4 import modeling_llama4 + old_lamma_4_text_experts = modeling_llama4.Llama4TextExperts modeling_llama4.Llama4TextExperts = Llama4TextExperts setattr( sys.modules["transformers.models.llama4"], "Llama4TextExperts", Llama4TextExperts, ) + + def unpatch(): + modeling_llama4.Llama4TextExperts = old_lamma_4_text_experts + setattr( + sys.modules["transformers.models.llama4"], + "Llama4TextExperts", + old_lamma_4_text_experts, + ) + + return unpatch diff --git a/src/axolotl/monkeypatch/trainer_eval_guard.py b/src/axolotl/monkeypatch/trainer_eval_guard.py new file mode 100644 index 000000000..e929ac766 --- /dev/null +++ b/src/axolotl/monkeypatch/trainer_eval_guard.py @@ -0,0 +1,78 @@ +""" +fix for FSDP2 evals when using torch.compile +""" + +import inspect +import logging + +from transformers import Trainer + +from axolotl.monkeypatch.utils import detab_code + +LOG = logging.getLogger(__name__) + +ORIGINAL_TRAINER_CODE = """ + model.eval() +""" + +PATCHED_TRAINER_CODE = """ + if hasattr(model, "eval") and callable(model.eval): + self.model.eval() +""" + + +def get_evaluation_loop_code() -> str: + training_loop = inspect.getsource(Trainer.evaluation_loop) + return training_loop + + +def check_evaluation_loop_is_patchable() -> bool: + eval_loop = get_evaluation_loop_code() + eval_loop, _ = detab_code(eval_loop) + return ORIGINAL_TRAINER_CODE in eval_loop + + +def patch_evaluation_loop_for_fsdp2(): + """ + monkeypatch for fixing the eval loop for fsdp2 with torch.compile + """ + + try: + evaluation_loop = get_evaluation_loop_code() + except OSError: + return + Trainer._original_evaluation_loop = ( # pylint: disable=protected-access + evaluation_loop + ) + evaluation_loop, _ = detab_code(evaluation_loop) + if ORIGINAL_TRAINER_CODE not in evaluation_loop: + return + + evaluation_loop = evaluation_loop.replace( + ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE + ) + evaluation_loop = evaluation_loop.replace( + "def evaluation_loop(", + "def _fixed_evaluation_loop(", + 1, + ) + + # load imports necessary + import transformers.trainer + + items_to_import = [] + for item in dir(transformers.trainer): + if item in evaluation_loop: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching _inner_training_loop for fsdp optimizer save") + Trainer.evaluation_loop = ( # pylint: disable=protected-access + _fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821 + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index c2bddeeec..e003c8b67 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -81,6 +81,11 @@ def setup_model_and_tokenizer( # Apply freezing if specified if cfg.unfrozen_parameters: freeze_layers_except(model, cfg.unfrozen_parameters) + if any( + any(embed in param for embed in ["lm_head", "embed_tokens"]) + for param in cfg.unfrozen_parameters + ): + model.enable_input_require_grads() return model, tokenizer, peft_config, processor diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py index 7199eaa36..65ca62137 100644 --- a/src/axolotl/utils/freeze.py +++ b/src/axolotl/utils/freeze.py @@ -2,13 +2,14 @@ module to freeze/unfreeze parameters by name """ -import logging import re from typing import Callable, List, Tuple, Union +from accelerate.logging import get_logger + from axolotl.utils.distributed import is_main_process -LOG = logging.getLogger("axolotl.utils.freeze") +LOG = get_logger(__name__) def freeze_layers_except(model, regex_patterns): @@ -184,7 +185,7 @@ class LayerNamePattern: """ self.raw_pattern = pattern name_pattern, self.range = self._parse_pattern(pattern) - self.name_regex = re.compile(name_pattern.replace(".", "\\.")) + self.name_regex = re.compile(re.sub(r"\.(?!\+)", "\\.", name_pattern)) def match(self, name: str) -> bool: """ diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c5e569f13..4d4366994 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -542,6 +542,17 @@ class ModelLoader: from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils patch_accelerate_fsdp_utils() + + if self.cfg.flex_attention: + from axolotl.monkeypatch.attention.flex_attn import ( + patch_flex_make_mask, + patch_flex_wrapper, + ) + + flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} + patch_flex_wrapper(**flex_attn_compile_kwargs) + patch_flex_make_mask() + # patch gemma3 conditional generation forward before loading plugins # as it could be overridden by plugins if self.cfg.model_config_type == "llama4": @@ -905,13 +916,6 @@ class ModelLoader: self.model_config._attn_implementation = ( # pylint: disable=protected-access "flex_attention" ) - from axolotl.monkeypatch.attention.flex_attn import ( - patch_flex_make_mask, - patch_flex_wrapper, - ) - - patch_flex_wrapper() - patch_flex_make_mask() elif self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 882c9a248..8a76c4eb4 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -225,6 +225,7 @@ class AxolotlInputConfig( sdp_attention: bool | None = None s2_attention: bool | None = None flex_attention: bool | None = None + flex_attn_compile_kwargs: dict[str, Any] | None = None flash_attention: bool | None = None flash_attn_cross_entropy: bool | None = None flash_attn_rms_norm: bool | None = None @@ -1276,11 +1277,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ): capabilities = data.get("capabilities") is_fsdp = data.get("fsdp") is not None - - if capabilities and capabilities.get("n_gpu", 0) > 1: + is_fsdp2 = ( + data.get("fsdp_config") is not None + and str(data.get("fsdp_config").get("fsdp_version")) == "2" + ) + if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2: if is_fsdp: raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP." + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP1." ) return data diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 964b17086..c1154be68 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -17,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder +from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -235,7 +236,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): def process_datasets_for_packing(cfg, train_dataset, eval_dataset): - if cfg.model_config_type in ["mamba", "gemma3"]: + drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"] + if drop_attn_mask: LOG.info("dropping attention_mask column") train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: @@ -625,6 +627,12 @@ def setup_trainer( A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based on the provided parameters. """ + if ( + cfg.torch_compile + and cfg.fsdp_config + and str(cfg.fsdp_config.fsdp_version) == "2" + ): + patch_evaluation_loop_for_fsdp2() if cfg.rl: trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor) trainer_builder.model_ref = model_ref diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py index 3af6d5a76..cbe3794b3 100644 --- a/tests/e2e/multigpu/solo/test_flex.py +++ b/tests/e2e/multigpu/solo/test_flex.py @@ -56,11 +56,12 @@ class TestPackedFlex: "num_epochs": 1, "micro_batch_size": 2, "gradient_accumulation_steps": 2, + "gradient_checkpointing": True, "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", - "max_steps": 5, + "max_steps": 2, "use_tensorboard": True, "save_strategy": "no", } diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index bd999e2f3..f4914ed1a 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -177,6 +177,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "NCCL_P2P_LEVEL": "LOC", **current_env, "CUDA_VISIBLE_DEVICES": "1", + "VLLM_USE_V1": "0", } vllm_process_id = start_vllm( cfg.base_model, @@ -264,6 +265,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable **current_env, "CUDA_VISIBLE_DEVICES": "1", + "VLLM_USE_V1": "0", } vllm_process_id = start_vllm( cfg.base_model,