batch flattening with xformers too

This commit is contained in:
Wing Lian
2025-05-08 18:23:25 -04:00
parent 7fb01f0461
commit 5e50d1e8f0

View File

@@ -475,8 +475,14 @@ class AxolotlInputConfig(
def check_batch_flattening_fa(cls, data):
if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto"
if not data.get("flash_attention") and not batch_flattening_auto:
raise ValueError("batch_flattening requires flash attention")
if (
not data.get("flash_attention")
and not data.get("xformers_attention")
and not batch_flattening_auto
):
raise ValueError(
"batch_flattening requires flash attention or xformers"
)
if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto: