From 8bbad21bfdbd312aede4d7277adc5a7448d7e941 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 7 Apr 2025 10:49:15 -0400 Subject: [PATCH] llama4 support (#2493) * llama4 support * add xet support [skip ci] * be flexible on transformers version and skip test on version * don't use deepspeed for the fix_untrained_tokens test * reordering to trigger torch 2.6.0 tests first * slightly smaller train set * use 4.51.0 for now * remove stray print, add llama4 chat template to schema, bump peft to 0.15.1 * patches to make llama4 performant * add preliminary fp8 support --- .github/workflows/multi-gpu-e2e.yml | 14 +- .github/workflows/tests.yml | 4 +- examples/llama4/scout-lora.yaml | 75 ++++++++ requirements.txt | 5 +- src/axolotl/core/trainers/base.py | 13 ++ src/axolotl/integrations/liger/__init__.py | 12 ++ .../integrations/liger/models/llama4.py | 171 ++++++++++++++++++ .../monkeypatch/attention/flex_attn.py | 1 - src/axolotl/monkeypatch/multipack.py | 1 + .../monkeypatch/trainer_accelerator_args.py | 80 ++++++++ src/axolotl/utils/chat_templates.py | 1 + src/axolotl/utils/models.py | 11 +- src/axolotl/utils/schemas/config.py | 4 +- src/axolotl/utils/schemas/enums.py | 1 + src/axolotl/utils/trainer.py | 4 +- tests/e2e/multigpu/test_llama.py | 40 ++-- tests/e2e/multigpu/test_ray.py | 6 +- 17 files changed, 409 insertions(+), 34 deletions(-) create mode 100644 examples/llama4/scout-lora.yaml create mode 100644 src/axolotl/integrations/liger/models/llama4.py create mode 100644 src/axolotl/monkeypatch/trainer_accelerator_args.py diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index dfa315618..f89de494d 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -24,6 +24,13 @@ jobs: fail-fast: false matrix: include: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.6.0 + axolotl_extras: vllm + num_gpus: 2 + nightly_build: "true" - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" @@ -38,13 +45,6 @@ jobs: axolotl_extras: vllm num_gpus: 2 nightly_build: "true" - - cuda: 124 - cuda_version: 12.4.1 - python_version: "3.11" - pytorch: 2.6.0 - axolotl_extras: vllm - num_gpus: 2 - nightly_build: "true" runs-on: [self-hosted, modal] timeout-minutes: 120 steps: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 434803d2c..9eb85a5b1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -211,7 +211,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.5.1 + pytorch: 2.6.0 num_gpus: 1 axolotl_extras: vllm steps: @@ -258,7 +258,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.6.0 + pytorch: 2.5.1 num_gpus: 1 axolotl_extras: vllm steps: diff --git a/examples/llama4/scout-lora.yaml b/examples/llama4/scout-lora.yaml new file mode 100644 index 000000000..26534b560 --- /dev/null +++ b/examples/llama4/scout-lora.yaml @@ -0,0 +1,75 @@ +base_model: meta-llama/Llama-4-Scout-17B-16E +model_type: Llama4ForConditionalGeneration + # Automatically upload checkpoint and final model to HF + # hub_model_id: username/custom_model_name + +strict: false + + # torch_compile: true + +adapter: lora +lora_r: 32 +lora_alpha: 64 +lora_target_modules: + - self_attn.q_proj + - self_attn.k_proj + - self_attn.v_proj + - self_attn.o_proj +lora_modules_to_save: + - lm_head + - embed_tokens + +chat_template: llama4 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +# gradient_checkpointing: true +# gradient_checkpointing_kwargs: +# use_reentrant: false +logging_steps: 1 +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 2 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - auto_wrap + - full_shard +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD + fsdp_reshard_after_forward: true + fsdp_activation_checkpointing: true +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot|> diff --git a/requirements.txt b/requirements.txt index d82489203..f2b2df5fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,18 +6,19 @@ triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 autoawq==0.2.7.post3 -liger-kernel==0.5.5 +liger-kernel==0.5.6 # END section packaging==23.2 -peft==0.15.0 +peft==0.15.1 transformers==4.51.0 tokenizers>=0.21.1 accelerate==1.6.0 datasets==3.5.0 deepspeed>=0.15.4 trl==0.16.1 +hf_xet==1.0.0 optimum==1.16.2 hf_transfer diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 9fed78eb7..bc3a200d4 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -562,6 +562,19 @@ class AxolotlTrainer( return res + def additional_accelerator_args( + self, fp8=None, **kwargs + ): # pylint: disable=unused-argument + ret_kwargs = {} + if fp8: + from accelerate.utils import AORecipeKwargs + + ret_kwargs["mixed_precision"] = "fp8" + ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()] + os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" + + return ret_kwargs + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: """ Log `logs` on the various objects watching training, including stored metrics. diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 82a46d9cf..8d737175e 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -173,5 +173,17 @@ class LigerPlugin(BasePlugin): raise NotImplementedError( "Fused linear cross entropy is not yet supported for Gemma3." ) + elif cfg.model_config_type == "llama4": + from axolotl.integrations.liger.models.llama4 import ( + apply_liger_kernel_to_llama4, + ) + + apply_liger_kernel_to_llama4( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) elif cfg.model_config_type in ["deepseek_v3"]: raise ValueError(f"Unsupported model config type: {cfg.model_config_type}") diff --git a/src/axolotl/integrations/liger/models/llama4.py b/src/axolotl/integrations/liger/models/llama4.py new file mode 100644 index 000000000..da35b114c --- /dev/null +++ b/src/axolotl/integrations/liger/models/llama4.py @@ -0,0 +1,171 @@ +""" +Liger FLCE for llama4 +""" + +import sys +from typing import List, Optional, Tuple, Union + +import torch +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from transformers.modeling_outputs import CausalLMOutputWithPast + + +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[ + Union["Cache", List[torch.FloatTensor]] # noqa: F821 + ] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + """ + + # pylint: disable=duplicate-code + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: + raise Exception( # pylint: disable=broad-exception-raised + "Liger Kernel does not support pretraining_tp!!" + ) + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + loss = LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + hidden_size=self.config.hidden_size, + **loss_kwargs, + ) + + else: # if in inference mode materialize logits + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def apply_liger_kernel_to_llama4( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = False, + rms_norm: bool = False, + glu_activation: bool = False, + layer_norm: bool = False, + **kwargs, # pylint: disable=unused-argument +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is False. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be False. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False. + glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. + """ + + import transformers.models.llama4.modeling_llama4 # noqa: F401 # pylint: disable=unused-import + from liger_kernel.transformers.functional import liger_cross_entropy + from liger_kernel.transformers.layer_norm import LigerLayerNorm + from liger_kernel.transformers.rms_norm import LigerRMSNorm + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP + + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + modeling_llama4 = sys.modules["transformers.models.llama4.modeling_llama4"] + + if rms_norm: + modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm + if glu_activation: + modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP + if layer_norm: + modeling_llama4.nn.LayerNorm = LigerLayerNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + modeling_llama4.Llama4ForCausalLM.forward = lce_forward diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index 2ca5b09a6..d65ee706f 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -162,7 +162,6 @@ def patch_flex_make_mask(): for n in tuple(sys.modules): if ".modeling_" in n and "llama4" not in n: if hasattr(sys.modules[n], "make_flex_block_causal_mask"): - print(n) sys.modules[n].make_flex_block_causal_mask = ( patched_make_flex_block_causal_mask ) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 015743329..2b02699bd 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -13,6 +13,7 @@ from axolotl.monkeypatch.utils import get_unpad_data SUPPORTED_MULTIPACK_MODEL_TYPES = [ "mllama_text_model", "llama", + "llama4", "mistral", "mixtral", "qwen2", diff --git a/src/axolotl/monkeypatch/trainer_accelerator_args.py b/src/axolotl/monkeypatch/trainer_accelerator_args.py new file mode 100644 index 000000000..d87812c9f --- /dev/null +++ b/src/axolotl/monkeypatch/trainer_accelerator_args.py @@ -0,0 +1,80 @@ +""" +allow adding additional kwargs to Accelerator init +""" + +import inspect +import logging + +from transformers import Trainer + +from axolotl.monkeypatch.utils import detab_code + +LOG = logging.getLogger(__name__) + +ORIGINAL_TRAINER_CODE = """ + # create accelerator object + self.accelerator = Accelerator(**args) +""" + +PATCHED_TRAINER_CODE = """ + if hasattr(self, "additional_accelerator_args"): + additional_args = self.additional_accelerator_args(fp8=True, **args) + if additional_args: + args.update(additional_args) + + # create accelerator object + self.accelerator = Accelerator(**args) +""" + + +def get_create_accelerate_code() -> str: + training_loop = inspect.getsource(Trainer.create_accelerator_and_postprocess) + return training_loop + + +def check_create_accelerate_code_is_patchable() -> bool: + create_code = get_create_accelerate_code() + create_code, _ = detab_code(create_code) + return ORIGINAL_TRAINER_CODE in create_code + + +def patch_create_accelerate_code_for_fp8(): + """ + monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs + """ + + try: + create_code = get_create_accelerate_code() + except OSError: + return + Trainer._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access + create_code + ) + create_code, _ = detab_code(create_code) + if ORIGINAL_TRAINER_CODE not in create_code: + return + + create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE) + create_code = create_code.replace( + "def create_accelerator_and_postprocess(", + "def fixed_create_accelerator_and_postprocess(", + 1, + ) + + # load imports necessary + import transformers.trainer + + items_to_import = [] + for item in dir(transformers.trainer): + if item in create_code: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(create_code, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching create_accelerator_and_postprocess to allow for overrides") + Trainer.create_accelerator_and_postprocess = fixed_create_accelerator_and_postprocess # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index ba0516eb9..234b42d8d 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -26,6 +26,7 @@ _CHAT_TEMPLATES = { "gemma3": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'model\n'}}\n{%- endif -%}\n", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "llama4": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %} \n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- else %}\n {#- FIXME: The processor requires an array, always. #}\n {%- set system_message = messages[0]['content'][0]['text']|trim %}\n {%- endif %}\n {%- set messages = messages[1:] %}\n {%- set user_supplied_system_message = true %}\n{%- else %}\n {%- set system_message = \"\" %}\n {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- System message if the user supplied one #}\n{%- if user_supplied_system_message %}\n {{- \"<|header_start|>system<|header_end|>\\n\\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n {%- endif %}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|header_start|>user<|header_end|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|header_start|>' + message['role'] + '<|header_end|>\\n\\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n {{- '<|header_start|>assistant<|header_end|>\\n\\n' -}}\n {{- '<|python_start|>' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|python_end|>' }}\n {%- for tool_call in message.tool_calls %}\n {{- '{\"name\": \"' + tool_call.function.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.function.arguments | tojson }}\n {{- \"}\" }}\n {%- endfor %}\n {{- \"<|eot|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|header_start|>ipython<|header_end|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|header_start|>assistant<|header_end|>\\n\\n' }}\n{%- endif %}\n", "llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n', "llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0e1329b97..367e69850 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -557,6 +557,14 @@ class ModelLoader: plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(self.cfg) + # monkey patch to allow additional Accelerator init kwargs + if self.cfg.fp8: + from axolotl.monkeypatch.trainer_accelerator_args import ( + patch_create_accelerate_code_for_fp8, + ) + + patch_create_accelerate_code_for_fp8() + if self.cfg.adapter: from axolotl.monkeypatch.transformers_fa_utils import ( patch_fa_peft_integration, @@ -988,10 +996,11 @@ class ModelLoader: ) skip_move_to_device = True elif ( - self.model_config.model_type == "llama" + self.model_config.model_type in ["llama", "llama4"] and not self.cfg.trust_remote_code and not self.cfg.gptq ): + # TODO do we need to open this up for all models? if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: skip_move_to_device = True if "device_map" in self.model_kwargs: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 3ceae4273..0f9a3a1f9 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -169,6 +169,7 @@ class AxolotlInputConfig( bf16: Literal["auto"] | bool | None = "auto" fp16: bool | None = None + fp8: bool | None = None bfloat16: bool | None = None # for non-AMP cases float16: bool | None = None # for non-AMP cases tf32: bool | None = None @@ -464,9 +465,10 @@ class AxolotlInputConfig( data.get("sample_packing") and not data.get("flash_attention") and not data.get("sdp_attention") + and not data.get("flex_attention") ): LOG.warning( - "sample_packing without flash_attention or sdp_attention does not handle cross-attention." + "sample_packing without flash, sdp or flex attention does not handle cross sample decontamination." ) return data diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index ad735afe0..16b91ec41 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -26,6 +26,7 @@ class ChatTemplate(str, Enum): gemma = "gemma" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name + llama4 = "llama4" # pylint: disable=invalid-name llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name phi_35 = "phi_35" # pylint: disable=invalid-name diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c5c9e5599..964b17086 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -582,7 +582,9 @@ def prepare_optim_env(cfg): setup_torch_compile_env(cfg) - if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True: + if cfg.fp8: + os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" + elif (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True: os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" elif cfg.fp16: os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index d71fa25c8..f44c775c8 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -7,9 +7,11 @@ import os from pathlib import Path import pytest +import transformers import yaml from accelerate.test_utils import execute_subprocess_async from huggingface_hub import snapshot_download +from packaging import version from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault @@ -28,6 +30,10 @@ def download_model(): snapshot_download("HuggingFaceTB/SmolLM2-135M") +def transformers_version_eq(required_version): + return version.parse(transformers.__version__) == version.parse(required_version) + + class TestMultiGPULlama: """ Test case for Llama models using LoRA @@ -56,7 +62,7 @@ class TestMultiGPULlama: ], "num_epochs": 1, "max_steps": 2, - "micro_batch_size": 4, + "micro_batch_size": 1, "gradient_accumulation_steps": 4, # "gradient_checkpointing": True, "output_dir": temp_dir, @@ -108,7 +114,7 @@ class TestMultiGPULlama: "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_linear": True, - "val_set_size": 0.01, + "val_set_size": 0.05, "special_tokens": { "pad_token": "<|endoftext|>", }, @@ -116,6 +122,7 @@ class TestMultiGPULlama: { "path": "tatsu-lab/alpaca", "type": "alpaca", + "split": "train[:20%]", }, ], "num_epochs": 1, @@ -193,7 +200,7 @@ class TestMultiGPULlama: ], "num_epochs": 1, "max_steps": 2, - "micro_batch_size": 4, + "micro_batch_size": 2, "gradient_accumulation_steps": 4, # "gradient_checkpointing": True, "output_dir": temp_dir, @@ -390,7 +397,7 @@ class TestMultiGPULlama: "base_model": "HuggingFaceTB/SmolLM2-135M", "sample_packing": True, "pad_to_sequence_len": True, - "sequence_len": 2048, + "sequence_len": 1024, "val_set_size": 0.01, "special_tokens": { "pad_token": "<|endoftext|>", @@ -403,7 +410,7 @@ class TestMultiGPULlama: ], "num_epochs": 1, "max_steps": 2, - "micro_batch_size": 4, + "micro_batch_size": 2, "gradient_accumulation_steps": 2, # "gradient_checkpointing": True, "output_dir": temp_dir, @@ -493,9 +500,7 @@ class TestMultiGPULlama: ], "fsdp_config": { "fsdp_version": 2, - "fsdp_forward_prefetch": True, - "fsdp_sync_module_states": True, - "fsdp_use_orig_params": True, + # "fsdp_forward_prefetch": True, # not yet implemented in accelerate "fsdp_offload_params": False, "fsdp_cpu_ram_efficient_loading": False, "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", @@ -551,7 +556,7 @@ class TestMultiGPULlama: "sample_packing": True, "eval_sample_packing": False, "pad_to_sequence_len": True, - "sequence_len": 2048, + "sequence_len": 1024, "val_set_size": 0.01, "special_tokens": { "pad_token": "<|endoftext|>", @@ -565,7 +570,7 @@ class TestMultiGPULlama: ], "num_epochs": 1, "max_steps": 2, - "micro_batch_size": 4, + "micro_batch_size": 2, "gradient_accumulation_steps": 2, # "gradient_checkpointing": True, "output_dir": temp_dir, @@ -612,8 +617,11 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" ) - @pytest.mark.skip( - reason="ds-zero3 broken in main until transformers#37281 resolved" + # TODO: remove skip once deepspeed regression is fixed + # see https://github.com/huggingface/transformers/pull/37324 + @pytest.mark.skipif( + transformers_version_eq("4.51.0"), + reason="zero3 is not supported with transformers==4.51.0", ) @pytest.mark.parametrize( "gradient_accumulation_steps", @@ -651,7 +659,7 @@ class TestMultiGPULlama: "base_model": "HuggingFaceTB/SmolLM2-135M", "sample_packing": True, "pad_to_sequence_len": True, - "sequence_len": 2048, + "sequence_len": 1024, "val_set_size": 0.01, "special_tokens": { "pad_token": "<|endoftext|>", @@ -724,7 +732,7 @@ class TestMultiGPULlama: "base_model": "HuggingFaceTB/SmolLM2-135M", "sample_packing": True, "pad_to_sequence_len": True, - "sequence_len": 2048, + "sequence_len": 1024, "val_set_size": 0.01, "special_tokens": { "pad_token": "<|endoftext|>", @@ -797,7 +805,7 @@ class TestMultiGPULlama: "base_model": "HuggingFaceTB/SmolLM2-135M", "sample_packing": True, "pad_to_sequence_len": True, - "sequence_len": 2048, + "sequence_len": 1024, "val_set_size": 0.01, "special_tokens": { "pad_token": "<|endoftext|>", @@ -885,7 +893,7 @@ class TestMultiGPULlama: "sample_packing": True, "bf16": True, "save_safetensors": True, - "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"), + # "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"), "use_tensorboard": True, } ) diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index 14b1c0a86..9be7c6f50 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -31,7 +31,7 @@ class TestMultiGPURay: cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", - "sequence_len": 2048, + "sequence_len": 1024, "adapter": "lora", "lora_r": 8, "lora_alpha": 16, @@ -94,8 +94,8 @@ class TestMultiGPURay: "base_model": "HuggingFaceTB/SmolLM2-135M", "sample_packing": True, "pad_to_sequence_len": True, - "sequence_len": 2048, - "val_set_size": 0.05, + "sequence_len": 1024, + "val_set_size": 0.01, "special_tokens": { "pad_token": "<|endoftext|>", },