Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
37a66e6866 multigpu longer timeout 2025-04-09 01:54:35 -04:00
Wing Lian
9f69597a5f upgrade transformers to 4.51.1 2025-04-09 00:20:50 -04:00
11 changed files with 231 additions and 256 deletions

View File

@@ -1,10 +0,0 @@
# Llama 4 by Meta AI
## Available Examples
### Llama 4 Scout 17Bx16Experts (109B)
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
Our Single GPU implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second.

View File

@@ -1,28 +1,13 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16 base_model: meta-llama/Llama-4-Scout-17B-16E
model_type: Llama4ForConditionalGeneration model_type: Llama4ForConditionalGeneration
processor_type: Llama4Processor
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
strict: false strict: false
# these 3 lines are needed for now to handle vision chat templates w images # torch_compile: true
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
sequence_len: 4096 adapter: lora
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true # use Axolotl's customized model
load_in_4bit: true
adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 64 lora_alpha: 64
lora_target_modules: lora_target_modules:
@@ -30,59 +15,60 @@ lora_target_modules:
- self_attn.k_proj - self_attn.k_proj
- self_attn.v_proj - self_attn.v_proj
- self_attn.o_proj - self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
- vision_adapter.mlp.fc1
- vision_adapter.mlp.fc2
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save: lora_modules_to_save:
- lm_head - lm_head
- embed_tokens - embed_tokens
chat_template: llama4 chat_template: llama4
datasets: datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft - path: mlabonne/FineTome-100k
type: chat_template type: chat_template
split: train[:1%] split: train[:20%]
field_messages: messages field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1
optimizer: adamw_torch_4bit optimizer: adamw_torch_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 2e-5 learning_rate: 2e-5
bf16: true bf16: true
tf32: true tf32: true
# gradient_checkpointing: true
# gradient_checkpointing_kwargs:
# use_reentrant: false
logging_steps: 1 logging_steps: 1
flash_attention: true flash_attention: true
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 1 evals_per_epoch: 2
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
- auto_wrap - auto_wrap
- full_shard - full_shard
fsdp_config: fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer fsdp_version: 2
fsdp_limit_all_gathers: true fsdp_offload_params: false
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: FULL_STATE_DICT fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true fsdp_activation_checkpointing: true
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>

View File

@@ -1,86 +0,0 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
# - lm_head
# - embed_tokens
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096 # up to 8k will work on a single H100
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
gradient_checkpointing: offload
gradient_checkpointing_kwargs:
use_reentrant: false
warmup_steps: 20
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -185,7 +185,5 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm, rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm, layer_norm=cfg.liger_layer_norm,
) )
else: elif cfg.model_config_type in ["deepseek_v3"]:
logging.warning( raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)

View File

@@ -3,7 +3,6 @@ Liger FLCE for llama4
""" """
import sys import sys
from copy import deepcopy
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@@ -159,16 +158,7 @@ def apply_liger_kernel_to_llama4(
if rms_norm: if rms_norm:
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
if glu_activation: if glu_activation:
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
return LigerSwiGLUMLP(config, **kwargs)
modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper
if layer_norm: if layer_norm:
modeling_llama4.nn.LayerNorm = LigerLayerNorm modeling_llama4.nn.LayerNorm = LigerLayerNorm

View File

@@ -0,0 +1,171 @@
"""Flex attention monkey patch"""
import sys
from typing import Optional, Tuple, Union
import torch
import transformers
def patch_flex_wrapper():
# TODO remove this patch when transformers#37285 is merged and in a release
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
if not (is_torch_2_6 and is_transformers_below_4_51):
return
from torch.nn.attention.flex_attention import flex_attention
class WrappedFlexAttention:
"""
We are doing a singleton class so that flex attention is compiled once when it's first called.
"""
_instance = None
_is_flex_compiled = False
_compiled_flex_attention = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
# Create a new instance if one doesn't already exist
cls._instance = super().__new__(cls)
return cls._instance
@torch.compiler.disable(recursive=False)
def __init__(self):
"""
Initialize or update the singleton instance.
"""
if not self._is_flex_compiled:
self._compiled_flex_attention = torch.compile(
flex_attention,
dynamic=False,
mode="max-autotune-no-cudagraphs",
fullgraph=True,
)
self._is_flex_compiled = True
def __call__(self):
return self._compiled_flex_attention
transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention
def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
if not (is_torch_2_6 and is_transformers_eq_4_51):
return
from torch.nn.attention.flex_attention import (
BlockMask,
)
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
Offset = Union[torch.Tensor, int]
def patched_make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. BlockMask is essential for performant computation of flex attention.
See: https://pytorch.org/blog/flexattention/
Args:
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
of shape (batch_size, total_seq_len). e.g.
For unpacked sequence:
[[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]]
For packed sequence:
[[1, 1, 1, 2, 2, 2, 0],
[1, 1, 2, 2, 2, 3, 3]]
Returns:
BlockMask
"""
batch_size, total_seq_len = attention_mask_2d.shape
if not key_length:
key_length = total_seq_len
if not query_length:
query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad(
attention_mask_2d, value=0, pad=(0, key_length)
)
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (
attention_chunk_size
)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def causal_mask_mod(
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = causal_mask_mod
return create_block_causal_mask_flex(
mask_mod=mask_mod,
B=batch_size,
H=None, # attention head
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
_compile=True,
)
for n in tuple(sys.modules):
if ".modeling_" in n and "llama4" not in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
sys.modules[n].make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)

View File

@@ -906,7 +906,20 @@ class ModelLoader:
""" """
sample packing uses custom FA2 patch sample packing uses custom FA2 patch
""" """
if self.cfg.flash_attention: if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
patch_flex_wrapper()
patch_flex_make_mask()
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention: if not self.cfg.sample_packing and self.cfg.s2_attention:
pass pass
self.model_kwargs["attn_implementation"] = "flash_attention_2" self.model_kwargs["attn_implementation"] = "flash_attention_2"

View File

@@ -1316,29 +1316,8 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if version.parse(torch_version) < version.parse("2.6.0"): if version.parse(torch_version) < version.parse("2.6.0"):
raise ValueError( raise ValueError(
"Flex attention is not supported on torch version < 2.6.0." "Flex attention is not supported on torch version < 2.6.0"
) )
if version.parse(torch_version) < version.parse("2.7.0"):
LOG.warning(
f"You are currently using torch version {torch_version}. "
"We recommend using the latest version of torch for flex attention. "
"You may encounter unexpected issues with flex attention on older versions of torch. "
"Please upgrade to the latest stable, or nightly version of torch. "
)
transformers_version = env_capabilities.get("transformers_version")
if transformers_version is None:
import transformers
transformers_version = str(transformers.__version__).split(
"+", maxsplit=1
)[0]
if version.parse(transformers_version) < version.parse("4.45.1"):
raise ValueError(
"Transformers version < 4.45.1 is not supported with flex attention. "
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -16,7 +16,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0, require_torch_2_7_0 from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -459,80 +459,16 @@ class TestMultiGPULlama:
@require_torch_2_6_0 @require_torch_2_6_0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_reshard_after_forward", "attention_backend",
[True, False], ["flash", "flex"],
) )
def test_fsdp2_packed_flash(self, temp_dir, fsdp_reshard_after_forward):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine",
"fsdp": [
"auto_wrap",
],
"fsdp_config": {
"fsdp_version": 2,
# "fsdp_forward_prefetch": True, # not yet implemented in accelerate
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
},
"use_tensorboard": True,
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
)
@require_torch_2_7_0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_reshard_after_forward", "fsdp_reshard_after_forward",
[True, False], [True, False],
) )
def test_fsdp2_packed_flex(self, temp_dir, fsdp_reshard_after_forward): def test_fsdp2_packed(
self, temp_dir, attention_backend, fsdp_reshard_after_forward
):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -573,9 +509,13 @@ class TestMultiGPULlama:
"fsdp_reshard_after_forward": fsdp_reshard_after_forward, "fsdp_reshard_after_forward": fsdp_reshard_after_forward,
}, },
"use_tensorboard": True, "use_tensorboard": True,
"flex_attention": True,
} }
) )
if attention_backend == "flash":
cfg.flash_attention = True
elif attention_backend == "flex":
cfg.flex_attention = True
# write cfg to yaml file # write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True) Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
@@ -677,6 +617,12 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
) )
# TODO: remove skip once deepspeed regression is fixed
# see https://github.com/huggingface/transformers/pull/37324
@pytest.mark.skipif(
transformers_version_eq("4.51.0"),
reason="zero3 is not supported with transformers==4.51.0",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 2], [1, 2],

View File

@@ -14,7 +14,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard, require_torch_2_7_0, with_temp_dir from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -25,7 +25,7 @@ class TestPackedFlex(unittest.TestCase):
Test case for Packed training of llama models Test case for Packed training of llama models
""" """
@require_torch_2_7_0 @require_torch_2_6_0
@with_temp_dir @with_temp_dir
def test_loss_llama(self, temp_dir): def test_loss_llama(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code

View File

@@ -33,18 +33,6 @@ def with_temp_dir(test_func):
return wrapper return wrapper
def require_torch_2_7_0(test_case):
"""
Decorator marking a test that requires torch >= 2.7.0
"""
def is_min_2_7_0():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.7.0")
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
def most_recent_subdir(path): def most_recent_subdir(path):
base_path = Path(path) base_path = Path(path)
subdirectories = [d for d in base_path.iterdir() if d.is_dir()] subdirectories = [d for d in base_path.iterdir() if d.is_dir()]