Compare commits
1 Commits
llama-flas
...
20231212-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bb4a782ce |
@@ -23,7 +23,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import add_defaults, normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
@@ -301,6 +301,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
add_defaults(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
@@ -41,6 +41,16 @@ def choose_device(cfg):
|
||||
cfg.device_map = None
|
||||
|
||||
|
||||
def add_defaults(cfg):
|
||||
# setup sane defaults if left unspecified
|
||||
if cfg.dataloader_num_workers is None:
|
||||
cfg.dataloader_num_workers = int(os.getenv("WORLD_SIZE", "1"))
|
||||
if cfg.dataloader_prefetch_factor is None:
|
||||
cfg.dataloader_prefetch_factor = cfg.batch_size * 2
|
||||
if cfg.dataloader_pin_memory is None:
|
||||
cfg.dataloader_pin_memory = True
|
||||
|
||||
|
||||
def normalize_config(cfg):
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
|
||||
Reference in New Issue
Block a user