upstream updates for momentum change
This commit is contained in:
@@ -21,7 +21,7 @@ class SOAP(optim.Optimizer):
|
||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
|
||||
Adam's betas parameters (b1, b2).
|
||||
shampoo_beta (`float`, *optional*, defaults to -1):
|
||||
If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
|
||||
If >= 0, use this beta for the preconditioner (L and R in paper, state["GG"] below) moving average instead of betas[1].
|
||||
eps (`float`, *optional*, defaults to 1e-08):
|
||||
Adam's epsilon for numerical stability.
|
||||
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
|
||||
@@ -107,14 +107,17 @@ class SOAP(optim.Optimizer):
|
||||
return new_grad
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
def step(self, closure=None):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is None:
|
||||
loss = None
|
||||
else:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
@@ -158,7 +161,7 @@ class SOAP(optim.Optimizer):
|
||||
continue # first step is skipped so that we never use the current gradients in the projection.
|
||||
|
||||
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
||||
# i.e. projecting to the eigenbases of matrices in state['GG']
|
||||
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
||||
grad_projected = self.project(
|
||||
grad,
|
||||
state,
|
||||
@@ -173,7 +176,7 @@ class SOAP(optim.Optimizer):
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
||||
exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
|
||||
exp_avg_sq.mul_(beta2).add_(
|
||||
grad_projected.square(), alpha=(1.0 - beta2)
|
||||
)
|
||||
@@ -181,13 +184,14 @@ class SOAP(optim.Optimizer):
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
||||
# i.e. projecting to the eigenbases of matrices in state['GG']
|
||||
exp_avg_projected = self.project(
|
||||
exp_avg,
|
||||
state,
|
||||
merge_dims=group["merge_dims"],
|
||||
max_precond_dim=group["max_precond_dim"],
|
||||
)
|
||||
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
||||
# exp_avg_projected = self.project(
|
||||
# exp_avg,
|
||||
# state,
|
||||
# merge_dims=group["merge_dims"],
|
||||
# max_precond_dim=group["max_precond_dim"],
|
||||
# )
|
||||
exp_avg_projected = exp_avg
|
||||
|
||||
step_size = group["lr"]
|
||||
if group["correct_bias"]:
|
||||
@@ -307,6 +311,13 @@ class SOAP(optim.Optimizer):
|
||||
"""
|
||||
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
||||
"""
|
||||
if state["Q"] is not None:
|
||||
state["exp_avg"] = self.project_back(
|
||||
state["exp_avg"],
|
||||
state,
|
||||
merge_dims=merge_dims,
|
||||
max_precond_dim=max_precond_dim,
|
||||
)
|
||||
if grad.dim() == 1:
|
||||
if precondition_1d and grad.shape[0] <= max_precond_dim:
|
||||
state["GG"][0].lerp_(
|
||||
@@ -348,6 +359,15 @@ class SOAP(optim.Optimizer):
|
||||
state["Q"] = self.get_orthogonal_matrix_QR(
|
||||
state, max_precond_dim, merge_dims
|
||||
)
|
||||
# state["Q"] = self.get_fast_QR(state, max_precond_dim, merge_dims)
|
||||
|
||||
if state["step"] > 0:
|
||||
state["exp_avg"] = self.project(
|
||||
state["exp_avg"],
|
||||
state,
|
||||
merge_dims=merge_dims,
|
||||
max_precond_dim=max_precond_dim,
|
||||
)
|
||||
|
||||
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user