feat: moved torch compile to base and refactor collator setting

This commit is contained in:
NanoCode012
2025-05-22 17:56:58 +07:00
parent 0fc6499461
commit e55d64f709
2 changed files with 27 additions and 27 deletions

View File

@@ -20,9 +20,11 @@ import importlib.util
import logging import logging
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from contextlib import suppress
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
import torch
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
) )
@@ -40,6 +42,9 @@ from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
with suppress(ImportError):
import torch._dynamo # pylint: disable=ungrouped-imports
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
"""Base class for trainer builder.""" """Base class for trainer builder."""
@@ -452,4 +457,17 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.optim_target_modules: if self.cfg.optim_target_modules:
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
# torch compile
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
return training_args_kwargs return training_args_kwargs

View File

@@ -7,7 +7,6 @@ import os
from pathlib import Path from pathlib import Path
from typing import Type, Union from typing import Type, Union
import torch
import transformers import transformers
from transformers import ( from transformers import (
DataCollatorWithFlattening, DataCollatorWithFlattening,
@@ -53,11 +52,6 @@ from axolotl.utils.collators import (
) )
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -195,20 +189,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.greater_is_better: if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_arguments_kwargs["torch_compile_mode"] = (
self.cfg.torch_compile_mode
)
# DDP Config # DDP Config
if self.cfg.ddp_timeout: if self.cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
@@ -464,13 +444,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model: if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator: elif use_batch_sampler_collator:
if self.cfg.flex_attention: # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
collator = V2BatchSamplerDataCollatorForSeq2Seq # supported multipack models, or non-flash-attention llama
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: if (
collator = V2BatchSamplerDataCollatorForSeq2Seq self.cfg.flex_attention
elif ( or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
self.cfg.model_config_type in ["llama"] or (
and self.cfg.flash_attention is not True self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
)
): ):
collator = V2BatchSamplerDataCollatorForSeq2Seq collator = V2BatchSamplerDataCollatorForSeq2Seq
else: else: