prepare plugins needs to happen so registration can occur to build the plugin args (#2119)

* prepare plugins needs to happen so registration can occur to build the plugin args

use yaml.dump

include dataset and more assertions

* attempt to manually register plugins rather than use fn

* fix fixture

* remove fixture

* move cli test to patched dir

* fix cce validation
This commit is contained in:
Wing Lian
2024-12-03 15:06:09 -05:00
committed by GitHub
parent 1ef70312ba
commit d87df2c776
4 changed files with 53 additions and 6 deletions

View File

@@ -432,6 +432,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)
cfg = validate_config(
cfg,
capabilities={
@@ -444,8 +446,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
},
)
prepare_plugins(cfg)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)

View File

@@ -33,7 +33,7 @@ class CutCrossEntropyArgs(BaseModel):
@model_validator(mode="before")
@classmethod
def check_dtype_is_half(cls, data):
if not (data.get("bf16") or data.get("fp16")):
if data.get("cut_cross_entropy") and not (data.get("bf16") or data.get("fp16")):
raise ValueError(
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
"Please set `bf16` or `fp16` to `True`."