ADOPT optimizer integration (#2032) [skip ci]

* adopt integration

* stuff

* doc and test for ADOPT

* rearrangement

* fixed formatting

* hacking pre-commit

* chore: lint

* update module doc for adopt optimizer

* remove un-necessary example yaml for adopt optimizer

* skip test adopt if torch<2.5.1

* formatting

* use version.parse

* specifies required torch version for adopt_adamw

---------

Co-authored-by: sunny <sunnyliu19981005@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Sunny Liu
2024-11-13 17:10:17 -05:00
committed by GitHub
parent 659ee5d723
commit 1d7aee0ad2
6 changed files with 588 additions and 5 deletions

View File

@@ -6,11 +6,13 @@ import shutil
import tempfile
import unittest
from functools import wraps
from importlib.metadata import version
from pathlib import Path
import torch
# from importlib.metadata import version
from packaging import version
def with_temp_dir(test_func):
@wraps(test_func)
@@ -43,12 +45,24 @@ def require_torch_2_3_1(test_case):
"""
def is_min_2_3_1():
torch_version = version("torch")
return torch_version >= "2.3.1"
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.3.1")
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
def require_torch_2_5_1(test_case):
"""
Decorator marking a test that requires torch >= 2.3.1
"""
def is_min_2_5_1():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.5.1")
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)