upstream updates for momentum change

This commit is contained in:
Wing Lian
2025-03-24 03:39:42 -04:00
committed by Wing Lian
parent 64fe284765
commit 76d26366ad

View File

@@ -21,7 +21,7 @@ class SOAP(optim.Optimizer):
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`): betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
Adam's betas parameters (b1, b2). Adam's betas parameters (b1, b2).
shampoo_beta (`float`, *optional*, defaults to -1): shampoo_beta (`float`, *optional*, defaults to -1):
If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1]. If >= 0, use this beta for the preconditioner (L and R in paper, state["GG"] below) moving average instead of betas[1].
eps (`float`, *optional*, defaults to 1e-08): eps (`float`, *optional*, defaults to 1e-08):
Adam's epsilon for numerical stability. Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient. weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
@@ -107,14 +107,17 @@ class SOAP(optim.Optimizer):
return new_grad return new_grad
@torch.no_grad() @torch.no_grad()
def step(self): def step(self, closure=None):
""" """
Performs a single optimization step. Performs a single optimization step.
Arguments: Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
""" """
loss = None if closure is None:
loss = None
else:
loss = closure()
for group in self.param_groups: for group in self.param_groups:
for p in group["params"]: for p in group["params"]:
@@ -158,7 +161,7 @@ class SOAP(optim.Optimizer):
continue # first step is skipped so that we never use the current gradients in the projection. continue # first step is skipped so that we never use the current gradients in the projection.
# Projecting gradients to the eigenbases of Shampoo's preconditioner # Projecting gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG'] # i.e. projecting to the eigenbases of matrices in state["GG"]
grad_projected = self.project( grad_projected = self.project(
grad, grad,
state, state,
@@ -173,7 +176,7 @@ class SOAP(optim.Optimizer):
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time # In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).add_( exp_avg_sq.mul_(beta2).add_(
grad_projected.square(), alpha=(1.0 - beta2) grad_projected.square(), alpha=(1.0 - beta2)
) )
@@ -181,13 +184,14 @@ class SOAP(optim.Optimizer):
denom = exp_avg_sq.sqrt().add_(group["eps"]) denom = exp_avg_sq.sqrt().add_(group["eps"])
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG'] # i.e. projecting to the eigenbases of matrices in state["GG"]
exp_avg_projected = self.project( # exp_avg_projected = self.project(
exp_avg, # exp_avg,
state, # state,
merge_dims=group["merge_dims"], # merge_dims=group["merge_dims"],
max_precond_dim=group["max_precond_dim"], # max_precond_dim=group["max_precond_dim"],
) # )
exp_avg_projected = exp_avg
step_size = group["lr"] step_size = group["lr"]
if group["correct_bias"]: if group["correct_bias"]:
@@ -307,6 +311,13 @@ class SOAP(optim.Optimizer):
""" """
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper). Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
""" """
if state["Q"] is not None:
state["exp_avg"] = self.project_back(
state["exp_avg"],
state,
merge_dims=merge_dims,
max_precond_dim=max_precond_dim,
)
if grad.dim() == 1: if grad.dim() == 1:
if precondition_1d and grad.shape[0] <= max_precond_dim: if precondition_1d and grad.shape[0] <= max_precond_dim:
state["GG"][0].lerp_( state["GG"][0].lerp_(
@@ -348,6 +359,15 @@ class SOAP(optim.Optimizer):
state["Q"] = self.get_orthogonal_matrix_QR( state["Q"] = self.get_orthogonal_matrix_QR(
state, max_precond_dim, merge_dims state, max_precond_dim, merge_dims
) )
# state["Q"] = self.get_fast_QR(state, max_precond_dim, merge_dims)
if state["step"] > 0:
state["exp_avg"] = self.project(
state["exp_avg"],
state,
merge_dims=merge_dims,
max_precond_dim=max_precond_dim,
)
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000): def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
""" """