From 5e50d1e8f080c00c0affe6f926ea812e0f91d0dd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 8 May 2025 18:23:25 -0400 Subject: [PATCH] batch flattening with xformers too --- src/axolotl/utils/schemas/config.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 9db374409..cc7fc8e3d 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -475,8 +475,14 @@ class AxolotlInputConfig( def check_batch_flattening_fa(cls, data): if data.get("batch_flattening"): batch_flattening_auto = data.get("batch_flattening") == "auto" - if not data.get("flash_attention") and not batch_flattening_auto: - raise ValueError("batch_flattening requires flash attention") + if ( + not data.get("flash_attention") + and not data.get("xformers_attention") + and not batch_flattening_auto + ): + raise ValueError( + "batch_flattening requires flash attention or xformers" + ) if data.get("sample_packing") and not batch_flattening_auto: raise ValueError("batch_flattening not compatible with sample_packing") if data.get("micro_batch_size") == 1 and not batch_flattening_auto: