add check for broken fsdp+grad_accum

This commit is contained in:
Wing Lian
2024-11-14 10:12:57 -05:00
parent 2d7830fda6
commit 082a41af9d

View File

@@ -1402,6 +1402,16 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_grad_accum_4_46_2(cls, data):
if data.get("fsdp") and data.get("gradient_accumulation_steps") > 1:
if version("transformers") == "4.46.2":
raise ValueError(
"FSDP w/ gradient_accumulation_steps > 1 is broken with transformers==4.46.2. "
"Please use a lower value or switch to an older version of transformers."
)
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""