Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)
* added torch check for adopt, wip * lint * gonna put torch version checking somewhere else * added ENVcapabilities class for torch version checking * lint + pydantic * ENVCapabilities -> EnvCapabilities * forgot to git add v0_4_1/__init__.py * removed redundancy * add check if env_capabilities not specified * make env_capabilities compulsory [skip e2e] * fixup env_capabilities * modified test_validation.py to accomodate env_capabilities * adopt torch version test [skip e2e] * raise error * test correct torch version * test torch version above requirement * Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * removed unused is_totch_min --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -672,6 +672,9 @@ class TestValidation(BaseValidation):
|
||||
{
|
||||
"bf16": True,
|
||||
"capabilities": {"bf16": False},
|
||||
"env_capabilities": {
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -1160,6 +1163,38 @@ class TestValidation(BaseValidation):
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_torch_version_adopt_req(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"optimizer": "adopt_adamw",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*ADOPT optimizer is incompatible with torch version*",
|
||||
):
|
||||
env_capabilities = {"torch_version": "2.3.0"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.1"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.2"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
|
||||
class TestValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user