Merge pull request #134 from OpenAccess-AI-Collective/gas-batch-fix
fix batch size calculation
This commit is contained in:
@@ -163,15 +163,17 @@ def train(
|
|||||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||||
cfg.batch_size // cfg.micro_batch_size
|
cfg.batch_size // cfg.micro_batch_size
|
||||||
)
|
)
|
||||||
|
cfg.batch_size = (
|
||||||
|
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||||
|
)
|
||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
choose_device(cfg)
|
choose_device(cfg)
|
||||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||||
if cfg.ddp:
|
if cfg.ddp:
|
||||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||||
cfg.gradient_accumulation_steps = (
|
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||||
cfg.gradient_accumulation_steps // cfg.world_size
|
|
||||||
)
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
if cfg.device == "mps":
|
if cfg.device == "mps":
|
||||||
cfg.load_in_8bit = False
|
cfg.load_in_8bit = False
|
||||||
|
|||||||
@@ -233,6 +233,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
else:
|
else:
|
||||||
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
||||||
|
raise ValueError(f"unhandled prompt tokenization strategy: {d.type}")
|
||||||
logging.info("tokenizing, merging, and shuffling master dataset")
|
logging.info("tokenizing, merging, and shuffling master dataset")
|
||||||
|
|
||||||
samples: List[int] = []
|
samples: List[int] = []
|
||||||
|
|||||||
Reference in New Issue
Block a user