Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)

* added torch check for adopt, wip

* lint

* gonna put torch version checking somewhere else

* added ENVcapabilities class for torch version checking

* lint + pydantic

* ENVCapabilities -> EnvCapabilities

* forgot to git add v0_4_1/__init__.py

* removed redundancy

* add check if env_capabilities not specified

* make env_capabilities compulsory [skip e2e]

* fixup env_capabilities

* modified test_validation.py to accomodate env_capabilities

* adopt torch version test [skip e2e]

* raise error

* test correct torch version

* test torch version above requirement

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* removed unused is_totch_min

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Sunny Liu
2024-12-02 20:15:39 -05:00
committed by GitHub
parent 9f6d0b5587
commit d5f58b6509
10 changed files with 189 additions and 66 deletions

View File

@@ -672,6 +672,9 @@ class TestValidation(BaseValidation):
{
"bf16": True,
"capabilities": {"bf16": False},
"env_capabilities": {
"torch_version": "2.5.1",
},
}
)
| minimal_cfg
@@ -1160,6 +1163,38 @@ class TestValidation(BaseValidation):
in self._caplog.records[0].message
)
def test_torch_version_adopt_req(self, minimal_cfg):
cfg = (
DictDefault(
{
"optimizer": "adopt_adamw",
}
)
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*ADOPT optimizer is incompatible with torch version*",
):
env_capabilities = {"torch_version": "2.3.0"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
env_capabilities = {"torch_version": "2.5.1"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
env_capabilities = {"torch_version": "2.5.2"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
class TestValidationCheckModelConfig(BaseValidation):
"""