Fix total_num_steps (#1566)
* Fix `total_num_steps` * Fix total_num_steps * lint
This commit is contained in:
@@ -330,7 +330,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
/ cfg.sample_packing_eff_est
|
/ cfg.sample_packing_eff_est
|
||||||
/ cfg.sequence_len
|
/ cfg.sequence_len
|
||||||
// cfg.batch_size
|
// cfg.batch_size
|
||||||
// int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
@@ -359,18 +358,14 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
train_dataset.remove_columns(["length"]),
|
train_dataset.remove_columns(["length"]),
|
||||||
batch_sampler=sampler,
|
batch_sampler=sampler,
|
||||||
)
|
)
|
||||||
data_loader_len = len(data_loader) // cfg.batch_size
|
data_loader_len = len(data_loader) // (
|
||||||
|
cfg.world_size * cfg.gradient_accumulation_steps
|
||||||
|
)
|
||||||
actual_eff = sampler.efficiency()
|
actual_eff = sampler.efficiency()
|
||||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(
|
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||||
math.floor(
|
|
||||||
data_loader_len
|
|
||||||
* cfg.num_epochs
|
|
||||||
/ int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
@@ -391,12 +386,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
len(train_dataset)
|
|
||||||
* cfg.num_epochs
|
|
||||||
/ int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
/ cfg.batch_size
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|||||||
Reference in New Issue
Block a user