feat: validate sample packing requires flash_attention (#1465)

* feat: validate sample packing requires flash_attention

* fix: check for sdp_attn per suggestion

* feat: add FA to tests
This commit is contained in:
NanoCode012
2024-04-05 12:47:32 +09:00
committed by GitHub
parent 05b0b7e8ca
commit bf4cd67252
2 changed files with 20 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
""" """
Module for pydantic models for configuration Module for pydantic models for configuration
""" """
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import logging import logging
@@ -655,6 +656,20 @@ class AxolotlInputConfig(
return data return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):
if (
data.get("sample_packing")
and not data.get("flash_attention")
and not data.get("sdp_attention")
):
raise ValueError(
"sample_packing requires flash_attention or sdp_attention to be set to true"
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sample_packing_w_rl(cls, data): def check_sample_packing_w_rl(cls, data):

View File

@@ -600,6 +600,7 @@ class TestValidation(BaseValidation):
{ {
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": None, "pad_to_sequence_len": None,
"flash_attention": True,
} }
) )
| minimal_cfg | minimal_cfg
@@ -901,6 +902,7 @@ class TestValidation(BaseValidation):
{ {
"sample_packing": True, "sample_packing": True,
"eval_table_size": 100, "eval_table_size": 100,
"flash_attention": True,
} }
) )
| minimal_cfg | minimal_cfg
@@ -916,6 +918,7 @@ class TestValidation(BaseValidation):
{ {
"sample_packing": True, "sample_packing": True,
"eval_sample_packing": False, "eval_sample_packing": False,
"flash_attention": True,
} }
) )
| minimal_cfg | minimal_cfg
@@ -928,6 +931,7 @@ class TestValidation(BaseValidation):
{ {
"sample_packing": False, "sample_packing": False,
"eval_table_size": 100, "eval_table_size": 100,
"flash_attention": True,
} }
) )
| minimal_cfg | minimal_cfg
@@ -941,6 +945,7 @@ class TestValidation(BaseValidation):
"sample_packing": True, "sample_packing": True,
"eval_table_size": 100, "eval_table_size": 100,
"eval_sample_packing": False, "eval_sample_packing": False,
"flash_attention": True,
} }
) )
| minimal_cfg | minimal_cfg