batch flattening with xformers too
This commit is contained in:
@@ -475,8 +475,14 @@ class AxolotlInputConfig(
|
||||
def check_batch_flattening_fa(cls, data):
|
||||
if data.get("batch_flattening"):
|
||||
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
||||
if not data.get("flash_attention") and not batch_flattening_auto:
|
||||
raise ValueError("batch_flattening requires flash attention")
|
||||
if (
|
||||
not data.get("flash_attention")
|
||||
and not data.get("xformers_attention")
|
||||
and not batch_flattening_auto
|
||||
):
|
||||
raise ValueError(
|
||||
"batch_flattening requires flash attention or xformers"
|
||||
)
|
||||
if data.get("sample_packing") and not batch_flattening_auto:
|
||||
raise ValueError("batch_flattening not compatible with sample_packing")
|
||||
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
||||
|
||||
Reference in New Issue
Block a user