feat: validate sample packing requires flash_attention (#1465)

* feat: validate sample packing requires flash_attention

* fix: check for sdp_attn per suggestion

* feat: add FA to tests
This commit is contained in:
NanoCode012
2024-04-05 12:47:32 +09:00
committed by GitHub
parent 05b0b7e8ca
commit bf4cd67252
2 changed files with 20 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
"""
Module for pydantic models for configuration
"""
# pylint: disable=too-many-lines
import logging
@@ -655,6 +656,20 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):
if (
data.get("sample_packing")
and not data.get("flash_attention")
and not data.get("sdp_attention")
):
raise ValueError(
"sample_packing requires flash_attention or sdp_attention to be set to true"
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_rl(cls, data):

View File

@@ -600,6 +600,7 @@ class TestValidation(BaseValidation):
{
"sample_packing": True,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
| minimal_cfg
@@ -901,6 +902,7 @@ class TestValidation(BaseValidation):
{
"sample_packing": True,
"eval_table_size": 100,
"flash_attention": True,
}
)
| minimal_cfg
@@ -916,6 +918,7 @@ class TestValidation(BaseValidation):
{
"sample_packing": True,
"eval_sample_packing": False,
"flash_attention": True,
}
)
| minimal_cfg
@@ -928,6 +931,7 @@ class TestValidation(BaseValidation):
{
"sample_packing": False,
"eval_table_size": 100,
"flash_attention": True,
}
)
| minimal_cfg
@@ -941,6 +945,7 @@ class TestValidation(BaseValidation):
"sample_packing": True,
"eval_table_size": 100,
"eval_sample_packing": False,
"flash_attention": True,
}
)
| minimal_cfg