fix: import path for trainer builder and submodules
This commit is contained in:
@@ -4,6 +4,7 @@ import inspect
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.core.trainer_builder.base import TrainerBuilderBase
|
||||
from axolotl.core.trainers import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
@@ -12,7 +13,6 @@ from axolotl.core.trainers import (
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.trainer_builder.base import TrainerBuilderBase
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
@@ -15,6 +15,7 @@ from transformers import (
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainer_builder.base import TrainerBuilderBase
|
||||
from axolotl.core.trainers import (
|
||||
AxolotlMambaTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
@@ -22,7 +23,6 @@ from axolotl.core.trainers import (
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.core.trainers.trainer_builder.base import TrainerBuilderBase
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlPRMConfig,
|
||||
AxolotlRewardConfig,
|
||||
@@ -26,7 +26,8 @@ from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
|
||||
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -46,7 +46,7 @@ from axolotl.utils.distributed import (
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
@@ -9,7 +9,7 @@ from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import numpy as np
|
||||
from transformers import TrainerCallback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainer
|
||||
from axolotl.core.trainers import AxolotlTrainer
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks.lisa")
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
@@ -16,7 +16,8 @@ from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
|
||||
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
|
||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
|
||||
@@ -7,7 +7,8 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
|
||||
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
@@ -11,11 +11,11 @@ class TestImports(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_import_causal_trainer(self):
|
||||
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.core.trainer_builder.sft import ( # pylint: disable=unused-import # noqa: F401
|
||||
HFCausalTrainerBuilder,
|
||||
)
|
||||
|
||||
def test_import_rl_trainer(self):
|
||||
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.core.trainer_builder.rl import ( # pylint: disable=unused-import # noqa: F401
|
||||
HFRLTrainerBuilder,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user