Compare commits
2 Commits
sdpa-cp
...
optimizer-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6056e35de | ||
|
|
09c685fd2c |
@@ -422,6 +422,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.torch_compile_mode:
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
if self.cfg.compile_optimizer:
|
||||
training_args_kwargs["compile_optimizer"] = True
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -58,6 +59,42 @@ class AxolotlGRPOTrainer(
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def get_train_dataloader(self):
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="training"
|
||||
)
|
||||
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size
|
||||
* self.args.steps_per_generation, # < this is the change
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||
}
|
||||
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_train_sampler()
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = partial(
|
||||
seed_worker,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Module for Axolotl trainer optimizer mixin"""
|
||||
|
||||
import torch
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from transformers.trainer import Trainer
|
||||
@@ -185,12 +186,12 @@ class OptimizerMixin(Trainer):
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
LOG.info(f"skipped: {skipped / 2 ** 20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
@@ -199,6 +200,11 @@ class OptimizerMixin(Trainer):
|
||||
|
||||
return self.optimizer
|
||||
|
||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||
super().create_optimizer_and_scheduler(num_training_steps)
|
||||
if self.args.compile_optimizer:
|
||||
self.optimizer.step = torch.compile(self.optimizer.step)
|
||||
|
||||
|
||||
class OptimizerInitMixin:
|
||||
"""
|
||||
|
||||
@@ -141,6 +141,10 @@ class AxolotlTrainingMixins:
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
compile_optimizer: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Whether to compile the optimizer for faster training."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
|
||||
@@ -275,6 +275,7 @@ class AxolotlInputConfig(
|
||||
torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = (
|
||||
None
|
||||
)
|
||||
compile_optimizer: bool | None = None
|
||||
|
||||
max_steps: int | None = None
|
||||
warmup_steps: int | None = None
|
||||
|
||||
Reference in New Issue
Block a user