Compare commits

...

3 Commits

Author SHA1 Message Date
Wing Lian
1aec93cf9e add preliminary fp8 support 2025-04-06 23:54:50 -04:00
Wing Lian
37630fc6ef patches to make llama4 performant 2025-04-06 22:50:48 -04:00
Wing Lian
4b28b2a0b4 remove stray print, add llama4 chat template to schema, bump peft to 0.15.1 2025-04-06 19:59:48 -04:00
12 changed files with 371 additions and 9 deletions

View File

@@ -0,0 +1,75 @@
base_model: meta-llama/Llama-4-Scout-17B-16E
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
# torch_compile: true
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
# gradient_checkpointing: true
# gradient_checkpointing_kwargs:
# use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -6,12 +6,12 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.5
liger-kernel==0.5.6
# END section
packaging==23.2
peft==0.15.0
peft==0.15.1
transformers==4.51.0
tokenizers>=0.21.1
accelerate==1.6.0

View File

@@ -562,6 +562,19 @@ class AxolotlTrainer(
return res
def additional_accelerator_args(
self, fp8=None, **kwargs
): # pylint: disable=unused-argument
ret_kwargs = {}
if fp8:
from accelerate.utils import AORecipeKwargs
ret_kwargs["mixed_precision"] = "fp8"
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
return ret_kwargs
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.

View File

@@ -173,5 +173,17 @@ class LigerPlugin(BasePlugin):
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,
)
apply_liger_kernel_to_llama4(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type in ["deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")

View File

@@ -0,0 +1,171 @@
"""
Liger FLCE for llama4
"""
import sys
from typing import List, Optional, Tuple, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[
Union["Cache", List[torch.FloatTensor]] # noqa: F821
] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
raise Exception( # pylint: disable=broad-exception-raised
"Liger Kernel does not support pretraining_tp!!"
)
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**loss_kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**loss_kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_llama4(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.llama4.modeling_llama4 # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_llama4 = sys.modules["transformers.models.llama4.modeling_llama4"]
if rms_norm:
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
if glu_activation:
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
if layer_norm:
modeling_llama4.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_llama4.Llama4ForCausalLM.forward = lce_forward

View File

@@ -162,7 +162,6 @@ def patch_flex_make_mask():
for n in tuple(sys.modules):
if ".modeling_" in n and "llama4" not in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
print(n)
sys.modules[n].make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)

View File

@@ -0,0 +1,80 @@
"""
allow adding additional kwargs to Accelerator init
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__)
ORIGINAL_TRAINER_CODE = """
# create accelerator object
self.accelerator = Accelerator(**args)
"""
PATCHED_TRAINER_CODE = """
if hasattr(self, "additional_accelerator_args"):
additional_args = self.additional_accelerator_args(fp8=True, **args)
if additional_args:
args.update(additional_args)
# create accelerator object
self.accelerator = Accelerator(**args)
"""
def get_create_accelerate_code() -> str:
training_loop = inspect.getsource(Trainer.create_accelerator_and_postprocess)
return training_loop
def check_create_accelerate_code_is_patchable() -> bool:
create_code = get_create_accelerate_code()
create_code, _ = detab_code(create_code)
return ORIGINAL_TRAINER_CODE in create_code
def patch_create_accelerate_code_for_fp8():
"""
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
"""
try:
create_code = get_create_accelerate_code()
except OSError:
return
Trainer._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
create_code
)
create_code, _ = detab_code(create_code)
if ORIGINAL_TRAINER_CODE not in create_code:
return
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
create_code = create_code.replace(
"def create_accelerator_and_postprocess(",
"def fixed_create_accelerator_and_postprocess(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in create_code:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(create_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching create_accelerator_and_postprocess to allow for overrides")
Trainer.create_accelerator_and_postprocess = fixed_create_accelerator_and_postprocess # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821

View File

@@ -557,6 +557,14 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
# monkey patch to allow additional Accelerator init kwargs
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8()
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (
patch_fa_peft_integration,
@@ -988,10 +996,11 @@ class ModelLoader:
)
skip_move_to_device = True
elif (
self.model_config.model_type == "llama"
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# TODO do we need to open this up for all models?
if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in self.model_kwargs:

View File

@@ -169,6 +169,7 @@ class AxolotlInputConfig(
bf16: Literal["auto"] | bool | None = "auto"
fp16: bool | None = None
fp8: bool | None = None
bfloat16: bool | None = None # for non-AMP cases
float16: bool | None = None # for non-AMP cases
tf32: bool | None = None
@@ -464,9 +465,10 @@ class AxolotlInputConfig(
data.get("sample_packing")
and not data.get("flash_attention")
and not data.get("sdp_attention")
and not data.get("flex_attention")
):
LOG.warning(
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination."
)
return data

View File

@@ -26,6 +26,7 @@ class ChatTemplate(str, Enum):
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama4 = "llama4" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name

View File

@@ -582,7 +582,9 @@ def prepare_optim_env(cfg):
setup_torch_compile_env(cfg)
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
if cfg.fp8:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
elif (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"

View File

@@ -500,9 +500,7 @@ class TestMultiGPULlama:
],
"fsdp_config": {
"fsdp_version": 2,
"fsdp_forward_prefetch": True,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": True,
# "fsdp_forward_prefetch": True, # not yet implemented in accelerate
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",