Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)

* added torch check for adopt, wip

* lint

* gonna put torch version checking somewhere else

* added ENVcapabilities class for torch version checking

* lint + pydantic

* ENVCapabilities -> EnvCapabilities

* forgot to git add v0_4_1/__init__.py

* removed redundancy

* add check if env_capabilities not specified

* make env_capabilities compulsory [skip e2e]

* fixup env_capabilities

* modified test_validation.py to accomodate env_capabilities

* adopt torch version test [skip e2e]

* raise error

* test correct torch version

* test torch version above requirement

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* removed unused is_totch_min

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Sunny Liu
2024-12-02 20:15:39 -05:00
committed by bursteratom
parent cac785ec0e
commit d56260c8d5
10 changed files with 189 additions and 66 deletions

View File

@@ -409,7 +409,7 @@ lr_div_factor: # Learning rate div factor
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_apex_fused
# - adopt_adamw (only for torch version >= 2.5.1)
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adafactor
# - adamw_anyprecision
# - sgd

View File

@@ -100,8 +100,8 @@ def print_dep_versions():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}")
pkg_version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
print("*" * 40)
@@ -444,6 +444,9 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
},
)
prepare_optim_env(cfg)

View File

@@ -562,7 +562,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)

View File

@@ -229,7 +229,11 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
def validate_config(
cfg: DictDefault,
capabilities: Optional[dict] = None,
env_capabilities: Optional[dict] = None,
):
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase
@@ -239,14 +243,24 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args()
if capabilities:
if capabilities or env_capabilities:
if (capabilities and not env_capabilities) or (
env_capabilities and not capabilities
):
raise ValueError(
"Both capabilities and env_capabilities must be provided or not provided."
)
return DictDefault(
dict(
AxolotlConfigWCapabilities(
**cfg.to_dict(), capabilities=capabilities
**cfg.to_dict(),
capabilities=capabilities,
env_capabilities=env_capabilities,
).model_dump(exclude_none=True)
)
)
return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
)

View File

@@ -9,6 +9,7 @@ import os
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from packaging import version
from pydantic import (
BaseModel,
Field,
@@ -21,7 +22,7 @@ from transformers import SchedulerType
from transformers.training_args import OptimizerNames
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import GPUCapabilities
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
LOG = logging.getLogger("axolotl.utils.config.models.input")
@@ -1478,6 +1479,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
capabilities: GPUCapabilities
env_capabilities: EnvCapabilities
@model_validator(mode="after")
def check_bf16(self):
@@ -1552,3 +1554,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data

View File

@@ -12,3 +12,9 @@ class GPUCapabilities(BaseModel):
n_gpu: int = Field(default=1)
n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None)
class EnvCapabilities(BaseModel):
"""model to manage the environment capabilities statically"""
torch_version: Optional[str] = Field(default=None)

View File

@@ -6,21 +6,29 @@ Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeo
"""
# mypy: ignore-errors
# pylint: skip-file
# flake8: noqa
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union, cast
from typing import Callable, List, Optional, Tuple, Union, cast
import torch
from torch import Tensor
from torch.optim.optimizer import (
from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,
DeviceDict,
Optimizer,
ParamsT,
_capturable_doc,
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_fused_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
)
@@ -35,8 +43,9 @@ class ADOPT(Optimizer):
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6,
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
weight_decay: float = 0.0,
decoupled: bool = False,
decouple: bool = False,
*,
foreach: Optional[bool] = None,
maximize: bool = False,
@@ -62,12 +71,14 @@ class ADOPT(Optimizer):
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
self.clip_lambda = clip_lambda
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
decoupled=decoupled,
decouple=decouple,
maximize=maximize,
foreach=foreach,
capturable=capturable,
@@ -219,8 +230,9 @@ class ADOPT(Optimizer):
beta1=beta1,
beta2=beta2,
lr=group["lr"],
clip_lambda=self.clip_lambda,
weight_decay=group["weight_decay"],
decoupled=group["decoupled"],
decouple=group["decouple"],
eps=group["eps"],
maximize=group["maximize"],
foreach=group["foreach"],
@@ -247,8 +259,9 @@ def _single_tensor_adopt(
beta1: float,
beta2: float,
lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float,
decoupled: bool,
decouple: bool,
eps: float,
maximize: bool,
capturable: bool,
@@ -276,14 +289,10 @@ def _single_tensor_adopt(
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
# update step
step_t += 1
step = step_t if capturable or differentiable else _get_value(step_t)
if weight_decay != 0:
if decoupled:
param.add_(param, alpha=-lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if weight_decay != 0 and not decouple:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
grad = torch.view_as_real(grad)
@@ -293,20 +302,29 @@ def _single_tensor_adopt(
exp_avg_sq = torch.view_as_real(exp_avg_sq)
param = torch.view_as_real(param)
step = step_t if capturable or differentiable else _get_value(step_t)
if step == 1:
if step == 0:
exp_avg_sq.addcmul_(grad, grad.conj())
# update step
step_t += 1
continue
if weight_decay != 0 and decouple:
param.add_(param, alpha=-lr * weight_decay)
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
if step == 2:
exp_avg.addcdiv_(grad, denom)
else:
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
normed_grad = grad.div(denom)
if clip_lambda is not None:
clip = clip_lambda(step)
normed_grad.clamp_(-clip, clip)
exp_avg.lerp_(normed_grad, 1 - beta1)
param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
# update step
step_t += 1
def _multi_tensor_adopt(
params: List[Tensor],
@@ -321,8 +339,9 @@ def _multi_tensor_adopt(
beta1: float,
beta2: float,
lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float,
decoupled: bool,
decouple: bool,
eps: float,
maximize: bool,
capturable: bool,
@@ -376,6 +395,51 @@ def _multi_tensor_adopt(
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
if weight_decay != 0 and not decouple:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 0:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
continue
if weight_decay != 0 and decouple:
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
if clip_lambda is not None:
clip = clip_lambda(device_state_steps[0])
torch._foreach_maximum_(normed_grad, -clip)
torch._foreach_minimum_(normed_grad, clip)
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
@@ -387,41 +451,6 @@ def _multi_tensor_adopt(
else:
torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0:
if decoupled:
torch._foreach_add_(
device_params, device_params, alpha=-lr * weight_decay
)
else:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 1:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
continue
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
if device_state_steps[0] == 2:
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
else:
torch._foreach_mul_(device_exp_avgs, beta1)
torch._foreach_addcdiv_(
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
def adopt(
@@ -443,8 +472,9 @@ def adopt(
beta1: float,
beta2: float,
lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float,
decoupled: bool,
decouple: bool,
eps: float,
maximize: bool,
):
@@ -497,8 +527,9 @@ def adopt(
beta1=beta1,
beta2=beta2,
lr=lr,
clip_lambda=clip_lambda,
weight_decay=weight_decay,
decoupled=decoupled,
decouple=decouple,
eps=eps,
maximize=maximize,
capturable=capturable,

View File

@@ -53,7 +53,7 @@ def require_torch_2_3_1(test_case):
def require_torch_2_5_1(test_case):
"""
Decorator marking a test that requires torch >= 2.3.1
Decorator marking a test that requires torch >= 2.5.1
"""
def is_min_2_5_1():

View File

@@ -672,6 +672,9 @@ class TestValidation(BaseValidation):
{
"bf16": True,
"capabilities": {"bf16": False},
"env_capabilities": {
"torch_version": "2.5.1",
},
}
)
| minimal_cfg
@@ -1160,6 +1163,38 @@ class TestValidation(BaseValidation):
in self._caplog.records[0].message
)
def test_torch_version_adopt_req(self, minimal_cfg):
cfg = (
DictDefault(
{
"optimizer": "adopt_adamw",
}
)
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*ADOPT optimizer is incompatible with torch version*",
):
env_capabilities = {"torch_version": "2.3.0"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
env_capabilities = {"torch_version": "2.5.1"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
env_capabilities = {"torch_version": "2.5.2"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
class TestValidationCheckModelConfig(BaseValidation):
"""

View File

@@ -72,6 +72,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1,
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.5.1",
},
)
_check_config()
@@ -124,6 +127,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1,
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.5.1",
},
)
_check_config()
@@ -177,6 +183,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1,
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.5.1",
},
)
_check_config()
@@ -231,6 +240,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1,
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.5.1",
},
)
_check_config()