Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)
* added torch check for adopt, wip * lint * gonna put torch version checking somewhere else * added ENVcapabilities class for torch version checking * lint + pydantic * ENVCapabilities -> EnvCapabilities * forgot to git add v0_4_1/__init__.py * removed redundancy * add check if env_capabilities not specified * make env_capabilities compulsory [skip e2e] * fixup env_capabilities * modified test_validation.py to accomodate env_capabilities * adopt torch version test [skip e2e] * raise error * test correct torch version * test torch version above requirement * Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * removed unused is_totch_min --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -100,8 +100,8 @@ def print_dep_versions():
|
||||
print("*" * 40)
|
||||
print("**** Axolotl Dependency Versions *****")
|
||||
for pkg in packages:
|
||||
version = _is_package_available(pkg, return_version=True)
|
||||
print(f"{pkg: >{max_len}}: {version[1]: <15}")
|
||||
pkg_version = _is_package_available(pkg, return_version=True)
|
||||
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
|
||||
print("*" * 40)
|
||||
|
||||
|
||||
@@ -444,6 +444,9 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
},
|
||||
)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
Reference in New Issue
Block a user