upstream updates for momentum change
This commit is contained in:
@@ -21,7 +21,7 @@ class SOAP(optim.Optimizer):
|
|||||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
|
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
|
||||||
Adam's betas parameters (b1, b2).
|
Adam's betas parameters (b1, b2).
|
||||||
shampoo_beta (`float`, *optional*, defaults to -1):
|
shampoo_beta (`float`, *optional*, defaults to -1):
|
||||||
If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
|
If >= 0, use this beta for the preconditioner (L and R in paper, state["GG"] below) moving average instead of betas[1].
|
||||||
eps (`float`, *optional*, defaults to 1e-08):
|
eps (`float`, *optional*, defaults to 1e-08):
|
||||||
Adam's epsilon for numerical stability.
|
Adam's epsilon for numerical stability.
|
||||||
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
|
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
|
||||||
@@ -107,14 +107,17 @@ class SOAP(optim.Optimizer):
|
|||||||
return new_grad
|
return new_grad
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self):
|
def step(self, closure=None):
|
||||||
"""
|
"""
|
||||||
Performs a single optimization step.
|
Performs a single optimization step.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||||
"""
|
"""
|
||||||
loss = None
|
if closure is None:
|
||||||
|
loss = None
|
||||||
|
else:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
@@ -158,7 +161,7 @@ class SOAP(optim.Optimizer):
|
|||||||
continue # first step is skipped so that we never use the current gradients in the projection.
|
continue # first step is skipped so that we never use the current gradients in the projection.
|
||||||
|
|
||||||
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
||||||
# i.e. projecting to the eigenbases of matrices in state['GG']
|
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
||||||
grad_projected = self.project(
|
grad_projected = self.project(
|
||||||
grad,
|
grad,
|
||||||
state,
|
state,
|
||||||
@@ -173,7 +176,7 @@ class SOAP(optim.Optimizer):
|
|||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
# Decay the first and second moment running average coefficient
|
||||||
# In-place operations to update the averages at the same time
|
# In-place operations to update the averages at the same time
|
||||||
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
|
||||||
exp_avg_sq.mul_(beta2).add_(
|
exp_avg_sq.mul_(beta2).add_(
|
||||||
grad_projected.square(), alpha=(1.0 - beta2)
|
grad_projected.square(), alpha=(1.0 - beta2)
|
||||||
)
|
)
|
||||||
@@ -181,13 +184,14 @@ class SOAP(optim.Optimizer):
|
|||||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||||
|
|
||||||
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
||||||
# i.e. projecting to the eigenbases of matrices in state['GG']
|
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
||||||
exp_avg_projected = self.project(
|
# exp_avg_projected = self.project(
|
||||||
exp_avg,
|
# exp_avg,
|
||||||
state,
|
# state,
|
||||||
merge_dims=group["merge_dims"],
|
# merge_dims=group["merge_dims"],
|
||||||
max_precond_dim=group["max_precond_dim"],
|
# max_precond_dim=group["max_precond_dim"],
|
||||||
)
|
# )
|
||||||
|
exp_avg_projected = exp_avg
|
||||||
|
|
||||||
step_size = group["lr"]
|
step_size = group["lr"]
|
||||||
if group["correct_bias"]:
|
if group["correct_bias"]:
|
||||||
@@ -307,6 +311,13 @@ class SOAP(optim.Optimizer):
|
|||||||
"""
|
"""
|
||||||
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
||||||
"""
|
"""
|
||||||
|
if state["Q"] is not None:
|
||||||
|
state["exp_avg"] = self.project_back(
|
||||||
|
state["exp_avg"],
|
||||||
|
state,
|
||||||
|
merge_dims=merge_dims,
|
||||||
|
max_precond_dim=max_precond_dim,
|
||||||
|
)
|
||||||
if grad.dim() == 1:
|
if grad.dim() == 1:
|
||||||
if precondition_1d and grad.shape[0] <= max_precond_dim:
|
if precondition_1d and grad.shape[0] <= max_precond_dim:
|
||||||
state["GG"][0].lerp_(
|
state["GG"][0].lerp_(
|
||||||
@@ -348,6 +359,15 @@ class SOAP(optim.Optimizer):
|
|||||||
state["Q"] = self.get_orthogonal_matrix_QR(
|
state["Q"] = self.get_orthogonal_matrix_QR(
|
||||||
state, max_precond_dim, merge_dims
|
state, max_precond_dim, merge_dims
|
||||||
)
|
)
|
||||||
|
# state["Q"] = self.get_fast_QR(state, max_precond_dim, merge_dims)
|
||||||
|
|
||||||
|
if state["step"] > 0:
|
||||||
|
state["exp_avg"] = self.project(
|
||||||
|
state["exp_avg"],
|
||||||
|
state,
|
||||||
|
merge_dims=merge_dims,
|
||||||
|
max_precond_dim=max_precond_dim,
|
||||||
|
)
|
||||||
|
|
||||||
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user