Compare commits
2 Commits
llama-4-ex
...
fix/cce-li
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4581d6a8de | ||
|
|
1a85fab2ca |
@@ -1,10 +0,0 @@
|
|||||||
# Llama 4 by Meta AI
|
|
||||||
|
|
||||||
## Available Examples
|
|
||||||
|
|
||||||
### Llama 4 Scout 17Bx16Experts (109B)
|
|
||||||
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
|
|
||||||
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
|
|
||||||
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
|
|
||||||
|
|
||||||
Our Single GPU implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second.
|
|
||||||
@@ -1,28 +1,13 @@
|
|||||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
base_model: meta-llama/Llama-4-Scout-17B-16E
|
||||||
model_type: Llama4ForConditionalGeneration
|
model_type: Llama4ForConditionalGeneration
|
||||||
processor_type: Llama4Processor
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
# hub_model_id: username/custom_model_name
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
strict: false
|
strict: false
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
# torch_compile: true
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
sequence_len: 4096
|
adapter: lora
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
|
|
||||||
llama4_linearized_experts: true # use Axolotl's customized model
|
|
||||||
load_in_4bit: true
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 64
|
lora_alpha: 64
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
@@ -30,59 +15,60 @@ lora_target_modules:
|
|||||||
- self_attn.k_proj
|
- self_attn.k_proj
|
||||||
- self_attn.v_proj
|
- self_attn.v_proj
|
||||||
- self_attn.o_proj
|
- self_attn.o_proj
|
||||||
- shared_expert.gate_proj
|
|
||||||
- shared_expert.up_proj
|
|
||||||
- shared_expert.down_proj
|
|
||||||
- vision_adapter.mlp.fc1
|
|
||||||
- vision_adapter.mlp.fc2
|
|
||||||
# - experts.gate_projs.[0-9]+$
|
|
||||||
# - experts.up_projs.[0-9]+$
|
|
||||||
# - experts.down_projs.[0-9]+$
|
|
||||||
lora_modules_to_save:
|
lora_modules_to_save:
|
||||||
- lm_head
|
- lm_head
|
||||||
- embed_tokens
|
- embed_tokens
|
||||||
|
|
||||||
chat_template: llama4
|
chat_template: llama4
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: mlabonne/FineTome-100k
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:20%]
|
||||||
field_messages: messages
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
optimizer: adamw_torch_4bit
|
optimizer: adamw_torch_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 2e-5
|
learning_rate: 2e-5
|
||||||
|
|
||||||
bf16: true
|
bf16: true
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
|
# gradient_checkpointing: true
|
||||||
|
# gradient_checkpointing_kwargs:
|
||||||
|
# use_reentrant: false
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 2
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
- auto_wrap
|
- auto_wrap
|
||||||
- full_shard
|
- full_shard
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
fsdp_version: 2
|
||||||
fsdp_limit_all_gathers: true
|
fsdp_offload_params: false
|
||||||
fsdp_sync_module_states: true
|
|
||||||
fsdp_offload_params: true
|
|
||||||
fsdp_use_orig_params: false
|
|
||||||
fsdp_cpu_ram_efficient_loading: true
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
||||||
|
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
fsdp_activation_checkpointing: true
|
fsdp_activation_checkpointing: true
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|finetune_right_pad_id|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
|
||||||
model_type: Llama4ForConditionalGeneration
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
|
|
||||||
llama4_linearized_experts: true
|
|
||||||
load_in_4bit: true
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 64
|
|
||||||
lora_target_modules:
|
|
||||||
- self_attn.q_proj
|
|
||||||
- self_attn.k_proj
|
|
||||||
- self_attn.v_proj
|
|
||||||
- self_attn.o_proj
|
|
||||||
- shared_expert.gate_proj
|
|
||||||
- shared_expert.up_proj
|
|
||||||
- shared_expert.down_proj
|
|
||||||
# - experts.gate_projs.[0-9]+$
|
|
||||||
# - experts.up_projs.[0-9]+$
|
|
||||||
# - experts.down_projs.[0-9]+$
|
|
||||||
lora_modules_to_save:
|
|
||||||
# - lm_head
|
|
||||||
# - embed_tokens
|
|
||||||
|
|
||||||
lora_mlp_kernel: true
|
|
||||||
lora_qkv_kernel: true
|
|
||||||
lora_o_kernel: true
|
|
||||||
|
|
||||||
chat_template: llama4
|
|
||||||
datasets:
|
|
||||||
- path: mlabonne/FineTome-100k
|
|
||||||
type: chat_template
|
|
||||||
split: train[:20%]
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
sequence_len: 4096 # up to 8k will work on a single H100
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_4bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 1e-4
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
gradient_checkpointing: offload
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
|
|
||||||
warmup_steps: 20
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|finetune_right_pad_id|>
|
|
||||||
eos_token: <|eot|>
|
|
||||||
@@ -26,6 +26,7 @@ from transformers.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
RESET_LM_HEAD = True
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
||||||
@@ -308,7 +309,16 @@ def cce_forward_multimodal(
|
|||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
assert labels is not None
|
assert labels is not None
|
||||||
# TODO: check if need to handle attention_mask
|
|
||||||
|
# reset lm head gradient on first pass.
|
||||||
|
# linear model has some lm_head weight issue
|
||||||
|
# see https://github.com/axolotl-ai-cloud/axolotl/pull/2505
|
||||||
|
global RESET_LM_HEAD # pylint: disable=global-statement
|
||||||
|
if RESET_LM_HEAD:
|
||||||
|
RESET_LM_HEAD = False
|
||||||
|
self.language_model.lm_head.weight.requires_grad_(False) # Detach
|
||||||
|
self.language_model.lm_head.weight.requires_grad_(True) # Reattach
|
||||||
|
|
||||||
loss = apply_lce(
|
loss = apply_lce(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.language_model.lm_head.weight,
|
self.language_model.lm_head.weight,
|
||||||
@@ -373,11 +383,7 @@ def patch_llama4_text(
|
|||||||
|
|
||||||
return maybe_model
|
return maybe_model
|
||||||
|
|
||||||
setattr(
|
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
|
||||||
modeling_llama4.Llama4ForCausalLM,
|
|
||||||
"forward",
|
|
||||||
cce_forward,
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -403,12 +409,8 @@ def patch_llama4(
|
|||||||
)
|
)
|
||||||
return maybe_model
|
return maybe_model
|
||||||
|
|
||||||
setattr(
|
modeling_llama4.Llama4ForConditionalGeneration.forward = cce_forward_multimodal
|
||||||
modeling_llama4.Llama4ForConditionalGeneration,
|
|
||||||
"forward",
|
|
||||||
cce_forward_multimodal,
|
|
||||||
)
|
|
||||||
|
|
||||||
# patch the causal language model
|
# patch the causal language model
|
||||||
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
|
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user