feat: validate sample packing requires flash_attention (#1465)
* feat: validate sample packing requires flash_attention * fix: check for sdp_attn per suggestion * feat: add FA to tests
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Module for pydantic models for configuration
|
Module for pydantic models for configuration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -655,6 +656,20 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_sample_packing_wo_flash(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("sample_packing")
|
||||||
|
and not data.get("flash_attention")
|
||||||
|
and not data.get("sdp_attention")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"sample_packing requires flash_attention or sdp_attention to be set to true"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sample_packing_w_rl(cls, data):
|
def check_sample_packing_w_rl(cls, data):
|
||||||
|
|||||||
@@ -600,6 +600,7 @@ class TestValidation(BaseValidation):
|
|||||||
{
|
{
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": None,
|
"pad_to_sequence_len": None,
|
||||||
|
"flash_attention": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -901,6 +902,7 @@ class TestValidation(BaseValidation):
|
|||||||
{
|
{
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_table_size": 100,
|
"eval_table_size": 100,
|
||||||
|
"flash_attention": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -916,6 +918,7 @@ class TestValidation(BaseValidation):
|
|||||||
{
|
{
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
|
"flash_attention": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -928,6 +931,7 @@ class TestValidation(BaseValidation):
|
|||||||
{
|
{
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"eval_table_size": 100,
|
"eval_table_size": 100,
|
||||||
|
"flash_attention": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -941,6 +945,7 @@ class TestValidation(BaseValidation):
|
|||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_table_size": 100,
|
"eval_table_size": 100,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
|
"flash_attention": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
|
|||||||
Reference in New Issue
Block a user