diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index b20e4f085..8955eca3e 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -2,4 +2,7 @@ import os +from axolotl.logging_config import configure_logging + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +configure_logging() diff --git a/src/axolotl/cli/checks.py b/src/axolotl/cli/checks.py index cc3ed0d9f..47348240e 100644 --- a/src/axolotl/cli/checks.py +++ b/src/axolotl/cli/checks.py @@ -8,9 +8,6 @@ from accelerate.commands.config import config_args from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError -from axolotl.logging_config import configure_logging - -configure_logging() LOG = logging.getLogger(__name__) diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 64bf402b9..8f1fe7185 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -5,6 +5,7 @@ import logging import os import tempfile from pathlib import Path +from tempfile import NamedTemporaryFile from typing import Union from urllib.parse import urlparse @@ -158,7 +159,9 @@ def plugin_set_cfg(cfg: DictDefault): plugin_manager.cfg = cfg -def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: +def load_cfg( + config: str | Path | DictDefault = Path("examples/"), **kwargs +) -> DictDefault: """ Loads the `axolotl` configuration stored at `config`, validates it, and performs various setup. @@ -170,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa Returns: `DictDefault` mapping configuration keys to values. """ - config = check_remote_config(config) - if Path(config).is_dir(): - config = choose_config(Path(config)) + if isinstance(config, (str, Path)): + config = check_remote_config(config) + if Path(config).is_dir(): + config = choose_config(Path(config)) - # Load the config from the yaml file - with open(config, encoding="utf-8") as file: - cfg: DictDefault = DictDefault(yaml.safe_load(file)) + # Load the config from the yaml file + with open(config, encoding="utf-8") as file: + cfg: DictDefault = DictDefault(yaml.safe_load(file)) + + cfg.axolotl_config_path = config + else: + cfg = config + with NamedTemporaryFile( + mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" + ) as temp_file: + temp_file.write(yaml.dump(config.to_dict())) + temp_file.close() + cfg.axolotl_config_path = temp_file.name # If there are any options passed in the cli, if it is something that seems valid # from the yaml, then overwrite the value @@ -190,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa else: cfg[k] = kwargs[k] - cfg.axolotl_config_path = config - try: device_props = torch.cuda.get_device_properties("cuda") gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index 7cc4d2744..ee00db39d 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -20,11 +20,9 @@ from transformers import ( ProcessorMixin, ) -from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_processor, load_tokenizer -configure_logging() LOG = logging.getLogger(__name__) diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 3e712f772..2ab405ef1 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -47,7 +47,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: def load_datasets( *, cfg: DictDefault, - cli_args: Union[PreprocessCliArgs, TrainerCliArgs], + cli_args: PreprocessCliArgs | TrainerCliArgs | None = None, ) -> TrainDatasetMeta: """ Loads one or more training or evaluation datasets, calling @@ -64,7 +64,8 @@ def load_datasets( tokenizer = load_tokenizer(cfg) processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None preprocess_iterable = ( - hasattr(cli_args, "iterable") + cli_args + and hasattr(cli_args, "iterable") and cli_args.iterable is not None and cli_args.iterable ) @@ -76,7 +77,7 @@ def load_datasets( preprocess_iterable=preprocess_iterable, ) - if ( + if cli_args and ( cli_args.debug or cfg.debug or cli_args.debug_text_only diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 358058f69..31ee3cccf 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -488,7 +488,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): # these are all the "standard" kwargs that are def used training_arguments_kwargs["max_steps"] = ( - total_num_steps if self.cfg.max_steps else -1 + self.cfg.max_steps if self.cfg.max_steps else -1 ) training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len training_arguments_kwargs["per_device_train_batch_size"] = ( diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index a6a192bc7..6d6813730 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -11,7 +11,6 @@ from accelerate.logging import get_logger from datasets import Dataset from transformers.trainer import Trainer -from axolotl.logging_config import configure_logging from axolotl.train import ( TrainDatasetMeta, setup_model_and_tokenizer, @@ -24,7 +23,6 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) -configure_logging() LOG = get_logger(__name__) diff --git a/src/axolotl/monkeypatch/attention/ring_attn/patch.py b/src/axolotl/monkeypatch/attention/ring_attn/patch.py index b5587ddca..fa03bd174 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/patch.py @@ -12,10 +12,8 @@ import torch import torch.distributed as dist from accelerate.logging import get_logger -from axolotl.logging_config import configure_logging from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids -configure_logging() LOG = get_logger(__name__) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 7896239de..2768a8111 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -30,7 +30,6 @@ from axolotl.core.trainers.mixins.sequence_parallel import ( SequenceParallelContextManager, ) from axolotl.integrations.base import PluginManager -from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except @@ -42,7 +41,6 @@ try: except ImportError: BetterTransformer = None -configure_logging() LOG = get_logger(__name__) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index e5ea44aa0..35e742a89 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -67,7 +67,7 @@ def resolve_dtype(cfg): else: LOG.debug("bf16 support not detected, disabling for this configuration.") cfg.bf16 = False - if cfg.fp16 is None: + if cfg.fp16 is None and not cfg.float16: cfg.fp16 = True if cfg.device == "mps": diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 69aaabfa6..96f54b39d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -597,6 +597,8 @@ def prepare_optim_env(cfg): os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" elif cfg.fp16: os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" + else: + os.environ["ACCELERATE_MIXED_PRECISION"] = "no" def prepare_opinionated_env(cfg):