various bugfixes

This commit is contained in:
Wing Lian
2023-04-19 17:04:34 -04:00
parent 2624bc2f11
commit 94f5e415a3
6 changed files with 63 additions and 10 deletions

View File

@@ -159,7 +159,7 @@ def train(
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.world_size != 1
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = (