diff --git a/src/axolotl/utils/optimizers/soap/__init__.py b/src/axolotl/utils/optimizers/soap/__init__.py index 338e35e63..4382d018f 100644 --- a/src/axolotl/utils/optimizers/soap/__init__.py +++ b/src/axolotl/utils/optimizers/soap/__init__.py @@ -21,7 +21,7 @@ class SOAP(optim.Optimizer): betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`): Adam's betas parameters (b1, b2). shampoo_beta (`float`, *optional*, defaults to -1): - If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1]. + If >= 0, use this beta for the preconditioner (L and R in paper, state["GG"] below) moving average instead of betas[1]. eps (`float`, *optional*, defaults to 1e-08): Adam's epsilon for numerical stability. weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient. @@ -107,14 +107,17 @@ class SOAP(optim.Optimizer): return new_grad @torch.no_grad() - def step(self): + def step(self, closure=None): """ Performs a single optimization step. Arguments: closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. """ - loss = None + if closure is None: + loss = None + else: + loss = closure() for group in self.param_groups: for p in group["params"]: @@ -158,7 +161,7 @@ class SOAP(optim.Optimizer): continue # first step is skipped so that we never use the current gradients in the projection. # Projecting gradients to the eigenbases of Shampoo's preconditioner - # i.e. projecting to the eigenbases of matrices in state['GG'] + # i.e. projecting to the eigenbases of matrices in state["GG"] grad_projected = self.project( grad, state, @@ -173,7 +176,7 @@ class SOAP(optim.Optimizer): # Decay the first and second moment running average coefficient # In-place operations to update the averages at the same time - exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1)) exp_avg_sq.mul_(beta2).add_( grad_projected.square(), alpha=(1.0 - beta2) ) @@ -181,13 +184,14 @@ class SOAP(optim.Optimizer): denom = exp_avg_sq.sqrt().add_(group["eps"]) # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner - # i.e. projecting to the eigenbases of matrices in state['GG'] - exp_avg_projected = self.project( - exp_avg, - state, - merge_dims=group["merge_dims"], - max_precond_dim=group["max_precond_dim"], - ) + # i.e. projecting to the eigenbases of matrices in state["GG"] + # exp_avg_projected = self.project( + # exp_avg, + # state, + # merge_dims=group["merge_dims"], + # max_precond_dim=group["max_precond_dim"], + # ) + exp_avg_projected = exp_avg step_size = group["lr"] if group["correct_bias"]: @@ -307,6 +311,13 @@ class SOAP(optim.Optimizer): """ Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper). """ + if state["Q"] is not None: + state["exp_avg"] = self.project_back( + state["exp_avg"], + state, + merge_dims=merge_dims, + max_precond_dim=max_precond_dim, + ) if grad.dim() == 1: if precondition_1d and grad.shape[0] <= max_precond_dim: state["GG"][0].lerp_( @@ -348,6 +359,15 @@ class SOAP(optim.Optimizer): state["Q"] = self.get_orthogonal_matrix_QR( state, max_precond_dim, merge_dims ) + # state["Q"] = self.get_fast_QR(state, max_precond_dim, merge_dims) + + if state["step"] > 0: + state["exp_avg"] = self.project( + state["exp_avg"], + state, + merge_dims=merge_dims, + max_precond_dim=max_precond_dim, + ) def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000): """