fix: import path for trainer builder and submodules
This commit is contained in:
@@ -4,6 +4,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.core.trainer_builder.base import TrainerBuilderBase
|
||||||
from axolotl.core.trainers import (
|
from axolotl.core.trainers import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
@@ -12,7 +13,6 @@ from axolotl.core.trainers import (
|
|||||||
from axolotl.core.trainers.dpo import DPOStrategy
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
from axolotl.core.trainers.trainer_builder.base import TrainerBuilderBase
|
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlCPOConfig,
|
AxolotlCPOConfig,
|
||||||
AxolotlKTOConfig,
|
AxolotlKTOConfig,
|
||||||
@@ -15,6 +15,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
|
from axolotl.core.trainer_builder.base import TrainerBuilderBase
|
||||||
from axolotl.core.trainers import (
|
from axolotl.core.trainers import (
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
AxolotlPRMTrainer,
|
AxolotlPRMTrainer,
|
||||||
@@ -22,7 +23,6 @@ from axolotl.core.trainers import (
|
|||||||
AxolotlTrainer,
|
AxolotlTrainer,
|
||||||
ReLoRATrainer,
|
ReLoRATrainer,
|
||||||
)
|
)
|
||||||
from axolotl.core.trainers.trainer_builder.base import TrainerBuilderBase
|
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlPRMConfig,
|
AxolotlPRMConfig,
|
||||||
AxolotlRewardConfig,
|
AxolotlRewardConfig,
|
||||||
@@ -26,7 +26,8 @@ from axolotl.common.datasets import TrainDatasetMeta
|
|||||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
|
||||||
|
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from axolotl.utils.distributed import (
|
|||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||||
|
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from transformers import TrainerCallback, TrainerControl, TrainerState
|
|||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import numpy as np
|
|||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainer
|
from axolotl.core.trainers import AxolotlTrainer
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks.lisa")
|
LOG = logging.getLogger("axolotl.callbacks.lisa")
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from transformers import TrainerCallback, TrainerControl, TrainerState
|
|||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ from datasets import IterableDataset, disable_caching, enable_caching
|
|||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
|
||||||
|
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
|
||||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
|
||||||
|
from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ class TestImports(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def test_import_causal_trainer(self):
|
def test_import_causal_trainer(self):
|
||||||
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
|
from axolotl.core.trainer_builder.sft import ( # pylint: disable=unused-import # noqa: F401
|
||||||
HFCausalTrainerBuilder,
|
HFCausalTrainerBuilder,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_import_rl_trainer(self):
|
def test_import_rl_trainer(self):
|
||||||
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
|
from axolotl.core.trainer_builder.rl import ( # pylint: disable=unused-import # noqa: F401
|
||||||
HFRLTrainerBuilder,
|
HFRLTrainerBuilder,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user