fix: import path for trainer builder and submodules

This commit is contained in:
NanoCode012
2025-05-15 15:49:37 +07:00
parent a1832953c4
commit 555190868a
12 changed files with 14 additions and 11 deletions

View File

@@ -4,6 +4,7 @@ import inspect
import logging
from pathlib import Path
from axolotl.core.trainer_builder.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
@@ -12,7 +13,6 @@ from axolotl.core.trainers import (
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.trainer_builder.base import TrainerBuilderBase
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,

View File

@@ -15,6 +15,7 @@ from transformers import (
)
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainer_builder.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlMambaTrainer,
AxolotlPRMTrainer,
@@ -22,7 +23,6 @@ from axolotl.core.trainers import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.trainers.trainer_builder.base import TrainerBuilderBase
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,

View File

@@ -26,7 +26,8 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
from axolotl.integrations.base import PluginManager
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault

View File

@@ -46,7 +46,7 @@ from axolotl.utils.distributed import (
from axolotl.utils.schemas.config import AxolotlInputConfig
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
from axolotl.core.training_args import AxolotlTrainingArguments
IGNORE_INDEX = -100

View File

@@ -9,7 +9,7 @@ from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
from axolotl.core.training_args import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")

View File

@@ -14,7 +14,7 @@ import numpy as np
from transformers import TrainerCallback
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainer
from axolotl.core.trainers import AxolotlTrainer
LOG = logging.getLogger("axolotl.callbacks.lisa")

View File

@@ -11,7 +11,7 @@ from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
from axolotl.core.training_args import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")

View File

@@ -16,7 +16,8 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support

View File

@@ -7,7 +7,8 @@ from pathlib import Path
import pytest
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer

View File

@@ -11,11 +11,11 @@ class TestImports(unittest.TestCase):
"""
def test_import_causal_trainer(self):
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
from axolotl.core.trainer_builder.sft import ( # pylint: disable=unused-import # noqa: F401
HFCausalTrainerBuilder,
)
def test_import_rl_trainer(self):
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
from axolotl.core.trainer_builder.rl import ( # pylint: disable=unused-import # noqa: F401
HFRLTrainerBuilder,
)