From bf4cd672522729e5b0ad9a5523ad83a60de19243 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 5 Apr 2024 12:47:32 +0900 Subject: [PATCH] feat: validate sample packing requires flash_attention (#1465) * feat: validate sample packing requires flash_attention * fix: check for sdp_attn per suggestion * feat: add FA to tests --- .../utils/config/models/input/v0_4_1/__init__.py | 15 +++++++++++++++ tests/test_validation.py | 5 +++++ 2 files changed, 20 insertions(+) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 56307da0b..ad332da2d 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1,6 +1,7 @@ """ Module for pydantic models for configuration """ + # pylint: disable=too-many-lines import logging @@ -655,6 +656,20 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_sample_packing_wo_flash(cls, data): + if ( + data.get("sample_packing") + and not data.get("flash_attention") + and not data.get("sdp_attention") + ): + raise ValueError( + "sample_packing requires flash_attention or sdp_attention to be set to true" + ) + + return data + @model_validator(mode="before") @classmethod def check_sample_packing_w_rl(cls, data): diff --git a/tests/test_validation.py b/tests/test_validation.py index 7a8d80cb7..4865712c4 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -600,6 +600,7 @@ class TestValidation(BaseValidation): { "sample_packing": True, "pad_to_sequence_len": None, + "flash_attention": True, } ) | minimal_cfg @@ -901,6 +902,7 @@ class TestValidation(BaseValidation): { "sample_packing": True, "eval_table_size": 100, + "flash_attention": True, } ) | minimal_cfg @@ -916,6 +918,7 @@ class TestValidation(BaseValidation): { "sample_packing": True, "eval_sample_packing": False, + "flash_attention": True, } ) | minimal_cfg @@ -928,6 +931,7 @@ class TestValidation(BaseValidation): { "sample_packing": False, "eval_table_size": 100, + "flash_attention": True, } ) | minimal_cfg @@ -941,6 +945,7 @@ class TestValidation(BaseValidation): "sample_packing": True, "eval_table_size": 100, "eval_sample_packing": False, + "flash_attention": True, } ) | minimal_cfg