Compare commits
2 Commits
maverick-e
...
transforme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
37a66e6866 | ||
|
|
9f69597a5f |
@@ -1,16 +0,0 @@
|
|||||||
# Llama 4 by Meta AI
|
|
||||||
|
|
||||||
## Available Examples
|
|
||||||
|
|
||||||
### Llama 4 Scout 17Bx16Experts (109B)
|
|
||||||
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
|
|
||||||
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
|
|
||||||
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
|
|
||||||
|
|
||||||
Our Single H100 implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-sft/runs/zic56rhd)
|
|
||||||
|
|
||||||
### Llama 4 Maverick 17Bx128Experts (400B)
|
|
||||||
|
|
||||||
- [Text Multi GPU QLoRA w/FSDP1](./maverick-qlora-fsdp1.yaml)
|
|
||||||
|
|
||||||
Our 4xH100 implementation for Llama 4 Maverick uses 79.5GB VRAM/GPU for post-training with 4k context length @ 206 tokens/second. [WandB logs here.](https://wandb.ai/axolotl-ai/llama-sft/runs/siyvwuxc?nw=nwuserwinglian)
|
|
||||||
@@ -1,20 +1,13 @@
|
|||||||
base_model: axolotl-quants/Llama-4-Maverick-17B-128E-Linearized-bnb-nf4-bf16
|
base_model: meta-llama/Llama-4-Scout-17B-16E
|
||||||
model_type: Llama4ForConditionalGeneration
|
model_type: Llama4ForConditionalGeneration
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
strict: false
|
strict: false
|
||||||
|
|
||||||
plugins:
|
# torch_compile: true
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_glu_activation: true
|
adapter: lora
|
||||||
liger_rms_norm: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
|
|
||||||
llama4_linearized_experts: true
|
|
||||||
load_in_4bit: true
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 64
|
lora_alpha: 64
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
@@ -22,15 +15,9 @@ lora_target_modules:
|
|||||||
- self_attn.k_proj
|
- self_attn.k_proj
|
||||||
- self_attn.v_proj
|
- self_attn.v_proj
|
||||||
- self_attn.o_proj
|
- self_attn.o_proj
|
||||||
- shared_expert.gate_proj
|
|
||||||
- shared_expert.up_proj
|
|
||||||
- shared_expert.down_proj
|
|
||||||
# - experts.gate_projs.[0-9]+$
|
|
||||||
# - experts.up_projs.[0-9]+$
|
|
||||||
# - experts.down_projs.[0-9]+$
|
|
||||||
lora_modules_to_save:
|
lora_modules_to_save:
|
||||||
# - lm_head
|
- lm_head
|
||||||
# - embed_tokens
|
- embed_tokens
|
||||||
|
|
||||||
chat_template: llama4
|
chat_template: llama4
|
||||||
datasets:
|
datasets:
|
||||||
@@ -53,37 +40,36 @@ pad_to_sequence_len: true
|
|||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
optimizer: adamw_torch_fused
|
optimizer: adamw_torch_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 1e-4
|
learning_rate: 2e-5
|
||||||
|
|
||||||
bf16: true
|
bf16: true
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
|
# gradient_checkpointing: true
|
||||||
|
# gradient_checkpointing_kwargs:
|
||||||
|
# use_reentrant: false
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
gradient_checkpointing: offload
|
warmup_steps: 100
|
||||||
gradient_checkpointing_kwargs:
|
evals_per_epoch: 2
|
||||||
use_reentrant: false
|
|
||||||
|
|
||||||
warmup_steps: 20
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
- auto_wrap
|
- auto_wrap
|
||||||
- full_shard
|
- full_shard
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
fsdp_version: 2
|
||||||
fsdp_limit_all_gathers: true
|
fsdp_offload_params: false
|
||||||
fsdp_sync_module_states: true
|
|
||||||
fsdp_offload_params: true
|
|
||||||
fsdp_use_orig_params: false
|
|
||||||
fsdp_cpu_ram_efficient_loading: true
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
||||||
|
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
fsdp_activation_checkpointing: true
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
eos_token: <|eot|>
|
eos_token: <|eot|>
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
|
||||||
model_type: Llama4ForConditionalGeneration
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
|
|
||||||
llama4_linearized_experts: true
|
|
||||||
load_in_4bit: true
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 64
|
|
||||||
lora_target_modules:
|
|
||||||
- self_attn.q_proj
|
|
||||||
- self_attn.k_proj
|
|
||||||
- self_attn.v_proj
|
|
||||||
- self_attn.o_proj
|
|
||||||
- shared_expert.gate_proj
|
|
||||||
- shared_expert.up_proj
|
|
||||||
- shared_expert.down_proj
|
|
||||||
# - experts.gate_projs.[0-9]+$
|
|
||||||
# - experts.up_projs.[0-9]+$
|
|
||||||
# - experts.down_projs.[0-9]+$
|
|
||||||
lora_modules_to_save:
|
|
||||||
# - lm_head
|
|
||||||
# - embed_tokens
|
|
||||||
|
|
||||||
lora_mlp_kernel: true
|
|
||||||
lora_qkv_kernel: true
|
|
||||||
lora_o_kernel: true
|
|
||||||
|
|
||||||
chat_template: llama4
|
|
||||||
datasets:
|
|
||||||
- path: mlabonne/FineTome-100k
|
|
||||||
type: chat_template
|
|
||||||
split: train[:20%]
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
sequence_len: 4096 # up to 8k will work on a single H100
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_4bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 1e-4
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
gradient_checkpointing: offload
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
|
|
||||||
warmup_steps: 20
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|finetune_right_pad_id|>
|
|
||||||
eos_token: <|eot|>
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
|
||||||
model_type: Llama4ForConditionalGeneration
|
|
||||||
processor_type: Llama4Processor
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
|
|
||||||
llama4_linearized_experts: true # use Axolotl's customized model
|
|
||||||
load_in_4bit: true
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 64
|
|
||||||
lora_target_modules:
|
|
||||||
- self_attn.q_proj
|
|
||||||
- self_attn.k_proj
|
|
||||||
- self_attn.v_proj
|
|
||||||
- self_attn.o_proj
|
|
||||||
- shared_expert.gate_proj
|
|
||||||
- shared_expert.up_proj
|
|
||||||
- shared_expert.down_proj
|
|
||||||
- vision_adapter.mlp.fc1
|
|
||||||
- vision_adapter.mlp.fc2
|
|
||||||
# - experts.gate_projs.[0-9]+$
|
|
||||||
# - experts.up_projs.[0-9]+$
|
|
||||||
# - experts.down_projs.[0-9]+$
|
|
||||||
lora_modules_to_save:
|
|
||||||
- lm_head
|
|
||||||
- embed_tokens
|
|
||||||
|
|
||||||
chat_template: llama4
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_4bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 100
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
- auto_wrap
|
|
||||||
- full_shard
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
|
||||||
fsdp_limit_all_gathers: true
|
|
||||||
fsdp_sync_module_states: true
|
|
||||||
fsdp_offload_params: true
|
|
||||||
fsdp_use_orig_params: false
|
|
||||||
fsdp_cpu_ram_efficient_loading: true
|
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
|
||||||
fsdp_activation_checkpointing: true
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|finetune_right_pad_id|>
|
|
||||||
eos_token: <|eot|>
|
|
||||||
@@ -185,7 +185,5 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
else:
|
elif cfg.model_config_type in ["deepseek_v3"]:
|
||||||
logging.warning(
|
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
||||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ Liger FLCE for llama4
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -159,16 +158,7 @@ def apply_liger_kernel_to_llama4(
|
|||||||
if rms_norm:
|
if rms_norm:
|
||||||
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
||||||
if glu_activation:
|
if glu_activation:
|
||||||
|
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
||||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
|
||||||
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
|
|
||||||
# clone config to avoid modifying the original
|
|
||||||
config = deepcopy(config)
|
|
||||||
if intermediate_size:
|
|
||||||
setattr(config, "intermediate_size", intermediate_size)
|
|
||||||
return LigerSwiGLUMLP(config, **kwargs)
|
|
||||||
|
|
||||||
modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper
|
|
||||||
if layer_norm:
|
if layer_norm:
|
||||||
modeling_llama4.nn.LayerNorm = LigerLayerNorm
|
modeling_llama4.nn.LayerNorm = LigerLayerNorm
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user