prepare plugins needs to happen so registration can occur to build the plugin args (#2119)
* prepare plugins needs to happen so registration can occur to build the plugin args use yaml.dump include dataset and more assertions * attempt to manually register plugins rather than use fn * fix fixture * remove fixture * move cli test to patched dir * fix cce validation
This commit is contained in:
@@ -432,6 +432,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
except: # pylint: disable=bare-except # noqa: E722
|
||||
gpu_version = None
|
||||
|
||||
prepare_plugins(cfg)
|
||||
|
||||
cfg = validate_config(
|
||||
cfg,
|
||||
capabilities={
|
||||
@@ -444,8 +446,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
},
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
prepare_opinionated_env(cfg)
|
||||
|
||||
@@ -33,7 +33,7 @@ class CutCrossEntropyArgs(BaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_dtype_is_half(cls, data):
|
||||
if not (data.get("bf16") or data.get("fp16")):
|
||||
if data.get("cut_cross_entropy") and not (data.get("bf16") or data.get("fp16")):
|
||||
raise ValueError(
|
||||
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
|
||||
"Please set `bf16` or `fp16` to `True`."
|
||||
|
||||
Reference in New Issue
Block a user