Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)

* added torch check for adopt, wip

* lint

* gonna put torch version checking somewhere else

* added ENVcapabilities class for torch version checking

* lint + pydantic

* ENVCapabilities -> EnvCapabilities

* forgot to git add v0_4_1/__init__.py

* removed redundancy

* add check if env_capabilities not specified

* make env_capabilities compulsory [skip e2e]

* fixup env_capabilities

* modified test_validation.py to accomodate env_capabilities

* adopt torch version test [skip e2e]

* raise error

* test correct torch version

* test torch version above requirement

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* removed unused is_totch_min

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Sunny Liu
2024-12-02 20:15:39 -05:00
committed by GitHub
parent 9f6d0b5587
commit d5f58b6509
10 changed files with 189 additions and 66 deletions

View File

@@ -229,7 +229,11 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
def validate_config(
cfg: DictDefault,
capabilities: Optional[dict] = None,
env_capabilities: Optional[dict] = None,
):
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase
@@ -239,14 +243,24 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args()
if capabilities:
if capabilities or env_capabilities:
if (capabilities and not env_capabilities) or (
env_capabilities and not capabilities
):
raise ValueError(
"Both capabilities and env_capabilities must be provided or not provided."
)
return DictDefault(
dict(
AxolotlConfigWCapabilities(
**cfg.to_dict(), capabilities=capabilities
**cfg.to_dict(),
capabilities=capabilities,
env_capabilities=env_capabilities,
).model_dump(exclude_none=True)
)
)
return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
)