override the entire create_optimzier method
This commit is contained in:
@@ -43,3 +43,4 @@ trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda50034822
|
|||||||
fastcore>=1.5.29
|
fastcore>=1.5.29
|
||||||
|
|
||||||
lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main
|
lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main
|
||||||
|
yacs
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import transformers
|
|||||||
from accelerate import FullyShardedDataParallelPlugin
|
from accelerate import FullyShardedDataParallelPlugin
|
||||||
from accelerate.utils import str_to_bool
|
from accelerate.utils import str_to_bool
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from torch import nn
|
||||||
from torch.distributed.fsdp import MixedPrecision
|
from torch.distributed.fsdp import MixedPrecision
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
@@ -270,25 +271,71 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if self.args.loraplus_lr_ratio is None:
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args,
|
|
||||||
opt_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
optimizer_grouped_parameters = [
|
||||||
opt_model,
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in opt_model.named_parameters()
|
||||||
|
if (n in decay_parameters and p.requires_grad)
|
||||||
|
],
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in opt_model.named_parameters()
|
||||||
|
if (n not in decay_parameters and p.requires_grad)
|
||||||
|
],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
(
|
||||||
optimizer_cls,
|
optimizer_cls,
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
loraplus_lr_ratio,
|
) = AxolotlTrainer.get_optimizer_cls_and_kwargs(self.args)
|
||||||
loraplus_lr_embedding,
|
|
||||||
)
|
if self.args.loraplus_lr_ratio:
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
loraplus_lr_embedding = getattr(
|
||||||
|
self.args, "loraplus_lr_embedding", None
|
||||||
|
)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
optimizer_kwargs,
|
||||||
|
loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
|
import bitsandbytes
|
||||||
|
|
||||||
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
|
skipped = 0
|
||||||
|
for module in opt_model.modules():
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum(
|
||||||
|
{
|
||||||
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
|
}.values()
|
||||||
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
|||||||
Reference in New Issue
Block a user