Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)
* added torch check for adopt, wip * lint * gonna put torch version checking somewhere else * added ENVcapabilities class for torch version checking * lint + pydantic * ENVCapabilities -> EnvCapabilities * forgot to git add v0_4_1/__init__.py * removed redundancy * add check if env_capabilities not specified * make env_capabilities compulsory [skip e2e] * fixup env_capabilities * modified test_validation.py to accomodate env_capabilities * adopt torch version test [skip e2e] * raise error * test correct torch version * test torch version above requirement * Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * removed unused is_totch_min --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -409,7 +409,7 @@ lr_div_factor: # Learning rate div factor
|
|||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
# - adopt_adamw (only for torch version >= 2.5.1)
|
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
# - sgd
|
# - sgd
|
||||||
|
|||||||
@@ -100,8 +100,8 @@ def print_dep_versions():
|
|||||||
print("*" * 40)
|
print("*" * 40)
|
||||||
print("**** Axolotl Dependency Versions *****")
|
print("**** Axolotl Dependency Versions *****")
|
||||||
for pkg in packages:
|
for pkg in packages:
|
||||||
version = _is_package_available(pkg, return_version=True)
|
pkg_version = _is_package_available(pkg, return_version=True)
|
||||||
print(f"{pkg: >{max_len}}: {version[1]: <15}")
|
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
|
||||||
print("*" * 40)
|
print("*" * 40)
|
||||||
|
|
||||||
|
|
||||||
@@ -444,6 +444,9 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
|
env_capabilities={
|
||||||
|
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
prepare_optim_env(cfg)
|
||||||
|
|||||||
@@ -562,7 +562,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
ADOPT(
|
ADOPT(
|
||||||
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
optimizer_grouped_parameters,
|
||||||
|
decouple=True,
|
||||||
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -229,7 +229,11 @@ def normalize_cfg_datasets(cfg):
|
|||||||
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
def validate_config(
|
||||||
|
cfg: DictDefault,
|
||||||
|
capabilities: Optional[dict] = None,
|
||||||
|
env_capabilities: Optional[dict] = None,
|
||||||
|
):
|
||||||
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
||||||
AxolotlInputConfig = AxolotlInputConfigBase
|
AxolotlInputConfig = AxolotlInputConfigBase
|
||||||
|
|
||||||
@@ -239,14 +243,24 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
|||||||
AxolotlInputConfig, # pylint: disable=invalid-name
|
AxolotlInputConfig, # pylint: disable=invalid-name
|
||||||
) = merge_input_args()
|
) = merge_input_args()
|
||||||
|
|
||||||
if capabilities:
|
if capabilities or env_capabilities:
|
||||||
|
if (capabilities and not env_capabilities) or (
|
||||||
|
env_capabilities and not capabilities
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Both capabilities and env_capabilities must be provided or not provided."
|
||||||
|
)
|
||||||
|
|
||||||
return DictDefault(
|
return DictDefault(
|
||||||
dict(
|
dict(
|
||||||
AxolotlConfigWCapabilities(
|
AxolotlConfigWCapabilities(
|
||||||
**cfg.to_dict(), capabilities=capabilities
|
**cfg.to_dict(),
|
||||||
|
capabilities=capabilities,
|
||||||
|
env_capabilities=env_capabilities,
|
||||||
).model_dump(exclude_none=True)
|
).model_dump(exclude_none=True)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return DictDefault(
|
return DictDefault(
|
||||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import os
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
Field,
|
Field,
|
||||||
@@ -21,7 +22,7 @@ from transformers import SchedulerType
|
|||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
from axolotl.utils.config.models.internals import GPUCapabilities
|
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||||
|
|
||||||
@@ -1477,6 +1478,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|
||||||
capabilities: GPUCapabilities
|
capabilities: GPUCapabilities
|
||||||
|
env_capabilities: EnvCapabilities
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_bf16(self):
|
def check_bf16(self):
|
||||||
@@ -1551,3 +1553,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_adopt_torch_version(cls, data):
|
||||||
|
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
|
||||||
|
env_capabilities = data.get("env_capabilities", {})
|
||||||
|
torch_version = env_capabilities.get("torch_version")
|
||||||
|
|
||||||
|
if torch_version is None:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||||
|
|
||||||
|
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||||
|
raise ValueError(
|
||||||
|
"ADOPT optimizer is incompatible with torch version < 2.5.1"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
@@ -12,3 +12,9 @@ class GPUCapabilities(BaseModel):
|
|||||||
n_gpu: int = Field(default=1)
|
n_gpu: int = Field(default=1)
|
||||||
n_node: int = Field(default=1)
|
n_node: int = Field(default=1)
|
||||||
compute_capability: Optional[str] = Field(default=None)
|
compute_capability: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class EnvCapabilities(BaseModel):
|
||||||
|
"""model to manage the environment capabilities statically"""
|
||||||
|
|
||||||
|
torch_version: Optional[str] = Field(default=None)
|
||||||
|
|||||||
@@ -6,21 +6,29 @@ Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeo
|
|||||||
"""
|
"""
|
||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
# flake8: noqa
|
||||||
# mypy: allow-untyped-decorators
|
# mypy: allow-untyped-decorators
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing import List, Optional, Tuple, Union, cast
|
from typing import Callable, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim.optimizer import (
|
from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,
|
||||||
|
DeviceDict,
|
||||||
Optimizer,
|
Optimizer,
|
||||||
ParamsT,
|
ParamsT,
|
||||||
|
_capturable_doc,
|
||||||
_default_to_fused_or_foreach,
|
_default_to_fused_or_foreach,
|
||||||
_device_dtype_check_for_fused,
|
_device_dtype_check_for_fused,
|
||||||
|
_differentiable_doc,
|
||||||
_disable_dynamo_if_unsupported,
|
_disable_dynamo_if_unsupported,
|
||||||
|
_foreach_doc,
|
||||||
|
_fused_doc,
|
||||||
_get_capturable_supported_devices,
|
_get_capturable_supported_devices,
|
||||||
_get_scalar_dtype,
|
_get_scalar_dtype,
|
||||||
_get_value,
|
_get_value,
|
||||||
|
_maximize_doc,
|
||||||
|
_stack_if_compiling,
|
||||||
_use_grad_for_differentiable,
|
_use_grad_for_differentiable,
|
||||||
_view_as_real,
|
_view_as_real,
|
||||||
)
|
)
|
||||||
@@ -35,8 +43,9 @@ class ADOPT(Optimizer):
|
|||||||
lr: Union[float, Tensor] = 1e-3,
|
lr: Union[float, Tensor] = 1e-3,
|
||||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
|
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
decoupled: bool = False,
|
decouple: bool = False,
|
||||||
*,
|
*,
|
||||||
foreach: Optional[bool] = None,
|
foreach: Optional[bool] = None,
|
||||||
maximize: bool = False,
|
maximize: bool = False,
|
||||||
@@ -62,12 +71,14 @@ class ADOPT(Optimizer):
|
|||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
|
self.clip_lambda = clip_lambda
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
betas=betas,
|
betas=betas,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
decoupled=decoupled,
|
decouple=decouple,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
foreach=foreach,
|
foreach=foreach,
|
||||||
capturable=capturable,
|
capturable=capturable,
|
||||||
@@ -219,8 +230,9 @@ class ADOPT(Optimizer):
|
|||||||
beta1=beta1,
|
beta1=beta1,
|
||||||
beta2=beta2,
|
beta2=beta2,
|
||||||
lr=group["lr"],
|
lr=group["lr"],
|
||||||
|
clip_lambda=self.clip_lambda,
|
||||||
weight_decay=group["weight_decay"],
|
weight_decay=group["weight_decay"],
|
||||||
decoupled=group["decoupled"],
|
decouple=group["decouple"],
|
||||||
eps=group["eps"],
|
eps=group["eps"],
|
||||||
maximize=group["maximize"],
|
maximize=group["maximize"],
|
||||||
foreach=group["foreach"],
|
foreach=group["foreach"],
|
||||||
@@ -247,8 +259,9 @@ def _single_tensor_adopt(
|
|||||||
beta1: float,
|
beta1: float,
|
||||||
beta2: float,
|
beta2: float,
|
||||||
lr: Union[float, Tensor],
|
lr: Union[float, Tensor],
|
||||||
|
clip_lambda: Optional[Callable[[int], float]],
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
decoupled: bool,
|
decouple: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
capturable: bool,
|
capturable: bool,
|
||||||
@@ -276,14 +289,10 @@ def _single_tensor_adopt(
|
|||||||
and param.device.type in capturable_supported_devices
|
and param.device.type in capturable_supported_devices
|
||||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
# update step
|
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||||
step_t += 1
|
|
||||||
|
|
||||||
if weight_decay != 0:
|
if weight_decay != 0 and not decouple:
|
||||||
if decoupled:
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
param.add_(param, alpha=-lr * weight_decay)
|
|
||||||
else:
|
|
||||||
grad = grad.add(param, alpha=weight_decay)
|
|
||||||
|
|
||||||
if torch.is_complex(param):
|
if torch.is_complex(param):
|
||||||
grad = torch.view_as_real(grad)
|
grad = torch.view_as_real(grad)
|
||||||
@@ -293,20 +302,29 @@ def _single_tensor_adopt(
|
|||||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||||
param = torch.view_as_real(param)
|
param = torch.view_as_real(param)
|
||||||
|
|
||||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
if step == 0:
|
||||||
if step == 1:
|
|
||||||
exp_avg_sq.addcmul_(grad, grad.conj())
|
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||||
|
# update step
|
||||||
|
step_t += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if weight_decay != 0 and decouple:
|
||||||
|
param.add_(param, alpha=-lr * weight_decay)
|
||||||
|
|
||||||
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||||
if step == 2:
|
normed_grad = grad.div(denom)
|
||||||
exp_avg.addcdiv_(grad, denom)
|
if clip_lambda is not None:
|
||||||
else:
|
clip = clip_lambda(step)
|
||||||
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
normed_grad.clamp_(-clip, clip)
|
||||||
|
|
||||||
|
exp_avg.lerp_(normed_grad, 1 - beta1)
|
||||||
|
|
||||||
param.add_(exp_avg, alpha=-lr)
|
param.add_(exp_avg, alpha=-lr)
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||||
|
|
||||||
|
# update step
|
||||||
|
step_t += 1
|
||||||
|
|
||||||
|
|
||||||
def _multi_tensor_adopt(
|
def _multi_tensor_adopt(
|
||||||
params: List[Tensor],
|
params: List[Tensor],
|
||||||
@@ -321,8 +339,9 @@ def _multi_tensor_adopt(
|
|||||||
beta1: float,
|
beta1: float,
|
||||||
beta2: float,
|
beta2: float,
|
||||||
lr: Union[float, Tensor],
|
lr: Union[float, Tensor],
|
||||||
|
clip_lambda: Optional[Callable[[int], float]],
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
decoupled: bool,
|
decouple: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
capturable: bool,
|
capturable: bool,
|
||||||
@@ -376,6 +395,51 @@ def _multi_tensor_adopt(
|
|||||||
if maximize:
|
if maximize:
|
||||||
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||||
|
|
||||||
|
if weight_decay != 0 and not decouple:
|
||||||
|
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||||
|
if maximize:
|
||||||
|
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||||
|
else:
|
||||||
|
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||||
|
device_grads, device_params, alpha=weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_state_steps[0] == 0:
|
||||||
|
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||||
|
|
||||||
|
# Update steps
|
||||||
|
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||||
|
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||||
|
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||||
|
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||||
|
torch._foreach_add_(
|
||||||
|
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch._foreach_add_(device_state_steps, 1)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
if weight_decay != 0 and decouple:
|
||||||
|
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
|
||||||
|
|
||||||
|
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||||
|
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
|
||||||
|
|
||||||
|
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
|
||||||
|
if clip_lambda is not None:
|
||||||
|
clip = clip_lambda(device_state_steps[0])
|
||||||
|
torch._foreach_maximum_(normed_grad, -clip)
|
||||||
|
torch._foreach_minimum_(normed_grad, clip)
|
||||||
|
|
||||||
|
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
|
||||||
|
|
||||||
|
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||||
|
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||||
|
torch._foreach_addcmul_(
|
||||||
|
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||||
|
)
|
||||||
|
|
||||||
# Update steps
|
# Update steps
|
||||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||||
@@ -387,41 +451,6 @@ def _multi_tensor_adopt(
|
|||||||
else:
|
else:
|
||||||
torch._foreach_add_(device_state_steps, 1)
|
torch._foreach_add_(device_state_steps, 1)
|
||||||
|
|
||||||
if weight_decay != 0:
|
|
||||||
if decoupled:
|
|
||||||
torch._foreach_add_(
|
|
||||||
device_params, device_params, alpha=-lr * weight_decay
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
|
||||||
if maximize:
|
|
||||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
|
||||||
else:
|
|
||||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
|
||||||
device_grads, device_params, alpha=weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
if device_state_steps[0] == 1:
|
|
||||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
|
||||||
continue
|
|
||||||
|
|
||||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
|
||||||
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
|
||||||
|
|
||||||
if device_state_steps[0] == 2:
|
|
||||||
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
|
||||||
else:
|
|
||||||
torch._foreach_mul_(device_exp_avgs, beta1)
|
|
||||||
torch._foreach_addcdiv_(
|
|
||||||
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
|
||||||
)
|
|
||||||
|
|
||||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
|
||||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
|
||||||
torch._foreach_addcmul_(
|
|
||||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||||
def adopt(
|
def adopt(
|
||||||
@@ -443,8 +472,9 @@ def adopt(
|
|||||||
beta1: float,
|
beta1: float,
|
||||||
beta2: float,
|
beta2: float,
|
||||||
lr: Union[float, Tensor],
|
lr: Union[float, Tensor],
|
||||||
|
clip_lambda: Optional[Callable[[int], float]],
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
decoupled: bool,
|
decouple: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
):
|
):
|
||||||
@@ -497,8 +527,9 @@ def adopt(
|
|||||||
beta1=beta1,
|
beta1=beta1,
|
||||||
beta2=beta2,
|
beta2=beta2,
|
||||||
lr=lr,
|
lr=lr,
|
||||||
|
clip_lambda=clip_lambda,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
decoupled=decoupled,
|
decouple=decouple,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
capturable=capturable,
|
capturable=capturable,
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ def require_torch_2_3_1(test_case):
|
|||||||
|
|
||||||
def require_torch_2_5_1(test_case):
|
def require_torch_2_5_1(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires torch >= 2.3.1
|
Decorator marking a test that requires torch >= 2.5.1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def is_min_2_5_1():
|
def is_min_2_5_1():
|
||||||
|
|||||||
@@ -672,6 +672,9 @@ class TestValidation(BaseValidation):
|
|||||||
{
|
{
|
||||||
"bf16": True,
|
"bf16": True,
|
||||||
"capabilities": {"bf16": False},
|
"capabilities": {"bf16": False},
|
||||||
|
"env_capabilities": {
|
||||||
|
"torch_version": "2.5.1",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -1160,6 +1163,38 @@ class TestValidation(BaseValidation):
|
|||||||
in self._caplog.records[0].message
|
in self._caplog.records[0].message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_torch_version_adopt_req(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"optimizer": "adopt_adamw",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*ADOPT optimizer is incompatible with torch version*",
|
||||||
|
):
|
||||||
|
env_capabilities = {"torch_version": "2.3.0"}
|
||||||
|
capabilities = {"bf16": False}
|
||||||
|
_ = validate_config(
|
||||||
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
env_capabilities = {"torch_version": "2.5.1"}
|
||||||
|
capabilities = {"bf16": False}
|
||||||
|
_ = validate_config(
|
||||||
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
env_capabilities = {"torch_version": "2.5.2"}
|
||||||
|
capabilities = {"bf16": False}
|
||||||
|
_ = validate_config(
|
||||||
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestValidationCheckModelConfig(BaseValidation):
|
class TestValidationCheckModelConfig(BaseValidation):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -72,6 +72,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
"n_gpu": 1,
|
"n_gpu": 1,
|
||||||
"compute_capability": "8.0",
|
"compute_capability": "8.0",
|
||||||
},
|
},
|
||||||
|
env_capabilities={
|
||||||
|
"torch_version": "2.5.1",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
_check_config()
|
_check_config()
|
||||||
@@ -124,6 +127,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
"n_gpu": 1,
|
"n_gpu": 1,
|
||||||
"compute_capability": "8.0",
|
"compute_capability": "8.0",
|
||||||
},
|
},
|
||||||
|
env_capabilities={
|
||||||
|
"torch_version": "2.5.1",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
_check_config()
|
_check_config()
|
||||||
@@ -177,6 +183,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
"n_gpu": 1,
|
"n_gpu": 1,
|
||||||
"compute_capability": "8.0",
|
"compute_capability": "8.0",
|
||||||
},
|
},
|
||||||
|
env_capabilities={
|
||||||
|
"torch_version": "2.5.1",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
_check_config()
|
_check_config()
|
||||||
@@ -231,6 +240,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
"n_gpu": 1,
|
"n_gpu": 1,
|
||||||
"compute_capability": "8.0",
|
"compute_capability": "8.0",
|
||||||
},
|
},
|
||||||
|
env_capabilities={
|
||||||
|
"torch_version": "2.5.1",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
_check_config()
|
_check_config()
|
||||||
|
|||||||
Reference in New Issue
Block a user