Merge pull request #134 from OpenAccess-AI-Collective/gas-batch-fix

fix batch size calculation
This commit is contained in:
Wing Lian
2023-05-31 14:24:48 -04:00
committed by GitHub
2 changed files with 6 additions and 3 deletions

View File

@@ -163,15 +163,17 @@ def train(
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size cfg.batch_size // cfg.micro_batch_size
) )
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg) choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp: if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = ( cfg.batch_size = cfg.batch_size * cfg.world_size
cfg.gradient_accumulation_steps // cfg.world_size
)
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)
if cfg.device == "mps": if cfg.device == "mps":
cfg.load_in_8bit = False cfg.load_in_8bit = False

View File

@@ -233,6 +233,7 @@ def load_tokenized_prepared_datasets(
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
else: else:
logging.error(f"unhandled prompt tokenization strategy: {d.type}") logging.error(f"unhandled prompt tokenization strategy: {d.type}")
raise ValueError(f"unhandled prompt tokenization strategy: {d.type}")
logging.info("tokenizing, merging, and shuffling master dataset") logging.info("tokenizing, merging, and shuffling master dataset")
samples: List[int] = [] samples: List[int] = []