extract module for working with cfg
This commit is contained in:
committed by
Aman Gupta Karmani
parent
a13e45d548
commit
8cec513447
@@ -19,6 +19,7 @@ from transformers import GenerationConfig, TextStreamer
|
|||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import barrier, is_main_process
|
from axolotl.utils.distributed import barrier, is_main_process
|
||||||
@@ -29,7 +30,6 @@ from axolotl.utils.trainer import (
|
|||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
setup_trainer,
|
setup_trainer,
|
||||||
)
|
)
|
||||||
from axolotl.utils.validation import validate_config
|
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -44,27 +44,6 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
|||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
|
|
||||||
def choose_device(cfg):
|
|
||||||
def get_device():
|
|
||||||
try:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return f"cuda:{cfg.local_rank}"
|
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
return "mps"
|
|
||||||
|
|
||||||
raise SystemError("No CUDA/mps device found")
|
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
|
||||||
return "cpu"
|
|
||||||
|
|
||||||
cfg.device = get_device()
|
|
||||||
if cfg.device_map != "auto":
|
|
||||||
if cfg.device.startswith("cuda"):
|
|
||||||
cfg.device_map = {"": cfg.local_rank}
|
|
||||||
else:
|
|
||||||
cfg.device_map = {"": cfg.device}
|
|
||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> Optional[str]:
|
def get_multi_line_input() -> Optional[str]:
|
||||||
print("Give me an instruction (Ctrl + D to finish): ")
|
print("Give me an instruction (Ctrl + D to finish): ")
|
||||||
instruction = ""
|
instruction = ""
|
||||||
@@ -194,31 +173,9 @@ def train(
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
# setup some derived config / hyperparams
|
normalize_config(cfg)
|
||||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
|
||||||
cfg.batch_size // cfg.micro_batch_size
|
|
||||||
)
|
|
||||||
cfg.batch_size = (
|
|
||||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
||||||
choose_device(cfg)
|
|
||||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
|
||||||
if cfg.ddp:
|
|
||||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
|
||||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
if cfg.device == "mps":
|
|
||||||
cfg.load_in_8bit = False
|
|
||||||
cfg.tf32 = False
|
|
||||||
if cfg.bf16:
|
|
||||||
cfg.fp16 = True
|
|
||||||
cfg.bf16 = False
|
|
||||||
|
|
||||||
if cfg.tf32:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||||
|
|||||||
@@ -1,12 +1,60 @@
|
|||||||
"""Module for validating config files"""
|
"""Module for working with config dicts"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def choose_device(cfg):
|
||||||
|
def get_device():
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return f"cuda:{cfg.local_rank}"
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
|
||||||
|
raise SystemError("No CUDA/mps device found")
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
cfg.device = get_device()
|
||||||
|
if cfg.device_map != "auto":
|
||||||
|
if cfg.device.startswith("cuda"):
|
||||||
|
cfg.device_map = {"": cfg.local_rank}
|
||||||
|
else:
|
||||||
|
cfg.device_map = {"": cfg.device}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_config(cfg):
|
||||||
|
# setup some derived config / hyperparams
|
||||||
|
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||||
|
cfg.batch_size // cfg.micro_batch_size
|
||||||
|
)
|
||||||
|
cfg.batch_size = (
|
||||||
|
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
choose_device(cfg)
|
||||||
|
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||||
|
if cfg.ddp:
|
||||||
|
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||||
|
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||||
|
|
||||||
|
if cfg.device == "mps":
|
||||||
|
cfg.load_in_8bit = False
|
||||||
|
cfg.tf32 = False
|
||||||
|
if cfg.bf16:
|
||||||
|
cfg.fp16 = True
|
||||||
|
cfg.bf16 = False
|
||||||
|
else:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -6,8 +6,8 @@ from typing import Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.validation import validate_config
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationTest(unittest.TestCase):
|
class ValidationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user