Compare commits
1 Commits
chore/docs
...
20231212-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bb4a782ce |
@@ -23,7 +23,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
|||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import add_defaults, normalize_config, validate_config
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
@@ -301,6 +301,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
|
add_defaults(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,16 @@ def choose_device(cfg):
|
|||||||
cfg.device_map = None
|
cfg.device_map = None
|
||||||
|
|
||||||
|
|
||||||
|
def add_defaults(cfg):
|
||||||
|
# setup sane defaults if left unspecified
|
||||||
|
if cfg.dataloader_num_workers is None:
|
||||||
|
cfg.dataloader_num_workers = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
if cfg.dataloader_prefetch_factor is None:
|
||||||
|
cfg.dataloader_prefetch_factor = cfg.batch_size * 2
|
||||||
|
if cfg.dataloader_pin_memory is None:
|
||||||
|
cfg.dataloader_pin_memory = True
|
||||||
|
|
||||||
|
|
||||||
def normalize_config(cfg):
|
def normalize_config(cfg):
|
||||||
# setup some derived config / hyperparams
|
# setup some derived config / hyperparams
|
||||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||||
|
|||||||
Reference in New Issue
Block a user