est total tokens, fix field loop

This commit is contained in:
Wing Lian
2023-07-18 11:30:07 -04:00
parent 41d4992029
commit 66774011c4
2 changed files with 37 additions and 24 deletions

View File

@@ -195,7 +195,7 @@ class MultipackDistributedDataloader:
for feature, array in concatenated.items()
}
chunked_data.append(chunk)
yield self.collate_fn(chunked_data)
yield self.collate_fn(chunked_data)
def __len__(self):
batches, _ = self.generate_batches()

View File

@@ -155,30 +155,43 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids)
eval_dataset = eval_dataset.map(add_position_ids)
sampler = DistributedSampler(
train_dataset,
num_replicas=1,
rank=0,
seed=cfg.seed or 42,
)
data_loader = MultipackDistributedDataloader(
train_dataset,
batch_size=cfg.micro_batch_size,
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
collate_fn=DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding="longest",
),
sampler=sampler,
)
data_loader_len = len(data_loader)
LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int(
math.ceil(
data_loader_len * cfg.micro_batch_size * cfg.num_epochs / cfg.batch_size
if cfg.sample_packing_eff_est:
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
total_num_steps = math.ceil(
total_num_tokens
/ cfg.sample_packing_eff_est
/ 2048
* cfg.num_epochs
/ cfg.batch_size
)
else:
sampler = DistributedSampler(
train_dataset,
num_replicas=1,
rank=0,
seed=cfg.seed or 42,
)
data_loader = MultipackDistributedDataloader(
train_dataset,
batch_size=cfg.micro_batch_size,
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
collate_fn=DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding="longest",
),
sampler=sampler,
)
data_loader_len = len(data_loader)
LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int(
math.ceil(
data_loader_len
* cfg.micro_batch_size
* cfg.num_epochs
/ cfg.batch_size
)
)
)
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)