Logging config for colab (#2611)

* only configure logging on cli to play nicely with colab

* allow reloading the config on the fly from a dict

* make sure to use dict for yaml

* reuse existing function for load

* make cli args optional

* mps fix and respect max_steps
This commit is contained in:
Wing Lian
2025-05-01 12:58:00 -04:00
parent bc6dfa6899
commit b50d35bec9
11 changed files with 32 additions and 25 deletions

View File

@@ -2,4 +2,7 @@
import os
from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
configure_logging()

View File

@@ -8,9 +8,6 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.logging_config import configure_logging
configure_logging()
LOG = logging.getLogger(__name__)

View File

@@ -5,6 +5,7 @@ import logging
import os
import tempfile
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union
from urllib.parse import urlparse
@@ -158,7 +159,9 @@ def plugin_set_cfg(cfg: DictDefault):
plugin_manager.cfg = cfg
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup.
@@ -170,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
Returns:
`DictDefault` mapping configuration keys to values.
"""
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
if isinstance(config, (str, Path)):
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
else:
cfg = config
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
temp_file.write(yaml.dump(config.to_dict()))
temp_file.close()
cfg.axolotl_config_path = temp_file.name
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
@@ -190,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)

View File

@@ -20,11 +20,9 @@ from transformers import (
ProcessorMixin,
)
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)

View File

@@ -47,7 +47,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
def load_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
@@ -64,7 +64,8 @@ def load_datasets(
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = (
hasattr(cli_args, "iterable")
cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
@@ -76,7 +77,7 @@ def load_datasets(
preprocess_iterable=preprocess_iterable,
)
if (
if cli_args and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only

View File

@@ -488,7 +488,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
self.cfg.max_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs["per_device_train_batch_size"] = (

View File

@@ -11,7 +11,6 @@ from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.logging_config import configure_logging
from axolotl.train import (
TrainDatasetMeta,
setup_model_and_tokenizer,
@@ -24,7 +23,6 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger(__name__)

View File

@@ -12,10 +12,8 @@ import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging()
LOG = get_logger(__name__)

View File

@@ -30,7 +30,6 @@ from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
@@ -42,7 +41,6 @@ try:
except ImportError:
BetterTransformer = None
configure_logging()
LOG = get_logger(__name__)

View File

@@ -67,7 +67,7 @@ def resolve_dtype(cfg):
else:
LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False
if cfg.fp16 is None:
if cfg.fp16 is None and not cfg.float16:
cfg.fp16 = True
if cfg.device == "mps":

View File

@@ -597,6 +597,8 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
else:
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
def prepare_opinionated_env(cfg):