Compare commits

...

11 Commits

Author SHA1 Message Date
Wing Lian
1aec93cf9e add preliminary fp8 support 2025-04-06 23:54:50 -04:00
Wing Lian
37630fc6ef patches to make llama4 performant 2025-04-06 22:50:48 -04:00
Wing Lian
4b28b2a0b4 remove stray print, add llama4 chat template to schema, bump peft to 0.15.1 2025-04-06 19:59:48 -04:00
Wing Lian
b38f70e068 use 4.51.0 for now 2025-04-06 18:14:14 -04:00
Wing Lian
cf4c84e21d slightly smaller train set 2025-04-06 17:11:52 -04:00
Wing Lian
98d98ea1dd reordering to trigger torch 2.6.0 tests first 2025-04-06 17:11:52 -04:00
Wing Lian
0cf42ab8a3 don't use deepspeed for the fix_untrained_tokens test 2025-04-06 17:11:52 -04:00
Wing Lian
3d0ab75a0c be flexible on transformers version and skip test on version 2025-04-06 17:11:50 -04:00
Wing Lian
d375be90ff add xet support [skip ci] 2025-04-06 17:09:23 -04:00
Wing Lian
98827e8f3b llama4 support 2025-04-06 17:08:57 -04:00
Wing Lian
5f4af3665d FSDP2 support (#2469)
* fsdp2 support

* use accelerate release 1.6.0

* allow 8bit optims with fsdp2

* liger + torch compile fix

* add fsdp2 e2e tests

* use transformers commit with fsdp2 support

* skip zero3 tests for this PR for now

* fix fsdp2 config for ci

* make sure both flex and flash attn work with fsdp2, skip fix untrained tokens

* okay, actually use fdsp2...

* more fixes to flex for fsdp2

* make sure to patch all the loaded models

* additional validation for fsdp2, bump dep versions
2025-04-06 17:08:01 -04:00
19 changed files with 718 additions and 66 deletions

View File

@@ -24,6 +24,13 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -38,13 +45,6 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -211,7 +211,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
steps:
@@ -258,7 +258,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.5.1
num_gpus: 1
axolotl_extras: vllm
steps:

View File

@@ -0,0 +1,75 @@
base_model: meta-llama/Llama-4-Scout-17B-16E
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
strict: false
# torch_compile: true
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
# gradient_checkpointing: true
# gradient_checkpointing_kwargs:
# use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -6,18 +6,19 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.5
liger-kernel==0.5.6
# END section
packaging==23.2
peft==0.15.0
transformers==4.50.3
peft==0.15.1
transformers==4.51.0
tokenizers>=0.21.1
accelerate==1.5.2
accelerate==1.6.0
datasets==3.5.0
deepspeed==0.15.4
trl==0.16.0
deepspeed>=0.15.4
trl==0.16.1
hf_xet==1.0.0
optimum==1.16.2
hf_transfer

View File

@@ -562,6 +562,19 @@ class AxolotlTrainer(
return res
def additional_accelerator_args(
self, fp8=None, **kwargs
): # pylint: disable=unused-argument
ret_kwargs = {}
if fp8:
from accelerate.utils import AORecipeKwargs
ret_kwargs["mixed_precision"] = "fp8"
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()]
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
return ret_kwargs
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.

View File

@@ -27,6 +27,7 @@ from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable
LOG = logging.getLogger("axolotl.integrations.liger")
@@ -40,6 +41,18 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.torch_compile:
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
import liger_kernel.ops.fused_linear_cross_entropy
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_forward",
)
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_backward",
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
@@ -160,5 +173,17 @@ class LigerPlugin(BasePlugin):
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,
)
apply_liger_kernel_to_llama4(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type in ["deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")

View File

@@ -0,0 +1,171 @@
"""
Liger FLCE for llama4
"""
import sys
from typing import List, Optional, Tuple, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[
Union["Cache", List[torch.FloatTensor]] # noqa: F821
] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
raise Exception( # pylint: disable=broad-exception-raised
"Liger Kernel does not support pretraining_tp!!"
)
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**loss_kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**loss_kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_llama4(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.llama4.modeling_llama4 # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_llama4 = sys.modules["transformers.models.llama4.modeling_llama4"]
if rms_norm:
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
if glu_activation:
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
if layer_norm:
modeling_llama4.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_llama4.Llama4ForCausalLM.forward = lce_forward

View File

@@ -0,0 +1,29 @@
"""
utils to patch liger kernel ops to disable torch.compile
"""
from functools import wraps
import torch
def patch_with_compile_disable(module, function_name):
"""
Patch a function in a module by wrapping it with torch.compile.disable
Args:
module: The module containing the function to patch
function_name: The name of the function to patch
"""
original_function = getattr(module, function_name)
@wraps(original_function)
@torch.compiler.disable
def wrapped_function(*args, **kwargs):
return original_function(*args, **kwargs)
# Replace the original function with the wrapped one
setattr(module, function_name, wrapped_function)
# Return the original function in case you need to restore it later
return original_function

View File

@@ -1,48 +1,171 @@
"""Flex attention monkey patch"""
import sys
from typing import Optional, Tuple, Union
import torch
import transformers
def patch_flex():
def patch_flex_wrapper():
# TODO remove this patch when transformers#37285 is merged and in a release
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
if is_torch_2_6 and is_transformers_below_4_51:
from torch.nn.attention.flex_attention import flex_attention
if not (is_torch_2_6 and is_transformers_below_4_51):
return
class WrappedFlexAttention:
from torch.nn.attention.flex_attention import flex_attention
class WrappedFlexAttention:
"""
We are doing a singleton class so that flex attention is compiled once when it's first called.
"""
_instance = None
_is_flex_compiled = False
_compiled_flex_attention = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
# Create a new instance if one doesn't already exist
cls._instance = super().__new__(cls)
return cls._instance
@torch.compiler.disable(recursive=False)
def __init__(self):
"""
We are doing a singleton class so that flex attention is compiled once when it's first called.
Initialize or update the singleton instance.
"""
if not self._is_flex_compiled:
self._compiled_flex_attention = torch.compile(
flex_attention,
dynamic=False,
mode="max-autotune-no-cudagraphs",
fullgraph=True,
)
self._is_flex_compiled = True
_instance = None
_is_flex_compiled = False
_compiled_flex_attention = None
def __call__(self):
return self._compiled_flex_attention
def __new__(cls, *args, **kwargs):
if cls._instance is None:
# Create a new instance if one doesn't already exist
cls._instance = super().__new__(cls)
return cls._instance
transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention
@torch.compiler.disable(recursive=False)
def __init__(self):
"""
Initialize or update the singleton instance.
"""
if not self._is_flex_compiled:
self._compiled_flex_attention = torch.compile(
flex_attention,
dynamic=False,
mode="max-autotune-no-cudagraphs",
fullgraph=True,
)
self._is_flex_compiled = True
def __call__(self):
return self._compiled_flex_attention
def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
transformers.integrations.flex_attention.WrappedFlexAttention = (
WrappedFlexAttention
if not (is_torch_2_6 and is_transformers_eq_4_51):
return
from torch.nn.attention.flex_attention import (
BlockMask,
)
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
Offset = Union[torch.Tensor, int]
def patched_make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. BlockMask is essential for performant computation of flex attention.
See: https://pytorch.org/blog/flexattention/
Args:
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
of shape (batch_size, total_seq_len). e.g.
For unpacked sequence:
[[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]]
For packed sequence:
[[1, 1, 1, 2, 2, 2, 0],
[1, 1, 2, 2, 2, 3, 3]]
Returns:
BlockMask
"""
batch_size, total_seq_len = attention_mask_2d.shape
if not key_length:
key_length = total_seq_len
if not query_length:
query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad(
attention_mask_2d, value=0, pad=(0, key_length)
)
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (
attention_chunk_size
)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def causal_mask_mod(
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = causal_mask_mod
return create_block_causal_mask_flex(
mask_mod=mask_mod,
B=batch_size,
H=None, # attention head
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
_compile=True,
)
for n in tuple(sys.modules):
if ".modeling_" in n and "llama4" not in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
sys.modules[n].make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)

View File

@@ -13,6 +13,7 @@ from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mllama_text_model",
"llama",
"llama4",
"mistral",
"mixtral",
"qwen2",

View File

@@ -0,0 +1,80 @@
"""
allow adding additional kwargs to Accelerator init
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__)
ORIGINAL_TRAINER_CODE = """
# create accelerator object
self.accelerator = Accelerator(**args)
"""
PATCHED_TRAINER_CODE = """
if hasattr(self, "additional_accelerator_args"):
additional_args = self.additional_accelerator_args(fp8=True, **args)
if additional_args:
args.update(additional_args)
# create accelerator object
self.accelerator = Accelerator(**args)
"""
def get_create_accelerate_code() -> str:
training_loop = inspect.getsource(Trainer.create_accelerator_and_postprocess)
return training_loop
def check_create_accelerate_code_is_patchable() -> bool:
create_code = get_create_accelerate_code()
create_code, _ = detab_code(create_code)
return ORIGINAL_TRAINER_CODE in create_code
def patch_create_accelerate_code_for_fp8():
"""
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
"""
try:
create_code = get_create_accelerate_code()
except OSError:
return
Trainer._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
create_code
)
create_code, _ = detab_code(create_code)
if ORIGINAL_TRAINER_CODE not in create_code:
return
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
create_code = create_code.replace(
"def create_accelerator_and_postprocess(",
"def fixed_create_accelerator_and_postprocess(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in create_code:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(create_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching create_accelerator_and_postprocess to allow for overrides")
Trainer.create_accelerator_and_postprocess = fixed_create_accelerator_and_postprocess # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821

View File

@@ -217,7 +217,7 @@ def save_trained_model(
# Handle FSDP state dict type
state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled:
if trainer.is_fsdp_enabled and str(cfg.fsdp_config.fsdp_version) != "2":
if cfg.fsdp_final_state_dict_type:
state_dict_type = cfg.fsdp_final_state_dict_type
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)

File diff suppressed because one or more lines are too long

View File

@@ -557,6 +557,14 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
# monkey patch to allow additional Accelerator init kwargs
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8()
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (
patch_fa_peft_integration,
@@ -889,9 +897,13 @@ class ModelLoader:
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
from axolotl.monkeypatch.attention.flex_attn import patch_flex
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
patch_flex()
patch_flex_wrapper()
patch_flex_make_mask()
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
@@ -984,10 +996,11 @@ class ModelLoader:
)
skip_move_to_device = True
elif (
self.model_config.model_type == "llama"
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# TODO do we need to open this up for all models?
if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in self.model_kwargs:

View File

@@ -169,6 +169,7 @@ class AxolotlInputConfig(
bf16: Literal["auto"] | bool | None = "auto"
fp16: bool | None = None
fp8: bool | None = None
bfloat16: bool | None = None # for non-AMP cases
float16: bool | None = None # for non-AMP cases
tf32: bool | None = None
@@ -464,9 +465,10 @@ class AxolotlInputConfig(
data.get("sample_packing")
and not data.get("flash_attention")
and not data.get("sdp_attention")
and not data.get("flex_attention")
):
LOG.warning(
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination."
)
return data
@@ -950,10 +952,23 @@ class AxolotlInputConfig(
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_offload_params")
and str(data["fsdp_config"].get("fsdp_version")) != "2"
):
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
)
if (
data.get("fsdp")
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and str(data["fsdp_config"].get("fsdp_version")) == "2"
):
if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]:
# CUDA ops errors with bnb 8bit optimizer + FSDP2
raise ValueError(
f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead"
)
return data
@model_validator(mode="before")

View File

@@ -26,6 +26,7 @@ class ChatTemplate(str, Enum):
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama4 = "llama4" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name

View File

@@ -538,6 +538,8 @@ def setup_deepspeed_env(cfg, stage=None):
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if str(cfg.fsdp_config.fsdp_version) == "2":
os.environ["FSDP_VERSION"] = "2"
if cfg.fsdp_config.fsdp_activation_checkpointing:
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
if cfg.fsdp_config.fsdp_offload_params:
@@ -556,6 +558,10 @@ def setup_fsdp_envs(cfg):
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = (
cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
)
if cfg.fsdp_config.fsdp_reshard_after_forward is not None:
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = (
"true" if cfg.fsdp_config.fsdp_reshard_after_forward else "false"
)
def prepare_optim_env(cfg):
@@ -576,7 +582,9 @@ def prepare_optim_env(cfg):
setup_torch_compile_env(cfg)
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
if cfg.fp8:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
elif (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"

View File

@@ -7,14 +7,16 @@ import os
from pathlib import Path
import pytest
import transformers
import yaml
from accelerate.test_utils import execute_subprocess_async
from huggingface_hub import snapshot_download
from packaging import version
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -28,6 +30,10 @@ def download_model():
snapshot_download("HuggingFaceTB/SmolLM2-135M")
def transformers_version_eq(required_version):
return version.parse(transformers.__version__) == version.parse(required_version)
class TestMultiGPULlama:
"""
Test case for Llama models using LoRA
@@ -56,7 +62,7 @@ class TestMultiGPULlama:
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
@@ -108,7 +114,7 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.01,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
@@ -116,6 +122,7 @@ class TestMultiGPULlama:
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:20%]",
},
],
"num_epochs": 1,
@@ -193,7 +200,7 @@ class TestMultiGPULlama:
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
@@ -390,7 +397,7 @@ class TestMultiGPULlama:
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
@@ -403,7 +410,7 @@ class TestMultiGPULlama:
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
@@ -450,6 +457,86 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@require_torch_2_6_0
@pytest.mark.parametrize(
"attention_backend",
["flash", "flex"],
)
@pytest.mark.parametrize(
"fsdp_reshard_after_forward",
[True, False],
)
def test_fsdp2_packed(
self, temp_dir, attention_backend, fsdp_reshard_after_forward
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine",
"fsdp": [
"auto_wrap",
],
"fsdp_config": {
"fsdp_version": 2,
# "fsdp_forward_prefetch": True, # not yet implemented in accelerate
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
},
"use_tensorboard": True,
}
)
if attention_backend == "flash":
cfg.flash_attention = True
elif attention_backend == "flex":
cfg.flex_attention = True
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -469,7 +556,7 @@ class TestMultiGPULlama:
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
@@ -483,7 +570,7 @@ class TestMultiGPULlama:
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
@@ -530,6 +617,12 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
# TODO: remove skip once deepspeed regression is fixed
# see https://github.com/huggingface/transformers/pull/37324
@pytest.mark.skipif(
transformers_version_eq("4.51.0"),
reason="zero3 is not supported with transformers==4.51.0",
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
@@ -566,7 +659,7 @@ class TestMultiGPULlama:
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
@@ -639,7 +732,7 @@ class TestMultiGPULlama:
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
@@ -712,7 +805,7 @@ class TestMultiGPULlama:
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
@@ -759,6 +852,9 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.skip(
reason="fix untrained tokens brittle with lots of edge cases in latest transformers"
)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -797,7 +893,7 @@ class TestMultiGPULlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
# "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
}
)

View File

@@ -31,7 +31,7 @@ class TestMultiGPURay:
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
@@ -94,8 +94,8 @@ class TestMultiGPURay:
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},