Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)
* added torch check for adopt, wip * lint * gonna put torch version checking somewhere else * added ENVcapabilities class for torch version checking * lint + pydantic * ENVCapabilities -> EnvCapabilities * forgot to git add v0_4_1/__init__.py * removed redundancy * add check if env_capabilities not specified * make env_capabilities compulsory [skip e2e] * fixup env_capabilities * modified test_validation.py to accomodate env_capabilities * adopt torch version test [skip e2e] * raise error * test correct torch version * test torch version above requirement * Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * removed unused is_totch_min --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -409,7 +409,7 @@ lr_div_factor: # Learning rate div factor
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_apex_fused
|
||||
# - adopt_adamw (only for torch version >= 2.5.1)
|
||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - sgd
|
||||
|
||||
@@ -100,8 +100,8 @@ def print_dep_versions():
|
||||
print("*" * 40)
|
||||
print("**** Axolotl Dependency Versions *****")
|
||||
for pkg in packages:
|
||||
version = _is_package_available(pkg, return_version=True)
|
||||
print(f"{pkg: >{max_len}}: {version[1]: <15}")
|
||||
pkg_version = _is_package_available(pkg, return_version=True)
|
||||
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
|
||||
print("*" * 40)
|
||||
|
||||
|
||||
@@ -444,6 +444,9 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
},
|
||||
)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
@@ -562,7 +562,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
ADOPT(
|
||||
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
||||
optimizer_grouped_parameters,
|
||||
decouple=True,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -229,7 +229,11 @@ def normalize_cfg_datasets(cfg):
|
||||
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
||||
|
||||
|
||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
def validate_config(
|
||||
cfg: DictDefault,
|
||||
capabilities: Optional[dict] = None,
|
||||
env_capabilities: Optional[dict] = None,
|
||||
):
|
||||
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
||||
AxolotlInputConfig = AxolotlInputConfigBase
|
||||
|
||||
@@ -239,14 +243,24 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
AxolotlInputConfig, # pylint: disable=invalid-name
|
||||
) = merge_input_args()
|
||||
|
||||
if capabilities:
|
||||
if capabilities or env_capabilities:
|
||||
if (capabilities and not env_capabilities) or (
|
||||
env_capabilities and not capabilities
|
||||
):
|
||||
raise ValueError(
|
||||
"Both capabilities and env_capabilities must be provided or not provided."
|
||||
)
|
||||
|
||||
return DictDefault(
|
||||
dict(
|
||||
AxolotlConfigWCapabilities(
|
||||
**cfg.to_dict(), capabilities=capabilities
|
||||
**cfg.to_dict(),
|
||||
capabilities=capabilities,
|
||||
env_capabilities=env_capabilities,
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
|
||||
return DictDefault(
|
||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ import os
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from packaging import version
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
@@ -21,7 +22,7 @@ from transformers import SchedulerType
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.config.models.internals import GPUCapabilities
|
||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||
|
||||
@@ -1478,6 +1479,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
capabilities: GPUCapabilities
|
||||
env_capabilities: EnvCapabilities
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_bf16(self):
|
||||
@@ -1552,3 +1554,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_adopt_torch_version(cls, data):
|
||||
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
|
||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||
raise ValueError(
|
||||
"ADOPT optimizer is incompatible with torch version < 2.5.1"
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -12,3 +12,9 @@ class GPUCapabilities(BaseModel):
|
||||
n_gpu: int = Field(default=1)
|
||||
n_node: int = Field(default=1)
|
||||
compute_capability: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class EnvCapabilities(BaseModel):
|
||||
"""model to manage the environment capabilities statically"""
|
||||
|
||||
torch_version: Optional[str] = Field(default=None)
|
||||
|
||||
@@ -6,21 +6,29 @@ Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeo
|
||||
"""
|
||||
# mypy: ignore-errors
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import Callable, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim.optimizer import (
|
||||
from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,
|
||||
DeviceDict,
|
||||
Optimizer,
|
||||
ParamsT,
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_device_dtype_check_for_fused,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
_fused_doc,
|
||||
_get_capturable_supported_devices,
|
||||
_get_scalar_dtype,
|
||||
_get_value,
|
||||
_maximize_doc,
|
||||
_stack_if_compiling,
|
||||
_use_grad_for_differentiable,
|
||||
_view_as_real,
|
||||
)
|
||||
@@ -35,8 +43,9 @@ class ADOPT(Optimizer):
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||
eps: float = 1e-6,
|
||||
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
||||
weight_decay: float = 0.0,
|
||||
decoupled: bool = False,
|
||||
decouple: bool = False,
|
||||
*,
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
@@ -62,12 +71,14 @@ class ADOPT(Optimizer):
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
self.clip_lambda = clip_lambda
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
decoupled=decoupled,
|
||||
decouple=decouple,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
@@ -219,8 +230,9 @@ class ADOPT(Optimizer):
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group["lr"],
|
||||
clip_lambda=self.clip_lambda,
|
||||
weight_decay=group["weight_decay"],
|
||||
decoupled=group["decoupled"],
|
||||
decouple=group["decouple"],
|
||||
eps=group["eps"],
|
||||
maximize=group["maximize"],
|
||||
foreach=group["foreach"],
|
||||
@@ -247,8 +259,9 @@ def _single_tensor_adopt(
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
clip_lambda: Optional[Callable[[int], float]],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
decouple: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
@@ -276,14 +289,10 @@ def _single_tensor_adopt(
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled:
|
||||
param.add_(param, alpha=-lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
if weight_decay != 0 and not decouple:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
if torch.is_complex(param):
|
||||
grad = torch.view_as_real(grad)
|
||||
@@ -293,20 +302,29 @@ def _single_tensor_adopt(
|
||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||
param = torch.view_as_real(param)
|
||||
|
||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||
if step == 1:
|
||||
if step == 0:
|
||||
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||
# update step
|
||||
step_t += 1
|
||||
continue
|
||||
|
||||
if weight_decay != 0 and decouple:
|
||||
param.add_(param, alpha=-lr * weight_decay)
|
||||
|
||||
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||
if step == 2:
|
||||
exp_avg.addcdiv_(grad, denom)
|
||||
else:
|
||||
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
||||
normed_grad = grad.div(denom)
|
||||
if clip_lambda is not None:
|
||||
clip = clip_lambda(step)
|
||||
normed_grad.clamp_(-clip, clip)
|
||||
|
||||
exp_avg.lerp_(normed_grad, 1 - beta1)
|
||||
|
||||
param.add_(exp_avg, alpha=-lr)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
|
||||
|
||||
def _multi_tensor_adopt(
|
||||
params: List[Tensor],
|
||||
@@ -321,8 +339,9 @@ def _multi_tensor_adopt(
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
clip_lambda: Optional[Callable[[int], float]],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
decouple: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
@@ -376,6 +395,51 @@ def _multi_tensor_adopt(
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||
|
||||
if weight_decay != 0 and not decouple:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
else:
|
||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||
device_grads, device_params, alpha=weight_decay
|
||||
)
|
||||
|
||||
if device_state_steps[0] == 0:
|
||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
else:
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
continue
|
||||
|
||||
if weight_decay != 0 and decouple:
|
||||
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
|
||||
|
||||
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
|
||||
if clip_lambda is not None:
|
||||
clip = clip_lambda(device_state_steps[0])
|
||||
torch._foreach_maximum_(normed_grad, -clip)
|
||||
torch._foreach_minimum_(normed_grad, clip)
|
||||
|
||||
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
|
||||
|
||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
torch._foreach_addcmul_(
|
||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||
)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
@@ -387,41 +451,6 @@ def _multi_tensor_adopt(
|
||||
else:
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled:
|
||||
torch._foreach_add_(
|
||||
device_params, device_params, alpha=-lr * weight_decay
|
||||
)
|
||||
else:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
else:
|
||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||
device_grads, device_params, alpha=weight_decay
|
||||
)
|
||||
|
||||
if device_state_steps[0] == 1:
|
||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||
continue
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
||||
|
||||
if device_state_steps[0] == 2:
|
||||
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
||||
else:
|
||||
torch._foreach_mul_(device_exp_avgs, beta1)
|
||||
torch._foreach_addcdiv_(
|
||||
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
||||
)
|
||||
|
||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
torch._foreach_addcmul_(
|
||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||
)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||
def adopt(
|
||||
@@ -443,8 +472,9 @@ def adopt(
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
clip_lambda: Optional[Callable[[int], float]],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
decouple: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
):
|
||||
@@ -497,8 +527,9 @@ def adopt(
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
clip_lambda=clip_lambda,
|
||||
weight_decay=weight_decay,
|
||||
decoupled=decoupled,
|
||||
decouple=decouple,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
|
||||
@@ -53,7 +53,7 @@ def require_torch_2_3_1(test_case):
|
||||
|
||||
def require_torch_2_5_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.3.1
|
||||
Decorator marking a test that requires torch >= 2.5.1
|
||||
"""
|
||||
|
||||
def is_min_2_5_1():
|
||||
|
||||
@@ -672,6 +672,9 @@ class TestValidation(BaseValidation):
|
||||
{
|
||||
"bf16": True,
|
||||
"capabilities": {"bf16": False},
|
||||
"env_capabilities": {
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -1160,6 +1163,38 @@ class TestValidation(BaseValidation):
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_torch_version_adopt_req(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"optimizer": "adopt_adamw",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*ADOPT optimizer is incompatible with torch version*",
|
||||
):
|
||||
env_capabilities = {"torch_version": "2.3.0"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.1"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.2"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
|
||||
class TestValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
|
||||
@@ -72,6 +72,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
@@ -124,6 +127,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
@@ -177,6 +183,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
@@ -231,6 +240,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
Reference in New Issue
Block a user