batch flattening with xformers too
This commit is contained in:
@@ -475,8 +475,14 @@ class AxolotlInputConfig(
|
|||||||
def check_batch_flattening_fa(cls, data):
|
def check_batch_flattening_fa(cls, data):
|
||||||
if data.get("batch_flattening"):
|
if data.get("batch_flattening"):
|
||||||
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
||||||
if not data.get("flash_attention") and not batch_flattening_auto:
|
if (
|
||||||
raise ValueError("batch_flattening requires flash attention")
|
not data.get("flash_attention")
|
||||||
|
and not data.get("xformers_attention")
|
||||||
|
and not batch_flattening_auto
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"batch_flattening requires flash attention or xformers"
|
||||||
|
)
|
||||||
if data.get("sample_packing") and not batch_flattening_auto:
|
if data.get("sample_packing") and not batch_flattening_auto:
|
||||||
raise ValueError("batch_flattening not compatible with sample_packing")
|
raise ValueError("batch_flattening not compatible with sample_packing")
|
||||||
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
||||||
|
|||||||
Reference in New Issue
Block a user