feat: validate sample packing requires flash_attention (#1465)

* feat: validate sample packing requires flash_attention

* fix: check for sdp_attn per suggestion

* feat: add FA to tests
This commit is contained in:
NanoCode012
2024-04-05 12:47:32 +09:00
committed by GitHub
parent 05b0b7e8ca
commit bf4cd67252
2 changed files with 20 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
"""
Module for pydantic models for configuration
"""
# pylint: disable=too-many-lines
import logging
@@ -655,6 +656,20 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):
if (
data.get("sample_packing")
and not data.get("flash_attention")
and not data.get("sdp_attention")
):
raise ValueError(
"sample_packing requires flash_attention or sdp_attention to be set to true"
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_rl(cls, data):