diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 9db374409..cc7fc8e3d 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -475,8 +475,14 @@ class AxolotlInputConfig( def check_batch_flattening_fa(cls, data): if data.get("batch_flattening"): batch_flattening_auto = data.get("batch_flattening") == "auto" - if not data.get("flash_attention") and not batch_flattening_auto: - raise ValueError("batch_flattening requires flash attention") + if ( + not data.get("flash_attention") + and not data.get("xformers_attention") + and not batch_flattening_auto + ): + raise ValueError( + "batch_flattening requires flash attention or xformers" + ) if data.get("sample_packing") and not batch_flattening_auto: raise ValueError("batch_flattening not compatible with sample_packing") if data.get("micro_batch_size") == 1 and not batch_flattening_auto: