feat: cleanup old flex mask patch, suppress Matmul bnb warn, and misc (#3330) [skip-ci]

* feat: add pos id to flex attention for packing part 1

* feat: update to include sliding window mask patch

* fix: suppress MatMul8bitLt: inputs will be cast from warnings

* fix: remove redundant flex attention patch

* chore: update olmo docs

* feat: add validator patch for cross entropy
This commit is contained in:
NanoCode012
2025-12-25 17:56:20 +07:00
committed by GitHub
parent 97f1b1758d
commit 372f664c63
6 changed files with 41 additions and 167 deletions

View File

@@ -16,7 +16,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
This uses about 11.3 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
### TIPS

View File

@@ -42,10 +42,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -1,158 +0,0 @@
"""
monkeypatch for flex + packing
"""
import sys
from typing import Callable, Optional, Union
import torch
from torch.nn.attention.flex_attention import BlockMask
from transformers import Cache, PretrainedConfig
from transformers.masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
_preprocess_mask_arguments,
and_masks,
causal_mask_function,
or_masks,
)
from transformers.utils import is_torch_greater_or_equal
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
def create_causal_mask(
config: PretrainedConfig,
input_embeds: torch.Tensor,
attention_mask: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
"""
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
to what is needed in the `modeling_xxx.py` files).
Args:
config (`PretrainedConfig`):
The model config.
input_embeds (`torch.Tensor`):
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
batch size, query length and dtype.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
It can also be an already prepared 4D mask, in which case it is returned as-is.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
and_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the full layers
if (
past_key_values
and hasattr(past_key_values, "is_sliding")
and False in past_key_values.is_sliding
):
layer_idx = past_key_values.is_sliding.index(False)
else:
layer_idx = 0
original_attention_mask = (
None
if attention_mask is None
else attention_mask.clone().to(cache_position.device)
)
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
)
if early_exit:
return attention_mask
batch_size, total_seq_len = cache_position.shape
key_length = total_seq_len
document_ids = torch.nn.functional.pad(
original_attention_mask, value=0, pad=(0, key_length)
)
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
if attention_mask is not None:
def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask_ = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
final_mask = causal_mask_ & document_mask
return final_mask
mask_factory_function = causal_doc_mask_mod
else:
mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
# Do not allow skip if we are compiling (this is to match BC)
allow_is_causal_skip = (
not past_key_values.is_compileable if past_key_values is not None else True
)
# Allow slight deviations from causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError(
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
)
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError(
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
)
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
# We now create the mask
causal_mask = mask_interface(
batch_size=batch_size,
cache_position=cache_position,
kv_length=kv_length,
kv_offset=kv_offset,
mask_function=mask_factory_function,
attention_mask=attention_mask,
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
)
return causal_mask
def patch_create_causal_mask(model_type):
import transformers.masking_utils
transformers.masking_utils.create_causal_mask = create_causal_mask
if model_type:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(module_path)
module.create_causal_mask = create_causal_mask
del sys.modules[module_path]
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e

View File

@@ -199,12 +199,6 @@ class PatchManager:
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
if self.cfg.sample_packing:
from axolotl.core.attention.flex_block_mask import (
patch_create_causal_mask,
)
patch_create_causal_mask(self.cfg.model_config_type)
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""

View File

@@ -2,9 +2,17 @@
import functools
import logging
import warnings
from axolotl.utils.distributed import is_main_process
# Suppress noisy bitsandbytes warnings about dtype casting during quantization
warnings.filterwarnings(
"ignore",
message=".*MatMul8bitLt: inputs will be cast from.*",
category=UserWarning,
)
# Adapted from Accelerate
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py

View File

@@ -801,6 +801,36 @@ class OptimizationValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_cross_entropy_conflicts(cls, data):
"""Check for mutual exclusivity between cross entropy patch options.
Only one of the following can be enabled at a time:
- cut_cross_entropy (CutCrossEntropyPlugin)
- chunked_cross_entropy
- liger_cross_entropy (LigerPlugin)
- liger_fused_linear_cross_entropy (LigerPlugin)
"""
ce_options = {
"cut_cross_entropy": data.get("cut_cross_entropy"),
"chunked_cross_entropy": data.get("chunked_cross_entropy"),
"liger_cross_entropy": data.get("liger_cross_entropy"),
"liger_fused_linear_cross_entropy": data.get(
"liger_fused_linear_cross_entropy"
),
}
enabled_options = [k for k, v in ce_options.items() if v]
if len(enabled_options) > 1:
raise ValueError(
f"Only one cross entropy optimization can be enabled at a time. "
f"Found {len(enabled_options)} enabled: {', '.join(enabled_options)}. "
"Please disable all but one."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version(cls, data):