upstream updates for momentum change

This commit is contained in:
Wing Lian
2025-03-24 03:39:42 -04:00
committed by Wing Lian
parent 64fe284765
commit 76d26366ad

View File

@@ -21,7 +21,7 @@ class SOAP(optim.Optimizer):
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
Adam's betas parameters (b1, b2).
shampoo_beta (`float`, *optional*, defaults to -1):
If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
If >= 0, use this beta for the preconditioner (L and R in paper, state["GG"] below) moving average instead of betas[1].
eps (`float`, *optional*, defaults to 1e-08):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
@@ -107,14 +107,17 @@ class SOAP(optim.Optimizer):
return new_grad
@torch.no_grad()
def step(self):
def step(self, closure=None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is None:
loss = None
else:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
@@ -158,7 +161,7 @@ class SOAP(optim.Optimizer):
continue # first step is skipped so that we never use the current gradients in the projection.
# Projecting gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG']
# i.e. projecting to the eigenbases of matrices in state["GG"]
grad_projected = self.project(
grad,
state,
@@ -173,7 +176,7 @@ class SOAP(optim.Optimizer):
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).add_(
grad_projected.square(), alpha=(1.0 - beta2)
)
@@ -181,13 +184,14 @@ class SOAP(optim.Optimizer):
denom = exp_avg_sq.sqrt().add_(group["eps"])
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG']
exp_avg_projected = self.project(
exp_avg,
state,
merge_dims=group["merge_dims"],
max_precond_dim=group["max_precond_dim"],
)
# i.e. projecting to the eigenbases of matrices in state["GG"]
# exp_avg_projected = self.project(
# exp_avg,
# state,
# merge_dims=group["merge_dims"],
# max_precond_dim=group["max_precond_dim"],
# )
exp_avg_projected = exp_avg
step_size = group["lr"]
if group["correct_bias"]:
@@ -307,6 +311,13 @@ class SOAP(optim.Optimizer):
"""
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
"""
if state["Q"] is not None:
state["exp_avg"] = self.project_back(
state["exp_avg"],
state,
merge_dims=merge_dims,
max_precond_dim=max_precond_dim,
)
if grad.dim() == 1:
if precondition_1d and grad.shape[0] <= max_precond_dim:
state["GG"][0].lerp_(
@@ -348,6 +359,15 @@ class SOAP(optim.Optimizer):
state["Q"] = self.get_orthogonal_matrix_QR(
state, max_precond_dim, merge_dims
)
# state["Q"] = self.get_fast_QR(state, max_precond_dim, merge_dims)
if state["step"] > 0:
state["exp_avg"] = self.project(
state["exp_avg"],
state,
merge_dims=merge_dims,
max_precond_dim=max_precond_dim,
)
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""