est total tokens, fix field loop
This commit is contained in:
@@ -195,7 +195,7 @@ class MultipackDistributedDataloader:
|
||||
for feature, array in concatenated.items()
|
||||
}
|
||||
chunked_data.append(chunk)
|
||||
yield self.collate_fn(chunked_data)
|
||||
yield self.collate_fn(chunked_data)
|
||||
|
||||
def __len__(self):
|
||||
batches, _ = self.generate_batches()
|
||||
|
||||
@@ -155,30 +155,43 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.sample_packing:
|
||||
train_dataset = train_dataset.map(add_position_ids)
|
||||
eval_dataset = eval_dataset.map(add_position_ids)
|
||||
sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
num_replicas=1,
|
||||
rank=0,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
data_loader = MultipackDistributedDataloader(
|
||||
train_dataset,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
||||
collate_fn=DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
),
|
||||
sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
total_num_steps = int(
|
||||
math.ceil(
|
||||
data_loader_len * cfg.micro_batch_size * cfg.num_epochs / cfg.batch_size
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
||||
total_num_steps = math.ceil(
|
||||
total_num_tokens
|
||||
/ cfg.sample_packing_eff_est
|
||||
/ 2048
|
||||
* cfg.num_epochs
|
||||
/ cfg.batch_size
|
||||
)
|
||||
else:
|
||||
sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
num_replicas=1,
|
||||
rank=0,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
data_loader = MultipackDistributedDataloader(
|
||||
train_dataset,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
||||
collate_fn=DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
),
|
||||
sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
total_num_steps = int(
|
||||
math.ceil(
|
||||
data_loader_len
|
||||
* cfg.micro_batch_size
|
||||
* cfg.num_epochs
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
|
||||
Reference in New Issue
Block a user