Logging config for colab (#2611)

* only configure logging on cli to play nicely with colab

* allow reloading the config on the fly from a dict

* make sure to use dict for yaml

* reuse existing function for load

* make cli args optional

* mps fix and respect max_steps
This commit is contained in:
Wing Lian
2025-05-01 12:58:00 -04:00
committed by GitHub
parent 996fc124e5
commit fee3c13bb5
11 changed files with 32 additions and 25 deletions

View File

@@ -2,4 +2,7 @@
import os import os
from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
configure_logging()

View File

@@ -8,9 +8,6 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.logging_config import configure_logging
configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)

View File

@@ -5,6 +5,7 @@ import logging
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union from typing import Union
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -158,7 +159,9 @@ def plugin_set_cfg(cfg: DictDefault):
plugin_manager.cfg = cfg plugin_manager.cfg = cfg
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
""" """
Loads the `axolotl` configuration stored at `config`, validates it, and performs Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup. various setup.
@@ -170,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
Returns: Returns:
`DictDefault` mapping configuration keys to values. `DictDefault` mapping configuration keys to values.
""" """
config = check_remote_config(config) if isinstance(config, (str, Path)):
if Path(config).is_dir(): config = check_remote_config(config)
config = choose_config(Path(config)) if Path(config).is_dir():
config = choose_config(Path(config))
# Load the config from the yaml file # Load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
else:
cfg = config
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
temp_file.write(yaml.dump(config.to_dict()))
temp_file.close()
cfg.axolotl_config_path = temp_file.name
# If there are any options passed in the cli, if it is something that seems valid # If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value # from the yaml, then overwrite the value
@@ -190,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
else: else:
cfg[k] = kwargs[k] cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try: try:
device_props = torch.cuda.get_device_properties("cuda") device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)

View File

@@ -20,11 +20,9 @@ from transformers import (
ProcessorMixin, ProcessorMixin,
) )
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.models import load_model, load_processor, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)

View File

@@ -47,7 +47,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
def load_datasets( def load_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs], cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """
Loads one or more training or evaluation datasets, calling Loads one or more training or evaluation datasets, calling
@@ -64,7 +64,8 @@ def load_datasets(
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = ( preprocess_iterable = (
hasattr(cli_args, "iterable") cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None and cli_args.iterable is not None
and cli_args.iterable and cli_args.iterable
) )
@@ -76,7 +77,7 @@ def load_datasets(
preprocess_iterable=preprocess_iterable, preprocess_iterable=preprocess_iterable,
) )
if ( if cli_args and (
cli_args.debug cli_args.debug
or cfg.debug or cfg.debug
or cli_args.debug_text_only or cli_args.debug_text_only

View File

@@ -488,7 +488,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# these are all the "standard" kwargs that are def used # these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = ( training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1 self.cfg.max_steps if self.cfg.max_steps else -1
) )
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs["per_device_train_batch_size"] = ( training_arguments_kwargs["per_device_train_batch_size"] = (

View File

@@ -11,7 +11,6 @@ from accelerate.logging import get_logger
from datasets import Dataset from datasets import Dataset
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.logging_config import configure_logging
from axolotl.train import ( from axolotl.train import (
TrainDatasetMeta, TrainDatasetMeta,
setup_model_and_tokenizer, setup_model_and_tokenizer,
@@ -24,7 +23,6 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src") src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -12,10 +12,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -30,7 +30,6 @@ from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager, SequenceParallelContextManager,
) )
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except
@@ -42,7 +41,6 @@ try:
except ImportError: except ImportError:
BetterTransformer = None BetterTransformer = None
configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -67,7 +67,7 @@ def resolve_dtype(cfg):
else: else:
LOG.debug("bf16 support not detected, disabling for this configuration.") LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False cfg.bf16 = False
if cfg.fp16 is None: if cfg.fp16 is None and not cfg.float16:
cfg.fp16 = True cfg.fp16 = True
if cfg.device == "mps": if cfg.device == "mps":

View File

@@ -597,6 +597,8 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16: elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
else:
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
def prepare_opinionated_env(cfg): def prepare_opinionated_env(cfg):