prepare plugins needs to happen so registration can occur to build the plugin args (#2119)
* prepare plugins needs to happen so registration can occur to build the plugin args use yaml.dump include dataset and more assertions * attempt to manually register plugins rather than use fn * fix fixture * remove fixture * move cli test to patched dir * fix cce validation
This commit is contained in:
@@ -16,11 +16,11 @@ if v < V("2.4.0"):
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||||
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers")
|
|
||||||
|
|
||||||
UNINSTALL_PREFIX = ""
|
UNINSTALL_PREFIX = ""
|
||||||
if cce_spec and not cce_spec_transformers:
|
if cce_spec:
|
||||||
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
||||||
|
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
|
|||||||
@@ -432,6 +432,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
except: # pylint: disable=bare-except # noqa: E722
|
except: # pylint: disable=bare-except # noqa: E722
|
||||||
gpu_version = None
|
gpu_version = None
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
|
||||||
cfg = validate_config(
|
cfg = validate_config(
|
||||||
cfg,
|
cfg,
|
||||||
capabilities={
|
capabilities={
|
||||||
@@ -444,8 +446,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
prepare_plugins(cfg)
|
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
prepare_optim_env(cfg)
|
||||||
|
|
||||||
prepare_opinionated_env(cfg)
|
prepare_opinionated_env(cfg)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class CutCrossEntropyArgs(BaseModel):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_dtype_is_half(cls, data):
|
def check_dtype_is_half(cls, data):
|
||||||
if not (data.get("bf16") or data.get("fp16")):
|
if data.get("cut_cross_entropy") and not (data.get("bf16") or data.get("fp16")):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
|
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
|
||||||
"Please set `bf16` or `fp16` to `True`."
|
"Please set `bf16` or `fp16` to `True`."
|
||||||
|
|||||||
47
tests/e2e/patched/test_cli_integrations.py
Normal file
47
tests/e2e/patched/test_cli_integrations.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
test cases to make sure the plugin args are loaded from the config file
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.cli import load_cfg
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
class TestPluginArgs:
|
||||||
|
"""
|
||||||
|
test class for plugin args loaded from the config file
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_liger_plugin_args(self, temp_dir):
|
||||||
|
test_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"learning_rate": 0.000001,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
|
||||||
|
"liger_layer_norm": True,
|
||||||
|
"liger_rope": True,
|
||||||
|
"liger_rms_norm": False,
|
||||||
|
"liger_glu_activation": True,
|
||||||
|
"liger_fused_linear_cross_entropy": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(test_cfg.to_dict()))
|
||||||
|
cfg = load_cfg(str(Path(temp_dir) / "config.yaml"))
|
||||||
|
assert cfg.liger_layer_norm is True
|
||||||
|
assert cfg.liger_rope is True
|
||||||
|
assert cfg.liger_rms_norm is False
|
||||||
|
assert cfg.liger_glu_activation is True
|
||||||
|
assert cfg.liger_fused_linear_cross_entropy is True
|
||||||
Reference in New Issue
Block a user