add check for broken fsdp+grad_accum
This commit is contained in:
@@ -1402,6 +1402,16 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_grad_accum_4_46_2(cls, data):
|
||||||
|
if data.get("fsdp") and data.get("gradient_accumulation_steps") > 1:
|
||||||
|
if version("transformers") == "4.46.2":
|
||||||
|
raise ValueError(
|
||||||
|
"FSDP w/ gradient_accumulation_steps > 1 is broken with transformers==4.46.2. "
|
||||||
|
"Please use a lower value or switch to an older version of transformers."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|||||||
Reference in New Issue
Block a user