From 8cec513447adea66e46eacb1f6e94a440daa96c4 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Sun, 13 Aug 2023 01:22:20 +0000 Subject: [PATCH] extract module for working with cfg --- scripts/finetune.py | 47 +---------------- .../utils/{validation.py => config.py} | 50 ++++++++++++++++++- tests/test_validation.py | 2 +- 3 files changed, 52 insertions(+), 47 deletions(-) rename src/axolotl/utils/{validation.py => config.py} (78%) diff --git a/scripts/finetune.py b/scripts/finetune.py index a7fee5ec8..da08fda0b 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -19,6 +19,7 @@ from transformers import GenerationConfig, TextStreamer from axolotl.logging_config import configure_logging from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import barrier, is_main_process @@ -29,7 +30,6 @@ from axolotl.utils.trainer import ( process_datasets_for_packing, setup_trainer, ) -from axolotl.utils.validation import validate_config from axolotl.utils.wandb import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -44,27 +44,6 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" -def choose_device(cfg): - def get_device(): - try: - if torch.cuda.is_available(): - return f"cuda:{cfg.local_rank}" - - if torch.backends.mps.is_available(): - return "mps" - - raise SystemError("No CUDA/mps device found") - except Exception: # pylint: disable=broad-exception-caught - return "cpu" - - cfg.device = get_device() - if cfg.device_map != "auto": - if cfg.device.startswith("cuda"): - cfg.device_map = {"": cfg.local_rank} - else: - cfg.device_map = {"": cfg.device} - - def get_multi_line_input() -> Optional[str]: print("Give me an instruction (Ctrl + D to finish): ") instruction = "" @@ -194,31 +173,9 @@ def train( validate_config(cfg) - # setup some derived config / hyperparams - cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( - cfg.batch_size // cfg.micro_batch_size - ) - cfg.batch_size = ( - cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps - ) - cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) - cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) - choose_device(cfg) - cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 - if cfg.ddp: - cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} - cfg.batch_size = cfg.batch_size * cfg.world_size + normalize_config(cfg) setup_wandb_env_vars(cfg) - if cfg.device == "mps": - cfg.load_in_8bit = False - cfg.tf32 = False - if cfg.bf16: - cfg.fp16 = True - cfg.bf16 = False - - if cfg.tf32: - torch.backends.cuda.matmul.allow_tf32 = True # load the tokenizer first tokenizer_config = cfg.tokenizer_config or cfg.base_model_config diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/config.py similarity index 78% rename from src/axolotl/utils/validation.py rename to src/axolotl/utils/config.py index 97d70c4c8..e69bffa7a 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/config.py @@ -1,12 +1,60 @@ -"""Module for validating config files""" +"""Module for working with config dicts""" import logging +import os import torch LOG = logging.getLogger("axolotl") +def choose_device(cfg): + def get_device(): + try: + if torch.cuda.is_available(): + return f"cuda:{cfg.local_rank}" + + if torch.backends.mps.is_available(): + return "mps" + + raise SystemError("No CUDA/mps device found") + except Exception: # pylint: disable=broad-exception-caught + return "cpu" + + cfg.device = get_device() + if cfg.device_map != "auto": + if cfg.device.startswith("cuda"): + cfg.device_map = {"": cfg.local_rank} + else: + cfg.device_map = {"": cfg.device} + + +def normalize_config(cfg): + # setup some derived config / hyperparams + cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( + cfg.batch_size // cfg.micro_batch_size + ) + cfg.batch_size = ( + cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps + ) + cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) + cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + choose_device(cfg) + cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 + if cfg.ddp: + cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} + cfg.batch_size = cfg.batch_size * cfg.world_size + + if cfg.device == "mps": + cfg.load_in_8bit = False + cfg.tf32 = False + if cfg.bf16: + cfg.fp16 = True + cfg.bf16 = False + else: + torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False + + def validate_config(cfg): if cfg.max_packed_sequence_len and cfg.sample_packing: raise ValueError( diff --git a/tests/test_validation.py b/tests/test_validation.py index e956d7b40..48b122f9a 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -6,8 +6,8 @@ from typing import Optional import pytest +from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault -from axolotl.utils.validation import validate_config class ValidationTest(unittest.TestCase):