feat: add support and end-to-end tests for multiple custom optimizers… (#3457) [skip ci]

* feat: add support and end-to-end tests for multiple custom optimizers including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW.

* feat: Add standalone flashoptim integration test and E2E tests for various custom optimizers including FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion, optimi_adamw, adopt_adamw, muon, dion, and schedule_free_adamw.

* feat: introduce Pydantic schema validation for dataset, attention, and training configurations.

* feat: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.

* test: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.

* test: fix assertion in flash optimizers test to compare class names directly

* fix: address PR review - reuse require_torch_2_7_0 decorator, remove fsdp_config.version check, extract shared FSDP version helper, remove unused imports and optim_args

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Avaya Aggarwal
2026-03-20 17:54:44 +05:30
committed by GitHub
parent 5a5cf30b26
commit 1bcfc08c90
4 changed files with 115 additions and 3 deletions

View File

@@ -353,6 +353,30 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adamw":
from flashoptim import FlashAdamW
optimizer_cls = FlashAdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adam":
from flashoptim import FlashAdam
optimizer_cls = FlashAdam
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_sgd":
from flashoptim import FlashSGD
optimizer_cls = FlashSGD
elif self.cfg.optimizer == "flash_sgdw":
from flashoptim import FlashSGDW
optimizer_cls = FlashSGDW
elif self.cfg.optimizer == "flash_lion":
from flashoptim import FlashLion
optimizer_cls = FlashLion
if "betas" in adam_kwargs:
optimizer_kwargs["betas"] = adam_kwargs["betas"]
else:
raise ValueError(
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."

View File

@@ -87,6 +87,11 @@ class CustomSupportedOptimizers(str, Enum):
came_pytorch = "came_pytorch"
muon = "muon"
dion = "dion"
flash_adamw = "flash_adamw"
flash_adam = "flash_adam"
flash_sgd = "flash_sgd"
flash_sgdw = "flash_sgdw"
flash_lion = "flash_lion"
class RingAttnFunc(str, Enum):

View File

@@ -790,6 +790,14 @@ class OptimizationValidationMixin:
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@staticmethod
def _resolve_fsdp_version(data):
"""Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version."""
fsdp_version = data.get("fsdp_version")
if fsdp_version is None:
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
return fsdp_version
@model_validator(mode="before")
@classmethod
def check_muon_deepspeed_fsdp(cls, data):
@@ -799,15 +807,32 @@ class OptimizationValidationMixin:
"Muon optimizer is currently incompatible with DeepSpeed"
)
if data.get("fsdp") or data.get("fsdp_config"):
fsdp_version = data.get("fsdp_version")
if fsdp_version is None:
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
fsdp_version = cls._resolve_fsdp_version(data)
if str(fsdp_version) != "2":
raise ValueError(
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
)
return data
@model_validator(mode="before")
@classmethod
def check_flashoptim_deepspeed_fsdp(cls, data):
optimizer = data.get("optimizer") or ""
if str(optimizer).startswith("flash_"):
if data.get("deepspeed"):
raise ValueError(
f"{optimizer} optimizer is incompatible with DeepSpeed. "
"Flash optimizers only support DDP and FSDP2."
)
if data.get("fsdp") or data.get("fsdp_config"):
fsdp_version = cls._resolve_fsdp_version(data)
if str(fsdp_version) != "2":
raise ValueError(
f"{optimizer} optimizer is only compatible with FSDP2. "
"Set fsdp_version: 2 to use flash optimizers with FSDP."
)
return data
@model_validator(mode="before")
@classmethod
def check_batch_flattening_fa(cls, data):