Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
46afcf070f rename to specify fsdp 2025-04-09 02:39:03 -04:00
Wing Lian
3036ca349f add README for llama4 2025-04-09 02:15:09 -04:00
Wing Lian
dc4809f7dd [llama4] fix the mm yaml, add scout single gpu yaml 2025-04-09 01:52:31 -04:00
10 changed files with 210 additions and 125 deletions

View File

@@ -68,7 +68,7 @@ def run_cmd(cmd: str, run_folder: str):
@app.function( @app.function(
image=cicd_image, image=cicd_image,
gpu=GPU_CONFIG, gpu=GPU_CONFIG,
timeout=90 * 60, timeout=60 * 60,
cpu=8.0, cpu=8.0,
memory=131072 * N_GPUS, memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG, volumes=VOLUME_CONFIG,

View File

@@ -12,7 +12,7 @@ liger-kernel==0.5.6
packaging==23.2 packaging==23.2
peft==0.15.1 peft==0.15.1
transformers==4.51.1 transformers==4.51.0
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.6.0 accelerate==1.6.0
datasets==3.5.0 datasets==3.5.0

View File

@@ -185,7 +185,5 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm, rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm, layer_norm=cfg.liger_layer_norm,
) )
else: elif cfg.model_config_type in ["deepseek_v3"]:
logging.warning( raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)

View File

@@ -3,7 +3,6 @@ Liger FLCE for llama4
""" """
import sys import sys
from copy import deepcopy
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@@ -159,16 +158,7 @@ def apply_liger_kernel_to_llama4(
if rms_norm: if rms_norm:
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
if glu_activation: if glu_activation:
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
return LigerSwiGLUMLP(config, **kwargs)
modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper
if layer_norm: if layer_norm:
modeling_llama4.nn.LayerNorm = LigerLayerNorm modeling_llama4.nn.LayerNorm = LigerLayerNorm

View File

@@ -0,0 +1,171 @@
"""Flex attention monkey patch"""
import sys
from typing import Optional, Tuple, Union
import torch
import transformers
def patch_flex_wrapper():
# TODO remove this patch when transformers#37285 is merged and in a release
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
if not (is_torch_2_6 and is_transformers_below_4_51):
return
from torch.nn.attention.flex_attention import flex_attention
class WrappedFlexAttention:
"""
We are doing a singleton class so that flex attention is compiled once when it's first called.
"""
_instance = None
_is_flex_compiled = False
_compiled_flex_attention = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
# Create a new instance if one doesn't already exist
cls._instance = super().__new__(cls)
return cls._instance
@torch.compiler.disable(recursive=False)
def __init__(self):
"""
Initialize or update the singleton instance.
"""
if not self._is_flex_compiled:
self._compiled_flex_attention = torch.compile(
flex_attention,
dynamic=False,
mode="max-autotune-no-cudagraphs",
fullgraph=True,
)
self._is_flex_compiled = True
def __call__(self):
return self._compiled_flex_attention
transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention
def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
if not (is_torch_2_6 and is_transformers_eq_4_51):
return
from torch.nn.attention.flex_attention import (
BlockMask,
)
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
Offset = Union[torch.Tensor, int]
def patched_make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. BlockMask is essential for performant computation of flex attention.
See: https://pytorch.org/blog/flexattention/
Args:
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
of shape (batch_size, total_seq_len). e.g.
For unpacked sequence:
[[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]]
For packed sequence:
[[1, 1, 1, 2, 2, 2, 0],
[1, 1, 2, 2, 2, 3, 3]]
Returns:
BlockMask
"""
batch_size, total_seq_len = attention_mask_2d.shape
if not key_length:
key_length = total_seq_len
if not query_length:
query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad(
attention_mask_2d, value=0, pad=(0, key_length)
)
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (
attention_chunk_size
)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def causal_mask_mod(
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = causal_mask_mod
return create_block_causal_mask_flex(
mask_mod=mask_mod,
B=batch_size,
H=None, # attention head
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
_compile=True,
)
for n in tuple(sys.modules):
if ".modeling_" in n and "llama4" not in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
sys.modules[n].make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)

View File

@@ -906,7 +906,20 @@ class ModelLoader:
""" """
sample packing uses custom FA2 patch sample packing uses custom FA2 patch
""" """
if self.cfg.flash_attention: if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
patch_flex_wrapper()
patch_flex_make_mask()
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention: if not self.cfg.sample_packing and self.cfg.s2_attention:
pass pass
self.model_kwargs["attn_implementation"] = "flash_attention_2" self.model_kwargs["attn_implementation"] = "flash_attention_2"

View File

@@ -1316,29 +1316,8 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if version.parse(torch_version) < version.parse("2.6.0"): if version.parse(torch_version) < version.parse("2.6.0"):
raise ValueError( raise ValueError(
"Flex attention is not supported on torch version < 2.6.0." "Flex attention is not supported on torch version < 2.6.0"
) )
if version.parse(torch_version) < version.parse("2.7.0"):
LOG.warning(
f"You are currently using torch version {torch_version}. "
"We recommend using the latest version of torch for flex attention. "
"You may encounter unexpected issues with flex attention on older versions of torch. "
"Please upgrade to the latest stable, or nightly version of torch. "
)
transformers_version = env_capabilities.get("transformers_version")
if transformers_version is None:
import transformers
transformers_version = str(transformers.__version__).split(
"+", maxsplit=1
)[0]
if version.parse(transformers_version) < version.parse("4.45.1"):
raise ValueError(
"Transformers version < 4.45.1 is not supported with flex attention. "
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -16,7 +16,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0, require_torch_2_7_0 from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -459,80 +459,16 @@ class TestMultiGPULlama:
@require_torch_2_6_0 @require_torch_2_6_0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_reshard_after_forward", "attention_backend",
[True, False], ["flash", "flex"],
) )
def test_fsdp2_packed_flash(self, temp_dir, fsdp_reshard_after_forward):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine",
"fsdp": [
"auto_wrap",
],
"fsdp_config": {
"fsdp_version": 2,
# "fsdp_forward_prefetch": True, # not yet implemented in accelerate
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
},
"use_tensorboard": True,
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
)
@require_torch_2_7_0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_reshard_after_forward", "fsdp_reshard_after_forward",
[True, False], [True, False],
) )
def test_fsdp2_packed_flex(self, temp_dir, fsdp_reshard_after_forward): def test_fsdp2_packed(
self, temp_dir, attention_backend, fsdp_reshard_after_forward
):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -573,9 +509,13 @@ class TestMultiGPULlama:
"fsdp_reshard_after_forward": fsdp_reshard_after_forward, "fsdp_reshard_after_forward": fsdp_reshard_after_forward,
}, },
"use_tensorboard": True, "use_tensorboard": True,
"flex_attention": True,
} }
) )
if attention_backend == "flash":
cfg.flash_attention = True
elif attention_backend == "flex":
cfg.flex_attention = True
# write cfg to yaml file # write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True) Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
@@ -677,6 +617,12 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
) )
# TODO: remove skip once deepspeed regression is fixed
# see https://github.com/huggingface/transformers/pull/37324
@pytest.mark.skipif(
transformers_version_eq("4.51.0"),
reason="zero3 is not supported with transformers==4.51.0",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 2], [1, 2],

View File

@@ -14,7 +14,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard, require_torch_2_7_0, with_temp_dir from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -25,7 +25,7 @@ class TestPackedFlex(unittest.TestCase):
Test case for Packed training of llama models Test case for Packed training of llama models
""" """
@require_torch_2_7_0 @require_torch_2_6_0
@with_temp_dir @with_temp_dir
def test_loss_llama(self, temp_dir): def test_loss_llama(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code

View File

@@ -33,18 +33,6 @@ def with_temp_dir(test_func):
return wrapper return wrapper
def require_torch_2_7_0(test_case):
"""
Decorator marking a test that requires torch >= 2.7.0
"""
def is_min_2_7_0():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.7.0")
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
def most_recent_subdir(path): def most_recent_subdir(path):
base_path = Path(path) base_path = Path(path)
subdirectories = [d for d in base_path.iterdir() if d.is_dir()] subdirectories = [d for d in base_path.iterdir() if d.is_dir()]