feat: moved torch compile to base and refactor collator setting
This commit is contained in:
@@ -20,9 +20,11 @@ import importlib.util
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from contextlib import suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
@@ -40,6 +42,9 @@ from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
with suppress(ImportError):
|
||||||
|
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""Base class for trainer builder."""
|
"""Base class for trainer builder."""
|
||||||
@@ -452,4 +457,17 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.optim_target_modules:
|
if self.cfg.optim_target_modules:
|
||||||
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
|
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
|
||||||
|
|
||||||
|
# torch compile
|
||||||
|
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
|
||||||
|
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||||
|
True
|
||||||
|
)
|
||||||
|
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||||
|
if self.cfg.torch_compile_backend:
|
||||||
|
training_args_kwargs["torch_compile_backend"] = (
|
||||||
|
self.cfg.torch_compile_backend
|
||||||
|
)
|
||||||
|
if self.cfg.torch_compile_mode:
|
||||||
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Type, Union
|
from typing import Type, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
DataCollatorWithFlattening,
|
DataCollatorWithFlattening,
|
||||||
@@ -53,11 +52,6 @@ from axolotl.utils.collators import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
|
|
||||||
try:
|
|
||||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -195,20 +189,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.greater_is_better:
|
if self.cfg.greater_is_better:
|
||||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||||
|
|
||||||
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
|
|
||||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
|
||||||
True
|
|
||||||
)
|
|
||||||
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
|
||||||
if self.cfg.torch_compile_backend:
|
|
||||||
training_arguments_kwargs["torch_compile_backend"] = (
|
|
||||||
self.cfg.torch_compile_backend
|
|
||||||
)
|
|
||||||
if self.cfg.torch_compile_mode:
|
|
||||||
training_arguments_kwargs["torch_compile_mode"] = (
|
|
||||||
self.cfg.torch_compile_mode
|
|
||||||
)
|
|
||||||
|
|
||||||
# DDP Config
|
# DDP Config
|
||||||
if self.cfg.ddp_timeout:
|
if self.cfg.ddp_timeout:
|
||||||
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
|
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
|
||||||
@@ -464,13 +444,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
collator = RewardDataCollatorWithPadding
|
collator = RewardDataCollatorWithPadding
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.flex_attention:
|
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
# supported multipack models, or non-flash-attention llama
|
||||||
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
if (
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
self.cfg.flex_attention
|
||||||
elif (
|
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
self.cfg.model_config_type in ["llama"]
|
or (
|
||||||
and self.cfg.flash_attention is not True
|
self.cfg.model_config_type in ["llama"]
|
||||||
|
and self.cfg.flash_attention is not True
|
||||||
|
)
|
||||||
):
|
):
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user