Remove extra torch.compile call (#2904)
* debug * debug * debug * moving validation code to transformers * revert unneeded change * add accelerator config to base trainer builder * add back accumulated_cache_size_limit setting * lint
This commit is contained in:
@@ -418,6 +418,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_args_kwargs["torch_compile_backend"] = (
|
||||
@@ -426,6 +429,10 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.torch_compile_mode:
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
if self.cfg.accelerator_config:
|
||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
@@ -510,5 +517,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
self._configure_scheduler(training_args_kwargs)
|
||||
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
|
||||
self._configure_torch_compile(training_args_kwargs)
|
||||
self._configure_accelerator_config(training_args_kwargs)
|
||||
|
||||
return training_args_kwargs, trainer_kwargs
|
||||
|
||||
@@ -310,11 +310,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.neftune_noise_alpha
|
||||
)
|
||||
|
||||
if self.cfg.accelerator_config:
|
||||
training_arguments_kwargs["accelerator_config"] = (
|
||||
self.cfg.accelerator_config
|
||||
)
|
||||
|
||||
if self.cfg.image_size:
|
||||
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
||||
if self.cfg.image_resize_algorithm:
|
||||
|
||||
@@ -75,18 +75,6 @@ class AxolotlTrainer(
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _create_multipack_sampler(
|
||||
self, base_sampler: Sampler, dataset: Dataset
|
||||
) -> MultipackBatchSampler:
|
||||
|
||||
Reference in New Issue
Block a user