Compare commits

...

4 Commits

Author SHA1 Message Date
Wing Lian
747dafe5b2 Add Llama4 maverick examples 2025-04-09 08:27:46 -04:00
NanoCode012
f85861a0b2 fix: liger swiglu for llama4 (#2504)
* fix: liger swiglu for llama4

* feat: add liger to deepseek v3

* fix: unpack not found

* fix: spelling

* fix: comment out deepseek v3

* fix: retest deepseek

* fix: map glu

* fix: patch model forward

* chore: add temp code to save

* fix: remove deepseek to move into separate PR
2025-04-09 02:53:17 -04:00
Wing Lian
630e40dd13 upgrade transformers to 4.51.1 (#2508)
* upgrade transformers to 4.51.1

* multigpu longer timeout
2025-04-09 02:53:00 -04:00
Wing Lian
bf9efe2a09 [llama4] fix the mm yaml, add scout single gpu yaml (#2510)
* [llama4] fix the mm yaml, add scout single gpu yaml

* add README for llama4

* rename to specify fsdp
2025-04-09 02:52:45 -04:00
8 changed files with 242 additions and 25 deletions

View File

@@ -68,7 +68,7 @@ def run_cmd(cmd: str, run_folder: str):
@app.function( @app.function(
image=cicd_image, image=cicd_image,
gpu=GPU_CONFIG, gpu=GPU_CONFIG,
timeout=60 * 60, timeout=90 * 60,
cpu=8.0, cpu=8.0,
memory=131072 * N_GPUS, memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG, volumes=VOLUME_CONFIG,

View File

@@ -0,0 +1,16 @@
# Llama 4 by Meta AI
## Available Examples
### Llama 4 Scout 17Bx16Experts (109B)
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
Our Single H100 implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-sft/runs/zic56rhd)
### Llama 4 Maverick 17Bx128Experts (400B)
- [Text Multi GPU QLoRA w/FSDP1](./maverick-qlora-fsdp1.yaml)
Our 4xH100 implementation for Llama 4 Maverick uses 79.5GB VRAM/GPU for post-training with 4k context length @ 206 tokens/second. [WandB logs here.](https://wandb.ai/axolotl-ai/llama-sft/runs/siyvwuxc?nw=nwuserwinglian)

View File

@@ -1,13 +1,20 @@
base_model: meta-llama/Llama-4-Scout-17B-16E base_model: axolotl-quants/Llama-4-Maverick-17B-128E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
strict: false strict: false
# torch_compile: true plugins:
- axolotl.integrations.liger.LigerPlugin
adapter: lora liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 64 lora_alpha: 64
lora_target_modules: lora_target_modules:
@@ -15,9 +22,15 @@ lora_target_modules:
- self_attn.k_proj - self_attn.k_proj
- self_attn.v_proj - self_attn.v_proj
- self_attn.o_proj - self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save: lora_modules_to_save:
- lm_head # - lm_head
- embed_tokens # - embed_tokens
chat_template: llama4 chat_template: llama4
datasets: datasets:
@@ -40,36 +53,37 @@ pad_to_sequence_len: true
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1
optimizer: adamw_torch_8bit optimizer: adamw_torch_fused
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 2e-5 learning_rate: 1e-4
bf16: true bf16: true
tf32: true tf32: true
# gradient_checkpointing: true
# gradient_checkpointing_kwargs:
# use_reentrant: false
logging_steps: 1 logging_steps: 1
flash_attention: true flash_attention: true
warmup_steps: 100 gradient_checkpointing: offload
evals_per_epoch: 2 gradient_checkpointing_kwargs:
use_reentrant: false
warmup_steps: 20
evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
- auto_wrap - auto_wrap
- full_shard - full_shard
fsdp_config: fsdp_config:
fsdp_version: 2 fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_offload_params: false fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer fsdp_state_dict_type: FULL_STATE_DICT
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|> eos_token: <|eot|>

View File

@@ -0,0 +1,86 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
# - lm_head
# - embed_tokens
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096 # up to 8k will work on a single H100
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
gradient_checkpointing: offload
gradient_checkpointing_kwargs:
use_reentrant: false
warmup_steps: 20
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -0,0 +1,89 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
processor_type: Llama4Processor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
sequence_len: 4096
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true # use Axolotl's customized model
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
- vision_adapter.mlp.fc1
- vision_adapter.mlp.fc2
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_steps: 100
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -12,7 +12,7 @@ liger-kernel==0.5.6
packaging==23.2 packaging==23.2
peft==0.15.1 peft==0.15.1
transformers==4.51.0 transformers==4.51.1
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.6.0 accelerate==1.6.0
datasets==3.5.0 datasets==3.5.0

View File

@@ -185,5 +185,7 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm, rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm, layer_norm=cfg.liger_layer_norm,
) )
elif cfg.model_config_type in ["deepseek_v3"]: else:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}") logging.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)

View File

@@ -3,6 +3,7 @@ Liger FLCE for llama4
""" """
import sys import sys
from copy import deepcopy
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@@ -158,7 +159,16 @@ def apply_liger_kernel_to_llama4(
if rms_norm: if rms_norm:
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
if glu_activation: if glu_activation:
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
return LigerSwiGLUMLP(config, **kwargs)
modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper
if layer_norm: if layer_norm:
modeling_llama4.nn.LayerNorm = LigerLayerNorm modeling_llama4.nn.LayerNorm = LigerLayerNorm