diff --git a/src/axolotl/core/trainer_builder/base.py b/src/axolotl/core/trainer_builder/base.py index 9f6dd2561..2a14db56e 100644 --- a/src/axolotl/core/trainer_builder/base.py +++ b/src/axolotl/core/trainer_builder/base.py @@ -20,9 +20,11 @@ import importlib.util import logging import sys from abc import abstractmethod +from contextlib import suppress from pathlib import Path from typing import Any, Dict +import torch from transformers import ( TrainerCallback, ) @@ -40,6 +42,9 @@ from axolotl.utils.schemas.enums import CustomSupportedOptimizers LOG = logging.getLogger(__name__) +with suppress(ImportError): + import torch._dynamo # pylint: disable=ungrouped-imports + class TrainerBuilderBase(abc.ABC): """Base class for trainer builder.""" @@ -452,4 +457,17 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.optim_target_modules: training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules + # torch compile + if self.cfg.torch_compile and getattr(torch, "_dynamo", None): + torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access + True + ) + training_args_kwargs["torch_compile"] = self.cfg.torch_compile + if self.cfg.torch_compile_backend: + training_args_kwargs["torch_compile_backend"] = ( + self.cfg.torch_compile_backend + ) + if self.cfg.torch_compile_mode: + training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode + return training_args_kwargs diff --git a/src/axolotl/core/trainer_builder/sft.py b/src/axolotl/core/trainer_builder/sft.py index 3fea0e4d8..22b4e45fa 100644 --- a/src/axolotl/core/trainer_builder/sft.py +++ b/src/axolotl/core/trainer_builder/sft.py @@ -7,7 +7,6 @@ import os from pathlib import Path from typing import Type, Union -import torch import transformers from transformers import ( DataCollatorWithFlattening, @@ -53,11 +52,6 @@ from axolotl.utils.collators import ( ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator -try: - import torch._dynamo # pylint: disable=ungrouped-imports -except ImportError: - pass - LOG = logging.getLogger(__name__) @@ -195,20 +189,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.greater_is_better: training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better - if self.cfg.torch_compile and getattr(torch, "_dynamo", None): - torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access - True - ) - training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile - if self.cfg.torch_compile_backend: - training_arguments_kwargs["torch_compile_backend"] = ( - self.cfg.torch_compile_backend - ) - if self.cfg.torch_compile_mode: - training_arguments_kwargs["torch_compile_mode"] = ( - self.cfg.torch_compile_mode - ) - # DDP Config if self.cfg.ddp_timeout: training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout @@ -464,13 +444,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.reward_model: collator = RewardDataCollatorWithPadding elif use_batch_sampler_collator: - if self.cfg.flex_attention: - collator = V2BatchSamplerDataCollatorForSeq2Seq - elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: - collator = V2BatchSamplerDataCollatorForSeq2Seq - elif ( - self.cfg.model_config_type in ["llama"] - and self.cfg.flash_attention is not True + # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, + # supported multipack models, or non-flash-attention llama + if ( + self.cfg.flex_attention + or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES + or ( + self.cfg.model_config_type in ["llama"] + and self.cfg.flash_attention is not True + ) ): collator = V2BatchSamplerDataCollatorForSeq2Seq else: