Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
60763b2e61 fix missing return 2024-11-14 10:14:13 -05:00
Wing Lian
082a41af9d add check for broken fsdp+grad_accum 2024-11-14 10:12:57 -05:00

View File

@@ -1402,6 +1402,17 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_grad_accum_4_46_2(cls, data):
if data.get("fsdp") and data.get("gradient_accumulation_steps") > 1:
if version("transformers") == "4.46.2":
raise ValueError(
"FSDP w/ gradient_accumulation_steps > 1 is broken with transformers==4.46.2. "
"Please use a lower value or switch to an older version of transformers."
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""