From 5bb4a782ce2d18469e59068e1c74c487c97e4fa3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 12 Dec 2023 17:33:31 -0500 Subject: [PATCH] dataloader defaults --- src/axolotl/cli/__init__.py | 4 +++- src/axolotl/utils/config.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 8ca4f7fe5..b71c73c74 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -23,7 +23,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta -from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.config import add_defaults, normalize_config, validate_config from axolotl.utils.data import prepare_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process @@ -301,6 +301,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs): normalize_config(cfg) + add_defaults(cfg) + setup_wandb_env_vars(cfg) return cfg diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 74da66928..af809b6bf 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -41,6 +41,16 @@ def choose_device(cfg): cfg.device_map = None +def add_defaults(cfg): + # setup sane defaults if left unspecified + if cfg.dataloader_num_workers is None: + cfg.dataloader_num_workers = int(os.getenv("WORLD_SIZE", "1")) + if cfg.dataloader_prefetch_factor is None: + cfg.dataloader_prefetch_factor = cfg.batch_size * 2 + if cfg.dataloader_pin_memory is None: + cfg.dataloader_pin_memory = True + + def normalize_config(cfg): # setup some derived config / hyperparams cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (