From 555190868a4f39d9858124c642bc83ed5024f35e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 15 May 2025 15:49:37 +0700 Subject: [PATCH] fix: import path for trainer builder and submodules --- src/axolotl/core/{trainers => }/trainer_builder/__init__.py | 0 src/axolotl/core/{trainers => }/trainer_builder/base.py | 0 src/axolotl/core/{trainers => }/trainer_builder/rl.py | 2 +- src/axolotl/core/{trainers => }/trainer_builder/sft.py | 2 +- src/axolotl/train.py | 3 ++- src/axolotl/utils/callbacks/__init__.py | 2 +- src/axolotl/utils/callbacks/comet_.py | 2 +- src/axolotl/utils/callbacks/lisa.py | 2 +- src/axolotl/utils/callbacks/mlflow_.py | 2 +- src/axolotl/utils/trainer.py | 3 ++- tests/core/test_trainer_builder.py | 3 ++- tests/e2e/test_imports.py | 4 ++-- 12 files changed, 14 insertions(+), 11 deletions(-) rename src/axolotl/core/{trainers => }/trainer_builder/__init__.py (100%) rename src/axolotl/core/{trainers => }/trainer_builder/base.py (100%) rename src/axolotl/core/{trainers => }/trainer_builder/rl.py (99%) rename src/axolotl/core/{trainers => }/trainer_builder/sft.py (99%) diff --git a/src/axolotl/core/trainers/trainer_builder/__init__.py b/src/axolotl/core/trainer_builder/__init__.py similarity index 100% rename from src/axolotl/core/trainers/trainer_builder/__init__.py rename to src/axolotl/core/trainer_builder/__init__.py diff --git a/src/axolotl/core/trainers/trainer_builder/base.py b/src/axolotl/core/trainer_builder/base.py similarity index 100% rename from src/axolotl/core/trainers/trainer_builder/base.py rename to src/axolotl/core/trainer_builder/base.py diff --git a/src/axolotl/core/trainers/trainer_builder/rl.py b/src/axolotl/core/trainer_builder/rl.py similarity index 99% rename from src/axolotl/core/trainers/trainer_builder/rl.py rename to src/axolotl/core/trainer_builder/rl.py index 6b888e6fe..d0e44f8da 100644 --- a/src/axolotl/core/trainers/trainer_builder/rl.py +++ b/src/axolotl/core/trainer_builder/rl.py @@ -4,6 +4,7 @@ import inspect import logging from pathlib import Path +from axolotl.core.trainer_builder.base import TrainerBuilderBase from axolotl.core.trainers import ( AxolotlCPOTrainer, AxolotlKTOTrainer, @@ -12,7 +13,6 @@ from axolotl.core.trainers import ( from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.grpo import GRPOStrategy -from axolotl.core.trainers.trainer_builder.base import TrainerBuilderBase from axolotl.core.training_args import ( AxolotlCPOConfig, AxolotlKTOConfig, diff --git a/src/axolotl/core/trainers/trainer_builder/sft.py b/src/axolotl/core/trainer_builder/sft.py similarity index 99% rename from src/axolotl/core/trainers/trainer_builder/sft.py rename to src/axolotl/core/trainer_builder/sft.py index 5caf075a3..c3f777ae2 100644 --- a/src/axolotl/core/trainers/trainer_builder/sft.py +++ b/src/axolotl/core/trainer_builder/sft.py @@ -15,6 +15,7 @@ from transformers import ( ) from trl.trainer.utils import RewardDataCollatorWithPadding +from axolotl.core.trainer_builder.base import TrainerBuilderBase from axolotl.core.trainers import ( AxolotlMambaTrainer, AxolotlPRMTrainer, @@ -22,7 +23,6 @@ from axolotl.core.trainers import ( AxolotlTrainer, ReLoRATrainer, ) -from axolotl.core.trainers.trainer_builder.base import TrainerBuilderBase from axolotl.core.training_args import ( AxolotlPRMConfig, AxolotlRewardConfig, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 90ab10e9f..3aaba84ff 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -26,7 +26,8 @@ from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) -from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder +from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder +from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder from axolotl.integrations.base import PluginManager from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.dict import DictDefault diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 0e7b06093..a75a737a9 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -46,7 +46,7 @@ from axolotl.utils.distributed import ( from axolotl.utils.schemas.config import AxolotlInputConfig if TYPE_CHECKING: - from axolotl.core.trainer_builder import AxolotlTrainingArguments + from axolotl.core.training_args import AxolotlTrainingArguments IGNORE_INDEX = -100 diff --git a/src/axolotl/utils/callbacks/comet_.py b/src/axolotl/utils/callbacks/comet_.py index b29f997a8..4a8b6437c 100644 --- a/src/axolotl/utils/callbacks/comet_.py +++ b/src/axolotl/utils/callbacks/comet_.py @@ -9,7 +9,7 @@ from transformers import TrainerCallback, TrainerControl, TrainerState from axolotl.utils.distributed import is_main_process if TYPE_CHECKING: - from axolotl.core.trainer_builder import AxolotlTrainingArguments + from axolotl.core.training_args import AxolotlTrainingArguments LOG = logging.getLogger("axolotl.callbacks") diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index e226471b1..ccf0c12bc 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -14,7 +14,7 @@ import numpy as np from transformers import TrainerCallback if TYPE_CHECKING: - from axolotl.core.trainer_builder import AxolotlTrainer + from axolotl.core.trainers import AxolotlTrainer LOG = logging.getLogger("axolotl.callbacks.lisa") diff --git a/src/axolotl/utils/callbacks/mlflow_.py b/src/axolotl/utils/callbacks/mlflow_.py index 47679001f..056fb51cc 100644 --- a/src/axolotl/utils/callbacks/mlflow_.py +++ b/src/axolotl/utils/callbacks/mlflow_.py @@ -11,7 +11,7 @@ from transformers import TrainerCallback, TrainerControl, TrainerState from axolotl.utils.distributed import is_main_process if TYPE_CHECKING: - from axolotl.core.trainer_builder import AxolotlTrainingArguments + from axolotl.core.training_args import AxolotlTrainingArguments LOG = logging.getLogger("axolotl.callbacks") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 96f54b39d..17555c9e2 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,8 @@ from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available -from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder +from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder +from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index 824a7b4c7..35d8060b1 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -7,7 +7,8 @@ from pathlib import Path import pytest -from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder +from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder +from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer diff --git a/tests/e2e/test_imports.py b/tests/e2e/test_imports.py index fc0843479..8a2619632 100644 --- a/tests/e2e/test_imports.py +++ b/tests/e2e/test_imports.py @@ -11,11 +11,11 @@ class TestImports(unittest.TestCase): """ def test_import_causal_trainer(self): - from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401 + from axolotl.core.trainer_builder.sft import ( # pylint: disable=unused-import # noqa: F401 HFCausalTrainerBuilder, ) def test_import_rl_trainer(self): - from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401 + from axolotl.core.trainer_builder.rl import ( # pylint: disable=unused-import # noqa: F401 HFRLTrainerBuilder, )