extract module for working with cfg
This commit is contained in:
committed by
Aman Gupta Karmani
parent
a13e45d548
commit
8cec513447
@@ -1,12 +1,60 @@
|
||||
"""Module for validating config files"""
|
||||
"""Module for working with config dicts"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def choose_device(cfg):
|
||||
def get_device():
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
return f"cuda:{cfg.local_rank}"
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
raise SystemError("No CUDA/mps device found")
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return "cpu"
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.device_map != "auto":
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
|
||||
def normalize_config(cfg):
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
cfg.tf32 = False
|
||||
if cfg.bf16:
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
else:
|
||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||
raise ValueError(
|
||||
Reference in New Issue
Block a user