Compare commits

..

2 Commits

Author SHA1 Message Date
NanoCode012
4581d6a8de fix: accidentally reassigning tensor to weight 2025-04-09 13:45:29 +07:00
NanoCode012
1a85fab2ca fix: lm_head is a view or related view modified 2025-04-08 17:32:28 +07:00
4 changed files with 37 additions and 145 deletions

View File

@@ -1,10 +0,0 @@
# Llama 4 by Meta AI
## Available Examples
### Llama 4 Scout 17Bx16Experts (109B)
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
Our Single GPU implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second.

View File

@@ -1,28 +1,13 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
base_model: meta-llama/Llama-4-Scout-17B-16E
model_type: Llama4ForConditionalGeneration
processor_type: Llama4Processor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# torch_compile: true
sequence_len: 4096
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true # use Axolotl's customized model
load_in_4bit: true
adapter: qlora
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_modules:
@@ -30,59 +15,60 @@ lora_target_modules:
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
- vision_adapter.mlp.fc1
- vision_adapter.mlp.fc2
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:1%]
field_messages: messages
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
# gradient_checkpointing: true
# gradient_checkpointing_kwargs:
# use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_steps: 100
evals_per_epoch: 1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>

View File

@@ -1,86 +0,0 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
# - lm_head
# - embed_tokens
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096 # up to 8k will work on a single H100
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
gradient_checkpointing: offload
gradient_checkpointing_kwargs:
use_reentrant: false
warmup_steps: 20
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -26,6 +26,7 @@ from transformers.utils import (
)
_PATCH_OPTS: PatchOptions | None = None
RESET_LM_HEAD = True
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@@ -308,7 +309,16 @@ def cce_forward_multimodal(
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
# TODO: check if need to handle attention_mask
# reset lm head gradient on first pass.
# linear model has some lm_head weight issue
# see https://github.com/axolotl-ai-cloud/axolotl/pull/2505
global RESET_LM_HEAD # pylint: disable=global-statement
if RESET_LM_HEAD:
RESET_LM_HEAD = False
self.language_model.lm_head.weight.requires_grad_(False) # Detach
self.language_model.lm_head.weight.requires_grad_(True) # Reattach
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
@@ -373,11 +383,7 @@ def patch_llama4_text(
return maybe_model
setattr(
modeling_llama4.Llama4ForCausalLM,
"forward",
cce_forward,
)
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
return None
@@ -403,12 +409,8 @@ def patch_llama4(
)
return maybe_model
setattr(
modeling_llama4.Llama4ForConditionalGeneration,
"forward",
cce_forward_multimodal,
)
modeling_llama4.Llama4ForConditionalGeneration.forward = cce_forward_multimodal
# patch the causal language model
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
return None