add check for broken fsdp+grad_accum
This commit is contained in:
@@ -1402,6 +1402,16 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_grad_accum_4_46_2(cls, data):
|
||||
if data.get("fsdp") and data.get("gradient_accumulation_steps") > 1:
|
||||
if version("transformers") == "4.46.2":
|
||||
raise ValueError(
|
||||
"FSDP w/ gradient_accumulation_steps > 1 is broken with transformers==4.46.2. "
|
||||
"Please use a lower value or switch to an older version of transformers."
|
||||
)
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
Reference in New Issue
Block a user