diff --git a/examples/llama4/scout-lora.yaml b/examples/llama4/scout-lora.yaml new file mode 100644 index 000000000..26534b560 --- /dev/null +++ b/examples/llama4/scout-lora.yaml @@ -0,0 +1,75 @@ +base_model: meta-llama/Llama-4-Scout-17B-16E +model_type: Llama4ForConditionalGeneration + # Automatically upload checkpoint and final model to HF + # hub_model_id: username/custom_model_name + +strict: false + + # torch_compile: true + +adapter: lora +lora_r: 32 +lora_alpha: 64 +lora_target_modules: + - self_attn.q_proj + - self_attn.k_proj + - self_attn.v_proj + - self_attn.o_proj +lora_modules_to_save: + - lm_head + - embed_tokens + +chat_template: llama4 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +# gradient_checkpointing: true +# gradient_checkpointing_kwargs: +# use_reentrant: false +logging_steps: 1 +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 2 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - auto_wrap + - full_shard +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD + fsdp_reshard_after_forward: true + fsdp_activation_checkpointing: true +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot|> diff --git a/requirements.txt b/requirements.txt index 6dac24f27..f2b2df5fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 autoawq==0.2.7.post3 -liger-kernel==0.5.5 +liger-kernel==0.5.6 # END section packaging==23.2 diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 9fed78eb7..dd50f8ce7 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -562,6 +562,16 @@ class AxolotlTrainer( return res + def override_accelerator_args(self, **kwargs): # pylint: disable=unused-argument + ret_kwargs = {} + if os.environ.get("ACCELERATE_MIXED_PRECISION") == "fp8": + from accelerate.utils import AORecipeKwargs + + ret_kwargs["mixed_precision"] = "fp8" + ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()] + + return ret_kwargs + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: """ Log `logs` on the various objects watching training, including stored metrics. diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 82a46d9cf..8d737175e 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -173,5 +173,17 @@ class LigerPlugin(BasePlugin): raise NotImplementedError( "Fused linear cross entropy is not yet supported for Gemma3." ) + elif cfg.model_config_type == "llama4": + from axolotl.integrations.liger.models.llama4 import ( + apply_liger_kernel_to_llama4, + ) + + apply_liger_kernel_to_llama4( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) elif cfg.model_config_type in ["deepseek_v3"]: raise ValueError(f"Unsupported model config type: {cfg.model_config_type}") diff --git a/src/axolotl/integrations/liger/models/llama4.py b/src/axolotl/integrations/liger/models/llama4.py new file mode 100644 index 000000000..ee7f226cd --- /dev/null +++ b/src/axolotl/integrations/liger/models/llama4.py @@ -0,0 +1,171 @@ +""" +Liger FLCE for llama4 +""" + +import sys +from typing import List, Optional, Tuple, Union + +import torch +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from transformers.modeling_outputs import CausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[ + Union["Cache", List[torch.FloatTensor]] # noqa: F821 + ] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + """ + + print("=" * 30 + " lce_forward " + "=" * 30) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: + raise Exception( # pylint: disable=broad-exception-raised + "Liger Kernel does not support pretraining_tp!!" + ) + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + loss = LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + hidden_size=self.config.hidden_size, + **loss_kwargs, + ) + + else: # if in inference mode materialize logits + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def apply_liger_kernel_to_llama4( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = False, + rms_norm: bool = False, + glu_activation: bool = False, + layer_norm: bool = False, + **kwargs, # pylint: disable=unused-argument +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is False. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be False. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False. + glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. + """ + + import transformers.models.llama4.modeling_llama4 # noqa: F401 # pylint: disable=unused-import + from liger_kernel.transformers.functional import liger_cross_entropy + from liger_kernel.transformers.layer_norm import LigerLayerNorm + from liger_kernel.transformers.rms_norm import LigerRMSNorm + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP + + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + modeling_llama4 = sys.modules["transformers.models.llama4.modeling_llama4"] + + if rms_norm: + modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm + if glu_activation: + modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP + if layer_norm: + modeling_llama4.nn.LayerNorm = LigerLayerNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + modeling_llama4.Llama4ForCausalLM.forward = lce_forward diff --git a/src/axolotl/monkeypatch/trainer_accelerator_args.py b/src/axolotl/monkeypatch/trainer_accelerator_args.py new file mode 100644 index 000000000..8c68f2c8a --- /dev/null +++ b/src/axolotl/monkeypatch/trainer_accelerator_args.py @@ -0,0 +1,80 @@ +""" +allow adding additional kwargs to Accelerator init +""" + +import inspect +import logging + +from transformers import Trainer + +from axolotl.monkeypatch.utils import detab_code + +LOG = logging.getLogger(__name__) + +ORIGINAL_TRAINER_CODE = """ + # create accelerator object + self.accelerator = Accelerator(**args) +""" + +PATCHED_TRAINER_CODE = """ + if hasattr(self, "override_accelerator_args"): + additional_args = self.override_accelerator_args(**args) + if additional_args: + args.update(additional_args) + + # create accelerator object + self.accelerator = Accelerator(**args) +""" + + +def get_create_accelerate_code() -> str: + training_loop = inspect.getsource(Trainer.create_accelerator_and_postprocess) + return training_loop + + +def check_create_accelerate_code_is_patchable() -> bool: + create_code = get_create_accelerate_code() + create_code, _ = detab_code(create_code) + return ORIGINAL_TRAINER_CODE in create_code + + +def patch_create_accelerate_code_for_fp8(): + """ + monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs + """ + + try: + create_code = get_create_accelerate_code() + except OSError: + return + Trainer._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access + create_code + ) + create_code, _ = detab_code(create_code) + if ORIGINAL_TRAINER_CODE not in create_code: + return + + create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE) + create_code = create_code.replace( + "def create_accelerator_and_postprocess(", + "def fixed_create_accelerator_and_postprocess(", + 1, + ) + + # load imports necessary + import transformers.trainer + + items_to_import = [] + for item in dir(transformers.trainer): + if item in create_code: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(create_code, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching create_accelerator_and_postprocess to allow for overrides") + Trainer.create_accelerator_and_postprocess = fixed_create_accelerator_and_postprocess # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0e1329b97..79f6c5a9b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -557,6 +557,13 @@ class ModelLoader: plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(self.cfg) + # monkey patch to allow additional Accelerator init kwargs + from axolotl.monkeypatch.trainer_accelerator_args import ( + patch_create_accelerate_code_for_fp8, + ) + + patch_create_accelerate_code_for_fp8() + if self.cfg.adapter: from axolotl.monkeypatch.transformers_fa_utils import ( patch_fa_peft_integration, @@ -988,10 +995,11 @@ class ModelLoader: ) skip_move_to_device = True elif ( - self.model_config.model_type == "llama" + self.model_config.model_type in ["llama", "llama4"] and not self.cfg.trust_remote_code and not self.cfg.gptq ): + # TODO do we need to open this up for all models? if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: skip_move_to_device = True if "device_map" in self.model_kwargs: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 3ceae4273..4995962df 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -464,9 +464,10 @@ class AxolotlInputConfig( data.get("sample_packing") and not data.get("flash_attention") and not data.get("sdp_attention") + and not data.get("flex_attention") ): LOG.warning( - "sample_packing without flash_attention or sdp_attention does not handle cross-attention." + "sample_packing without flash, sdp or flex attention does not handle cross sample decontamination." ) return data diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c5c9e5599..964b17086 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -582,7 +582,9 @@ def prepare_optim_env(cfg): setup_torch_compile_env(cfg) - if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True: + if cfg.fp8: + os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" + elif (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True: os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" elif cfg.fp16: os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 758491d63..f44c775c8 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -500,9 +500,7 @@ class TestMultiGPULlama: ], "fsdp_config": { "fsdp_version": 2, - "fsdp_forward_prefetch": True, - "fsdp_sync_module_states": True, - "fsdp_use_orig_params": True, + # "fsdp_forward_prefetch": True, # not yet implemented in accelerate "fsdp_offload_params": False, "fsdp_cpu_ram_efficient_loading": False, "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",