Compare commits
1 Commits
main
...
optimizer-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6056e35de |
@@ -422,6 +422,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.torch_compile_mode:
|
if self.cfg.torch_compile_mode:
|
||||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
|
if self.cfg.compile_optimizer:
|
||||||
|
training_args_kwargs["compile_optimizer"] = True
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
training_args_kwargs["gradient_checkpointing"] = (
|
training_args_kwargs["gradient_checkpointing"] = (
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Module for Axolotl trainer optimizer mixin"""
|
"""Module for Axolotl trainer optimizer mixin"""
|
||||||
|
|
||||||
|
import torch
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
@@ -185,12 +186,12 @@ class OptimizerMixin(Trainer):
|
|||||||
p.data_ptr(): p.numel() for p in module.parameters()
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
}.values()
|
}.values()
|
||||||
)
|
)
|
||||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
|
||||||
manager.register_module_override(
|
manager.register_module_override(
|
||||||
module, "weight", {"optim_bits": 32}
|
module, "weight", {"optim_bits": 32}
|
||||||
)
|
)
|
||||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
LOG.info(f"skipped: {skipped / 2 ** 20}M params")
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -199,6 +200,11 @@ class OptimizerMixin(Trainer):
|
|||||||
|
|
||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||||
|
super().create_optimizer_and_scheduler(num_training_steps)
|
||||||
|
if self.args.compile_optimizer:
|
||||||
|
self.optimizer.step = torch.compile(self.optimizer.step)
|
||||||
|
|
||||||
|
|
||||||
class OptimizerInitMixin:
|
class OptimizerInitMixin:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -141,6 +141,10 @@ class AxolotlTrainingMixins:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||||
)
|
)
|
||||||
|
compile_optimizer: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Whether to compile the optimizer for faster training."},
|
||||||
|
)
|
||||||
qlora: bool = field(
|
qlora: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "whether this is a qlora training"},
|
metadata={"help": "whether this is a qlora training"},
|
||||||
|
|||||||
@@ -275,6 +275,7 @@ class AxolotlInputConfig(
|
|||||||
torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = (
|
torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = (
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
compile_optimizer: bool | None = None
|
||||||
|
|
||||||
max_steps: int | None = None
|
max_steps: int | None = None
|
||||||
warmup_steps: int | None = None
|
warmup_steps: int | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user