dataloader defaults

This commit is contained in:
Wing Lian
2023-12-12 17:33:31 -05:00
parent 86487c2e96
commit 5bb4a782ce
2 changed files with 13 additions and 1 deletions

View File

@@ -23,7 +23,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import add_defaults, normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
@@ -301,6 +301,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
normalize_config(cfg)
add_defaults(cfg)
setup_wandb_env_vars(cfg)
return cfg

View File

@@ -41,6 +41,16 @@ def choose_device(cfg):
cfg.device_map = None
def add_defaults(cfg):
# setup sane defaults if left unspecified
if cfg.dataloader_num_workers is None:
cfg.dataloader_num_workers = int(os.getenv("WORLD_SIZE", "1"))
if cfg.dataloader_prefetch_factor is None:
cfg.dataloader_prefetch_factor = cfg.batch_size * 2
if cfg.dataloader_pin_memory is None:
cfg.dataloader_pin_memory = True
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (