Logging config for colab (#2611)
* only configure logging on cli to play nicely with colab * allow reloading the config on the fly from a dict * make sure to use dict for yaml * reuse existing function for load * make cli args optional * mps fix and respect max_steps
This commit is contained in:
@@ -2,4 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
configure_logging()
|
||||||
|
|||||||
@@ -8,9 +8,6 @@ from accelerate.commands.config import config_args
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -158,7 +159,9 @@ def plugin_set_cfg(cfg: DictDefault):
|
|||||||
plugin_manager.cfg = cfg
|
plugin_manager.cfg = cfg
|
||||||
|
|
||||||
|
|
||||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
|
def load_cfg(
|
||||||
|
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
||||||
|
) -> DictDefault:
|
||||||
"""
|
"""
|
||||||
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
||||||
various setup.
|
various setup.
|
||||||
@@ -170,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
|||||||
Returns:
|
Returns:
|
||||||
`DictDefault` mapping configuration keys to values.
|
`DictDefault` mapping configuration keys to values.
|
||||||
"""
|
"""
|
||||||
config = check_remote_config(config)
|
if isinstance(config, (str, Path)):
|
||||||
if Path(config).is_dir():
|
config = check_remote_config(config)
|
||||||
config = choose_config(Path(config))
|
if Path(config).is_dir():
|
||||||
|
config = choose_config(Path(config))
|
||||||
|
|
||||||
# Load the config from the yaml file
|
# Load the config from the yaml file
|
||||||
with open(config, encoding="utf-8") as file:
|
with open(config, encoding="utf-8") as file:
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
|
|
||||||
|
cfg.axolotl_config_path = config
|
||||||
|
else:
|
||||||
|
cfg = config
|
||||||
|
with NamedTemporaryFile(
|
||||||
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
|
) as temp_file:
|
||||||
|
temp_file.write(yaml.dump(config.to_dict()))
|
||||||
|
temp_file.close()
|
||||||
|
cfg.axolotl_config_path = temp_file.name
|
||||||
|
|
||||||
# If there are any options passed in the cli, if it is something that seems valid
|
# If there are any options passed in the cli, if it is something that seems valid
|
||||||
# from the yaml, then overwrite the value
|
# from the yaml, then overwrite the value
|
||||||
@@ -190,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
|||||||
else:
|
else:
|
||||||
cfg[k] = kwargs[k]
|
cfg[k] = kwargs[k]
|
||||||
|
|
||||||
cfg.axolotl_config_path = config
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
device_props = torch.cuda.get_device_properties("cuda")
|
device_props = torch.cuda.get_device_properties("cuda")
|
||||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||||
|
|||||||
@@ -20,11 +20,9 @@ from transformers import (
|
|||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
|||||||
def load_datasets(
|
def load_datasets(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
|
||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
"""
|
"""
|
||||||
Loads one or more training or evaluation datasets, calling
|
Loads one or more training or evaluation datasets, calling
|
||||||
@@ -64,7 +64,8 @@ def load_datasets(
|
|||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||||
preprocess_iterable = (
|
preprocess_iterable = (
|
||||||
hasattr(cli_args, "iterable")
|
cli_args
|
||||||
|
and hasattr(cli_args, "iterable")
|
||||||
and cli_args.iterable is not None
|
and cli_args.iterable is not None
|
||||||
and cli_args.iterable
|
and cli_args.iterable
|
||||||
)
|
)
|
||||||
@@ -76,7 +77,7 @@ def load_datasets(
|
|||||||
preprocess_iterable=preprocess_iterable,
|
preprocess_iterable=preprocess_iterable,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if cli_args and (
|
||||||
cli_args.debug
|
cli_args.debug
|
||||||
or cfg.debug
|
or cfg.debug
|
||||||
or cli_args.debug_text_only
|
or cli_args.debug_text_only
|
||||||
|
|||||||
@@ -488,7 +488,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
# these are all the "standard" kwargs that are def used
|
# these are all the "standard" kwargs that are def used
|
||||||
training_arguments_kwargs["max_steps"] = (
|
training_arguments_kwargs["max_steps"] = (
|
||||||
total_num_steps if self.cfg.max_steps else -1
|
self.cfg.max_steps if self.cfg.max_steps else -1
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||||
training_arguments_kwargs["per_device_train_batch_size"] = (
|
training_arguments_kwargs["per_device_train_batch_size"] = (
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from accelerate.logging import get_logger
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.train import (
|
from axolotl.train import (
|
||||||
TrainDatasetMeta,
|
TrainDatasetMeta,
|
||||||
setup_model_and_tokenizer,
|
setup_model_and_tokenizer,
|
||||||
@@ -24,7 +23,6 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|||||||
src_dir = os.path.join(project_root, "src")
|
src_dir = os.path.join(project_root, "src")
|
||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,10 +12,8 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from axolotl.core.trainers.mixins.sequence_parallel import (
|
|||||||
SequenceParallelContextManager,
|
SequenceParallelContextManager,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
@@ -42,7 +41,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
BetterTransformer = None
|
BetterTransformer = None
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ def resolve_dtype(cfg):
|
|||||||
else:
|
else:
|
||||||
LOG.debug("bf16 support not detected, disabling for this configuration.")
|
LOG.debug("bf16 support not detected, disabling for this configuration.")
|
||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
if cfg.fp16 is None:
|
if cfg.fp16 is None and not cfg.float16:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
if cfg.device == "mps":
|
if cfg.device == "mps":
|
||||||
|
|||||||
@@ -597,6 +597,8 @@ def prepare_optim_env(cfg):
|
|||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||||
elif cfg.fp16:
|
elif cfg.fp16:
|
||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
||||||
|
else:
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
|
||||||
|
|
||||||
|
|
||||||
def prepare_opinionated_env(cfg):
|
def prepare_opinionated_env(cfg):
|
||||||
|
|||||||
Reference in New Issue
Block a user