Fix: add delinearization and make qlora work with fsdp2 (#2515)
* fixes for delinearization, and make qlora work with fsdp2 * Add back mistakenly removed lm_eval * typo [skip ci] * patch evals for torch.compile + fsdp2 * also check torch_compile w fsdp2 * lots of fixes for flex attn with llama4 * fix patch check and patch llama4 too * attempt to make the patches stick * use transformers 4.51.2 * update configs and README for llama4 * remove torch.compile for CI test * cleanup any existing singletons * set singleton cache to None instead of deleting * use importlib reload with monkeypatch * don't worry about transformers version, mark inputs with grads, fix regex * make sure embeds aren't on cpu * logging and mem improvements * vllm version and add to docker, make sure to save processor on conversion * fix ambiguous tensor bool check * fix vllm to not use v1, upgrade hf transformers * fix tests * make flex_attn_compile_kwargs configurable, since this depends on model params --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
This commit is contained in:
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
# Llama 4 by Meta AI
|
||||
|
||||
## Flash Attention vs Flex Attention
|
||||
|
||||
While Flash Attention to support is "enabled" for Llama-4, the upstream implementation is not correct and usage of Flex Attention is recommended.
|
||||
|
||||
## Available Examples
|
||||
|
||||
### Llama 4 Scout 17Bx16Experts (109B)
|
||||
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
|
||||
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
|
||||
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
|
||||
|
||||
Our Single H100 implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-sft/runs/zic56rhd)
|
||||
Flex Attention
|
||||
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100-flex.yaml)
|
||||
- [Text Multi GPU QLoRA w/ FSDP2](./scout-qlora-flexattn-fsdp2.yaml)
|
||||
|
||||
[//]: # (Flash Attention (Do not use))
|
||||
|
||||
[//]: # (- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml))
|
||||
|
||||
[//]: # (- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml))
|
||||
|
||||
[//]: # (- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml))
|
||||
|
||||
Our Single H100 implementation for Llama 4 Scout uses only 64.5GB VRAM for post-training with 4k context length @ 519 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/wpie7dkj)
|
||||
Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @ 280tps/gpu, [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/2lkezdj8)
|
||||
|
||||
### Llama 4 Maverick 17Bx128Experts (400B)
|
||||
|
||||
- [Text Multi GPU QLoRA w/FSDP1](./maverick-qlora-fsdp1.yaml)
|
||||
|
||||
Our 4xH100 implementation for Llama 4 Maverick uses 79.5GB VRAM/GPU for post-training with 4k context length @ 206 tokens/second. [WandB logs here.](https://wandb.ai/axolotl-ai/llama-sft/runs/siyvwuxc?nw=nwuserwinglian)
|
||||
Coming Soon
|
||||
|
||||
86
examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
Normal file
86
examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
Normal file
@@ -0,0 +1,86 @@
|
||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
||||
model_type: Llama4ForConditionalGeneration
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm: true
|
||||
liger_layer_norm: true
|
||||
|
||||
llama4_linearized_experts: true
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_modules:
|
||||
- self_attn.q_proj
|
||||
- self_attn.k_proj
|
||||
- self_attn.v_proj
|
||||
- self_attn.o_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
# - experts.gate_projs.[0-9]+$
|
||||
# - experts.up_projs.[0-9]+$
|
||||
# - experts.down_projs.[0-9]+$
|
||||
lora_modules_to_save:
|
||||
# - lm_head
|
||||
# - embed_tokens
|
||||
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- auto_wrap
|
||||
- full_shard
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_activation_checkpointing: true
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot|>
|
||||
85
examples/llama-4/scout-qlora-single-h100-flex.yaml
Normal file
85
examples/llama-4/scout-qlora-single-h100-flex.yaml
Normal file
@@ -0,0 +1,85 @@
|
||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
||||
model_type: Llama4ForConditionalGeneration
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm: true
|
||||
liger_layer_norm: true
|
||||
cut_cross_entropy: true
|
||||
|
||||
llama4_linearized_experts: true # needed with custom linearized experts model
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_modules:
|
||||
- self_attn.q_proj
|
||||
- self_attn.k_proj
|
||||
- self_attn.v_proj
|
||||
- self_attn.o_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
# - experts.gate_projs.[0-9]+$ # optionally train the moe experts
|
||||
# - experts.up_projs.[0-9]+$
|
||||
# - experts.down_projs.[0-9]+$
|
||||
lora_modules_to_save:
|
||||
# - lm_head # needed if modifying vocabulary
|
||||
# - embed_tokens
|
||||
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096 # up to 8k will work on a single H100
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
torch_compile: true
|
||||
flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
gradient_checkpointing: offload
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
logging_steps: 1
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot|>
|
||||
89
examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
Normal file
89
examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
Normal file
@@ -0,0 +1,89 @@
|
||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
||||
model_type: Llama4ForConditionalGeneration
|
||||
processor_type: Llama4Processor
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
sequence_len: 4096
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm: true
|
||||
liger_layer_norm: true
|
||||
|
||||
llama4_linearized_experts: true # use Axolotl's customized model
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_modules:
|
||||
- self_attn.q_proj
|
||||
- self_attn.k_proj
|
||||
- self_attn.v_proj
|
||||
- self_attn.o_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
- vision_adapter.mlp.fc1
|
||||
- vision_adapter.mlp.fc2
|
||||
# - experts.gate_projs.[0-9]+$
|
||||
# - experts.up_projs.[0-9]+$
|
||||
# - experts.down_projs.[0-9]+$
|
||||
lora_modules_to_save:
|
||||
- lm_head
|
||||
- embed_tokens
|
||||
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- auto_wrap
|
||||
- full_shard
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_activation_checkpointing: true
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot|>
|
||||
@@ -12,7 +12,7 @@ liger-kernel==0.5.6
|
||||
packaging==23.2
|
||||
|
||||
peft==0.15.1
|
||||
transformers==4.51.1
|
||||
transformers==4.51.3
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.6.0
|
||||
datasets==3.5.0
|
||||
|
||||
2
setup.py
2
setup.py
@@ -67,7 +67,7 @@ def parse_requirements(extras_require_map):
|
||||
if (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post2")
|
||||
extras_require_map["vllm"] = ["vllm==0.8.1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.8.3"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
|
||||
156
src/axolotl/cli/delinearize_llama4.py
Normal file
156
src/axolotl/cli/delinearize_llama4.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
CLI tool to delinearize quantized/Linearized Llama-4 models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Generator, Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
def iter_convert_patched_to_hf(model_state_dict, num_experts) -> Generator:
|
||||
keys = list(model_state_dict.keys())
|
||||
for key in keys:
|
||||
if ".feed_forward.experts." not in key:
|
||||
yield key, model_state_dict[key]
|
||||
if ".feed_forward.experts.gate_projs" in key:
|
||||
# gate gets fused with up so skip the yield on this and we'll fuse it when asking for the up
|
||||
continue
|
||||
if ".feed_forward.experts.up_projs" in key:
|
||||
if ".feed_forward.experts.up_projs.0." in key:
|
||||
# handle the re-shape and fusing of gate and up, and conversion from linear to parameter
|
||||
prefix = key.split(".up_projs.0.")[0]
|
||||
key = f"{prefix}.gate_up_proj"
|
||||
# grab all the up_projs and gate_projs across all experts
|
||||
gate_stacked = torch.stack(
|
||||
[
|
||||
model_state_dict[
|
||||
f"{prefix}.gate_projs.{expert_idx}.weight"
|
||||
].transpose(0, 1)
|
||||
for expert_idx in range(num_experts)
|
||||
]
|
||||
)
|
||||
up_stacked = torch.stack(
|
||||
[
|
||||
model_state_dict[
|
||||
f"{prefix}.up_projs.{expert_idx}.weight"
|
||||
].transpose(0, 1)
|
||||
for expert_idx in range(num_experts)
|
||||
]
|
||||
)
|
||||
gate_up_proj = torch.cat((gate_stacked, up_stacked), dim=-1)
|
||||
del gate_stacked, up_stacked
|
||||
yield key, gate_up_proj
|
||||
else:
|
||||
del model_state_dict[key]
|
||||
continue
|
||||
if ".feed_forward.experts.down_projs" in key:
|
||||
if ".feed_forward.experts.down_projs.0." in key:
|
||||
# handle the re-shape and fusing of gate and up, and conversion from linear to parameter
|
||||
prefix = key.split(".down_projs.0.")[0]
|
||||
key = f"{prefix}.down_proj"
|
||||
# grab all the down_projs across all experts
|
||||
down_stacked = torch.stack(
|
||||
[
|
||||
model_state_dict[
|
||||
f"{prefix}.down_projs.{expert_idx}.weight"
|
||||
].transpose(0, 1)
|
||||
for expert_idx in range(num_experts)
|
||||
]
|
||||
)
|
||||
yield key, down_stacked
|
||||
else:
|
||||
del model_state_dict[key]
|
||||
continue
|
||||
|
||||
|
||||
def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
"""
|
||||
Convert a patched HF format Llama4 model (with separated projections)
|
||||
back to the original HF format (with fused projections).
|
||||
|
||||
Args:
|
||||
model: Path to the patched HF model
|
||||
output: Path to save the converted model
|
||||
"""
|
||||
print(f"Loading model from {model}")
|
||||
from axolotl.monkeypatch.models.llama4.modeling import (
|
||||
patch_llama4_linearized_modeling,
|
||||
)
|
||||
|
||||
unpatch_llama4 = patch_llama4_linearized_modeling()
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model, torch_dtype=torch.bfloat16
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model)
|
||||
processor.save_pretrained(output)
|
||||
|
||||
device = model_.device.type
|
||||
if device == "cuda":
|
||||
print(
|
||||
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
|
||||
)
|
||||
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
|
||||
model_config = model_.config
|
||||
config = model_.config.get_text_config()
|
||||
|
||||
# Get key dimensions from the config
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
num_experts = config.num_local_experts
|
||||
|
||||
print(
|
||||
f"Model dimensions: hidden_size={hidden_size}, intermediate_size={intermediate_size}, num_experts={num_experts}"
|
||||
)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output, exist_ok=True)
|
||||
|
||||
# Get state dict
|
||||
state_dict = model_.state_dict()
|
||||
del model_
|
||||
|
||||
# Create a new state dict for the converted model
|
||||
converted_state_dict = {}
|
||||
|
||||
# First, copy all keys that don't need modification
|
||||
for key, value in iter_convert_patched_to_hf(state_dict, num_experts):
|
||||
converted_state_dict[key] = value
|
||||
|
||||
del state_dict
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
print("State dict converted.")
|
||||
print(
|
||||
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
|
||||
)
|
||||
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
|
||||
# Ideally re-load the model import to load the converted state dict
|
||||
# Save the converted model
|
||||
with init_empty_weights():
|
||||
unpatch_llama4()
|
||||
model_ = Llama4ForConditionalGeneration(model_config)
|
||||
|
||||
if device == "cuda":
|
||||
print("State dict loaded into model.")
|
||||
print(
|
||||
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
|
||||
)
|
||||
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
|
||||
model_.load_state_dict(converted_state_dict, strict=False, assign=True)
|
||||
print(f"Saving converted model to {output}...")
|
||||
model_.save_pretrained(output)
|
||||
|
||||
print(f"Model successfully converted and saved to {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
@@ -330,6 +330,15 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||
do_vllm_serve(config, cli_args)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("model", type=click.Path(exists=True, path_type=str))
|
||||
@click.argument("output", type=click.Path(exists=False, path_type=str))
|
||||
def delinearize_llama4(model: str, output: str) -> None:
|
||||
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
|
||||
|
||||
do_delinearize_llama4(model, output)
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
LOG.warning("Error raised: %s", e)
|
||||
|
||||
model.generation_config.do_sample = True
|
||||
model.config.use_cache = True
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
|
||||
|
||||
@@ -165,7 +165,7 @@ def cce_forward(
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -254,7 +254,7 @@ def cce_forward_multimodal(
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
@@ -263,13 +263,13 @@ def cce_forward_multimodal(
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
original_inputs_embeds_shape = inputs_embeds.shape
|
||||
original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
|
||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
|
||||
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
|
||||
@@ -49,7 +49,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
|
||||
)
|
||||
sharded_sd[param_name] = sharded_tensor
|
||||
|
||||
model.load_state_dict(sharded_sd)
|
||||
model.load_state_dict(sharded_sd, assign=True)
|
||||
|
||||
|
||||
def patch_accelerate_fsdp_utils():
|
||||
|
||||
@@ -7,12 +7,11 @@ import torch
|
||||
import transformers
|
||||
|
||||
|
||||
def patch_flex_wrapper():
|
||||
def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
||||
# TODO remove this patch when transformers#37285 is merged and in a release
|
||||
is_torch_2_6 = torch.__version__.startswith("2.6")
|
||||
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
|
||||
|
||||
if not (is_torch_2_6 and is_transformers_below_4_51):
|
||||
if not is_torch_2_6:
|
||||
return
|
||||
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
@@ -32,17 +31,24 @@ def patch_flex_wrapper():
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def del_singleton(cls):
|
||||
cls._instance = None
|
||||
|
||||
@torch.compiler.disable(recursive=False)
|
||||
def __init__(self):
|
||||
def __init__(self, training):
|
||||
"""
|
||||
Initialize or update the singleton instance.
|
||||
"""
|
||||
if not self._is_flex_compiled:
|
||||
self.training = None
|
||||
if not self._is_flex_compiled or training != self.training:
|
||||
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
|
||||
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
||||
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
||||
self.training = training
|
||||
self._compiled_flex_attention = torch.compile(
|
||||
flex_attention,
|
||||
dynamic=False,
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
fullgraph=True,
|
||||
**flex_attn_compile_kwargs,
|
||||
)
|
||||
self._is_flex_compiled = True
|
||||
|
||||
@@ -50,15 +56,22 @@ def patch_flex_wrapper():
|
||||
return self._compiled_flex_attention
|
||||
|
||||
transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention
|
||||
setattr(
|
||||
sys.modules["transformers.integrations.flex_attention"],
|
||||
"WrappedFlexAttention",
|
||||
WrappedFlexAttention,
|
||||
)
|
||||
|
||||
|
||||
def patch_flex_make_mask():
|
||||
is_torch_2_6 = torch.__version__.startswith("2.6")
|
||||
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
|
||||
|
||||
if not (is_torch_2_6 and is_transformers_eq_4_51):
|
||||
if not is_torch_2_6:
|
||||
return
|
||||
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
)
|
||||
@@ -104,14 +117,16 @@ def patch_flex_make_mask():
|
||||
if not query_length:
|
||||
query_length = total_seq_len
|
||||
attention_mask_2d = torch.nn.functional.pad(
|
||||
attention_mask_2d, value=0, pad=(0, key_length)
|
||||
attention_mask_2d,
|
||||
value=0,
|
||||
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
|
||||
)
|
||||
device = attention_mask_2d.device
|
||||
document_ids = attention_mask_2d.clone()
|
||||
|
||||
if attention_chunk_size is not None:
|
||||
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
|
||||
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (
|
||||
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
|
||||
attention_chunk_size
|
||||
)
|
||||
|
||||
@@ -138,6 +153,18 @@ def patch_flex_make_mask():
|
||||
final_mask = causal_mask & padding_mask & document_mask
|
||||
return final_mask
|
||||
|
||||
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
"""
|
||||
Combines the chunk mask with the causal mask for chunked attention.
|
||||
"""
|
||||
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
|
||||
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
|
||||
return chunk_mask & causal_doc_mask
|
||||
|
||||
mask_mod_maybe_combined = (
|
||||
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
|
||||
)
|
||||
|
||||
if offsets is not None:
|
||||
q_offset = offsets[0]
|
||||
kv_offset = offsets[1]
|
||||
@@ -145,10 +172,10 @@ def patch_flex_make_mask():
|
||||
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
offset_q = q_idx + q_offset
|
||||
offset_kv = kv_idx + kv_offset
|
||||
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
|
||||
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
|
||||
|
||||
else:
|
||||
mask_mod = causal_mask_mod
|
||||
mask_mod = mask_mod_maybe_combined
|
||||
return create_block_causal_mask_flex(
|
||||
mask_mod=mask_mod,
|
||||
B=batch_size,
|
||||
@@ -160,11 +187,16 @@ def patch_flex_make_mask():
|
||||
)
|
||||
|
||||
for n in tuple(sys.modules):
|
||||
if ".modeling_" in n and "llama4" not in n:
|
||||
if ".modeling_" in n:
|
||||
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
|
||||
sys.modules[n].make_flex_block_causal_mask = (
|
||||
patched_make_flex_block_causal_mask
|
||||
)
|
||||
setattr(
|
||||
sys.modules[n],
|
||||
"make_flex_block_causal_mask",
|
||||
patched_make_flex_block_causal_mask,
|
||||
)
|
||||
|
||||
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
|
||||
patched_make_flex_block_causal_mask
|
||||
|
||||
@@ -93,9 +93,20 @@ def patch_llama4_linearized_modeling():
|
||||
"""
|
||||
from transformers.models.llama4 import modeling_llama4
|
||||
|
||||
old_lamma_4_text_experts = modeling_llama4.Llama4TextExperts
|
||||
modeling_llama4.Llama4TextExperts = Llama4TextExperts
|
||||
setattr(
|
||||
sys.modules["transformers.models.llama4"],
|
||||
"Llama4TextExperts",
|
||||
Llama4TextExperts,
|
||||
)
|
||||
|
||||
def unpatch():
|
||||
modeling_llama4.Llama4TextExperts = old_lamma_4_text_experts
|
||||
setattr(
|
||||
sys.modules["transformers.models.llama4"],
|
||||
"Llama4TextExperts",
|
||||
old_lamma_4_text_experts,
|
||||
)
|
||||
|
||||
return unpatch
|
||||
|
||||
78
src/axolotl/monkeypatch/trainer_eval_guard.py
Normal file
78
src/axolotl/monkeypatch/trainer_eval_guard.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
fix for FSDP2 evals when using torch.compile
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
model.eval()
|
||||
"""
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
if hasattr(model, "eval") and callable(model.eval):
|
||||
self.model.eval()
|
||||
"""
|
||||
|
||||
|
||||
def get_evaluation_loop_code() -> str:
|
||||
training_loop = inspect.getsource(Trainer.evaluation_loop)
|
||||
return training_loop
|
||||
|
||||
|
||||
def check_evaluation_loop_is_patchable() -> bool:
|
||||
eval_loop = get_evaluation_loop_code()
|
||||
eval_loop, _ = detab_code(eval_loop)
|
||||
return ORIGINAL_TRAINER_CODE in eval_loop
|
||||
|
||||
|
||||
def patch_evaluation_loop_for_fsdp2():
|
||||
"""
|
||||
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
|
||||
"""
|
||||
|
||||
try:
|
||||
evaluation_loop = get_evaluation_loop_code()
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
|
||||
evaluation_loop
|
||||
)
|
||||
evaluation_loop, _ = detab_code(evaluation_loop)
|
||||
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
|
||||
return
|
||||
|
||||
evaluation_loop = evaluation_loop.replace(
|
||||
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
|
||||
)
|
||||
evaluation_loop = evaluation_loop.replace(
|
||||
"def evaluation_loop(",
|
||||
"def _fixed_evaluation_loop(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.trainer
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.trainer):
|
||||
if item in evaluation_loop:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||
Trainer.evaluation_loop = ( # pylint: disable=protected-access
|
||||
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
@@ -81,6 +81,11 @@ def setup_model_and_tokenizer(
|
||||
# Apply freezing if specified
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
if any(
|
||||
any(embed in param for embed in ["lm_head", "embed_tokens"])
|
||||
for param in cfg.unfrozen_parameters
|
||||
):
|
||||
model.enable_input_require_grads()
|
||||
|
||||
return model, tokenizer, peft_config, processor
|
||||
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
module to freeze/unfreeze parameters by name
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def freeze_layers_except(model, regex_patterns):
|
||||
@@ -184,7 +185,7 @@ class LayerNamePattern:
|
||||
"""
|
||||
self.raw_pattern = pattern
|
||||
name_pattern, self.range = self._parse_pattern(pattern)
|
||||
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
|
||||
self.name_regex = re.compile(re.sub(r"\.(?!\+)", "\\.", name_pattern))
|
||||
|
||||
def match(self, name: str) -> bool:
|
||||
"""
|
||||
|
||||
@@ -542,6 +542,17 @@ class ModelLoader:
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
|
||||
|
||||
patch_accelerate_fsdp_utils()
|
||||
|
||||
if self.cfg.flex_attention:
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_make_mask,
|
||||
patch_flex_wrapper,
|
||||
)
|
||||
|
||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
patch_flex_make_mask()
|
||||
|
||||
# patch gemma3 conditional generation forward before loading plugins
|
||||
# as it could be overridden by plugins
|
||||
if self.cfg.model_config_type == "llama4":
|
||||
@@ -905,13 +916,6 @@ class ModelLoader:
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flex_attention"
|
||||
)
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_make_mask,
|
||||
patch_flex_wrapper,
|
||||
)
|
||||
|
||||
patch_flex_wrapper()
|
||||
patch_flex_make_mask()
|
||||
|
||||
elif self.cfg.flash_attention:
|
||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
|
||||
@@ -225,6 +225,7 @@ class AxolotlInputConfig(
|
||||
sdp_attention: bool | None = None
|
||||
s2_attention: bool | None = None
|
||||
flex_attention: bool | None = None
|
||||
flex_attn_compile_kwargs: dict[str, Any] | None = None
|
||||
flash_attention: bool | None = None
|
||||
flash_attn_cross_entropy: bool | None = None
|
||||
flash_attn_rms_norm: bool | None = None
|
||||
@@ -1276,11 +1277,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp") is not None
|
||||
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||
is_fsdp2 = (
|
||||
data.get("fsdp_config") is not None
|
||||
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
|
||||
)
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2:
|
||||
if is_fsdp:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP1."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
@@ -235,7 +236,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if cfg.model_config_type in ["mamba", "gemma3"]:
|
||||
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
|
||||
if drop_attn_mask:
|
||||
LOG.info("dropping attention_mask column")
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
@@ -625,6 +627,12 @@ def setup_trainer(
|
||||
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
||||
on the provided parameters.
|
||||
"""
|
||||
if (
|
||||
cfg.torch_compile
|
||||
and cfg.fsdp_config
|
||||
and str(cfg.fsdp_config.fsdp_version) == "2"
|
||||
):
|
||||
patch_evaluation_loop_for_fsdp2()
|
||||
if cfg.rl:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
||||
trainer_builder.model_ref = model_ref
|
||||
|
||||
@@ -56,11 +56,12 @@ class TestPackedFlex:
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"max_steps": 2,
|
||||
"use_tensorboard": True,
|
||||
"save_strategy": "no",
|
||||
}
|
||||
|
||||
@@ -177,6 +177,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_USE_V1": "0",
|
||||
}
|
||||
vllm_process_id = start_vllm(
|
||||
cfg.base_model,
|
||||
@@ -264,6 +265,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_USE_V1": "0",
|
||||
}
|
||||
vllm_process_id = start_vllm(
|
||||
cfg.base_model,
|
||||
|
||||
Reference in New Issue
Block a user