llama4 support (#2493)

* llama4 support

* add xet support [skip ci]

* be flexible on transformers version and skip test on version

* don't use deepspeed for the fix_untrained_tokens test

* reordering to trigger torch 2.6.0 tests first

* slightly smaller train set

* use 4.51.0 for now

* remove stray print, add llama4 chat template to schema, bump peft to 0.15.1

* patches to make llama4 performant

* add preliminary fp8 support
This commit is contained in:
Wing Lian
2025-04-07 10:49:15 -04:00
committed by GitHub
parent 5f4af3665d
commit 8bbad21bfd
17 changed files with 409 additions and 34 deletions

View File

@@ -24,6 +24,13 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -38,13 +45,6 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -211,7 +211,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
steps:
@@ -258,7 +258,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.5.1
num_gpus: 1
axolotl_extras: vllm
steps:

View File

@@ -0,0 +1,75 @@
base_model: meta-llama/Llama-4-Scout-17B-16E
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
# torch_compile: true
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
# gradient_checkpointing: true
# gradient_checkpointing_kwargs:
# use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -6,18 +6,19 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.5
liger-kernel==0.5.6
# END section
packaging==23.2
peft==0.15.0
peft==0.15.1
transformers==4.51.0
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.0
deepspeed>=0.15.4
trl==0.16.1
hf_xet==1.0.0
optimum==1.16.2
hf_transfer

View File

@@ -562,6 +562,19 @@ class AxolotlTrainer(
return res
def additional_accelerator_args(
self, fp8=None, **kwargs
): # pylint: disable=unused-argument
ret_kwargs = {}
if fp8:
from accelerate.utils import AORecipeKwargs
ret_kwargs["mixed_precision"] = "fp8"
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
return ret_kwargs
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.

View File

@@ -173,5 +173,17 @@ class LigerPlugin(BasePlugin):
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,
)
apply_liger_kernel_to_llama4(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type in ["deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")

View File

@@ -0,0 +1,171 @@
"""
Liger FLCE for llama4
"""
import sys
from typing import List, Optional, Tuple, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[
Union["Cache", List[torch.FloatTensor]] # noqa: F821
] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
raise Exception( # pylint: disable=broad-exception-raised
"Liger Kernel does not support pretraining_tp!!"
)
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**loss_kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**loss_kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_llama4(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.llama4.modeling_llama4 # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_llama4 = sys.modules["transformers.models.llama4.modeling_llama4"]
if rms_norm:
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
if glu_activation:
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
if layer_norm:
modeling_llama4.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_llama4.Llama4ForCausalLM.forward = lce_forward

View File

@@ -162,7 +162,6 @@ def patch_flex_make_mask():
for n in tuple(sys.modules):
if ".modeling_" in n and "llama4" not in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
print(n)
sys.modules[n].make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)

View File

@@ -13,6 +13,7 @@ from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mllama_text_model",
"llama",
"llama4",
"mistral",
"mixtral",
"qwen2",

View File

@@ -0,0 +1,80 @@
"""
allow adding additional kwargs to Accelerator init
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__)
ORIGINAL_TRAINER_CODE = """
# create accelerator object
self.accelerator = Accelerator(**args)
"""
PATCHED_TRAINER_CODE = """
if hasattr(self, "additional_accelerator_args"):
additional_args = self.additional_accelerator_args(fp8=True, **args)
if additional_args:
args.update(additional_args)
# create accelerator object
self.accelerator = Accelerator(**args)
"""
def get_create_accelerate_code() -> str:
training_loop = inspect.getsource(Trainer.create_accelerator_and_postprocess)
return training_loop
def check_create_accelerate_code_is_patchable() -> bool:
create_code = get_create_accelerate_code()
create_code, _ = detab_code(create_code)
return ORIGINAL_TRAINER_CODE in create_code
def patch_create_accelerate_code_for_fp8():
"""
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
"""
try:
create_code = get_create_accelerate_code()
except OSError:
return
Trainer._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
create_code
)
create_code, _ = detab_code(create_code)
if ORIGINAL_TRAINER_CODE not in create_code:
return
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
create_code = create_code.replace(
"def create_accelerator_and_postprocess(",
"def fixed_create_accelerator_and_postprocess(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in create_code:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(create_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching create_accelerator_and_postprocess to allow for overrides")
Trainer.create_accelerator_and_postprocess = fixed_create_accelerator_and_postprocess # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821

File diff suppressed because one or more lines are too long

View File

@@ -557,6 +557,14 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
# monkey patch to allow additional Accelerator init kwargs
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8()
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (
patch_fa_peft_integration,
@@ -988,10 +996,11 @@ class ModelLoader:
)
skip_move_to_device = True
elif (
self.model_config.model_type == "llama"
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# TODO do we need to open this up for all models?
if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in self.model_kwargs:

View File

@@ -169,6 +169,7 @@ class AxolotlInputConfig(
bf16: Literal["auto"] | bool | None = "auto"
fp16: bool | None = None
fp8: bool | None = None
bfloat16: bool | None = None # for non-AMP cases
float16: bool | None = None # for non-AMP cases
tf32: bool | None = None
@@ -464,9 +465,10 @@ class AxolotlInputConfig(
data.get("sample_packing")
and not data.get("flash_attention")
and not data.get("sdp_attention")
and not data.get("flex_attention")
):
LOG.warning(
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination."
)
return data

View File

@@ -26,6 +26,7 @@ class ChatTemplate(str, Enum):
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama4 = "llama4" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name

View File

@@ -582,7 +582,9 @@ def prepare_optim_env(cfg):
setup_torch_compile_env(cfg)
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
if cfg.fp8:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
elif (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"

View File

@@ -7,9 +7,11 @@ import os
from pathlib import Path
import pytest
import transformers
import yaml
from accelerate.test_utils import execute_subprocess_async
from huggingface_hub import snapshot_download
from packaging import version
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
@@ -28,6 +30,10 @@ def download_model():
snapshot_download("HuggingFaceTB/SmolLM2-135M")
def transformers_version_eq(required_version):
return version.parse(transformers.__version__) == version.parse(required_version)
class TestMultiGPULlama:
"""
Test case for Llama models using LoRA
@@ -56,7 +62,7 @@ class TestMultiGPULlama:
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
@@ -108,7 +114,7 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.01,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
@@ -116,6 +122,7 @@ class TestMultiGPULlama:
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:20%]",
},
],
"num_epochs": 1,
@@ -193,7 +200,7 @@ class TestMultiGPULlama:
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
@@ -390,7 +397,7 @@ class TestMultiGPULlama:
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
@@ -403,7 +410,7 @@ class TestMultiGPULlama:
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
@@ -493,9 +500,7 @@ class TestMultiGPULlama:
],
"fsdp_config": {
"fsdp_version": 2,
"fsdp_forward_prefetch": True,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": True,
# "fsdp_forward_prefetch": True, # not yet implemented in accelerate
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
@@ -551,7 +556,7 @@ class TestMultiGPULlama:
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
@@ -565,7 +570,7 @@ class TestMultiGPULlama:
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
@@ -612,8 +617,11 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.skip(
reason="ds-zero3 broken in main until transformers#37281 resolved"
# TODO: remove skip once deepspeed regression is fixed
# see https://github.com/huggingface/transformers/pull/37324
@pytest.mark.skipif(
transformers_version_eq("4.51.0"),
reason="zero3 is not supported with transformers==4.51.0",
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
@@ -651,7 +659,7 @@ class TestMultiGPULlama:
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
@@ -724,7 +732,7 @@ class TestMultiGPULlama:
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
@@ -797,7 +805,7 @@ class TestMultiGPULlama:
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
@@ -885,7 +893,7 @@ class TestMultiGPULlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
# "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
}
)

View File

@@ -31,7 +31,7 @@ class TestMultiGPURay:
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
@@ -94,8 +94,8 @@ class TestMultiGPURay:
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},