est total tokens, fix field loop
This commit is contained in:
@@ -195,7 +195,7 @@ class MultipackDistributedDataloader:
|
|||||||
for feature, array in concatenated.items()
|
for feature, array in concatenated.items()
|
||||||
}
|
}
|
||||||
chunked_data.append(chunk)
|
chunked_data.append(chunk)
|
||||||
yield self.collate_fn(chunked_data)
|
yield self.collate_fn(chunked_data)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
batches, _ = self.generate_batches()
|
batches, _ = self.generate_batches()
|
||||||
|
|||||||
@@ -155,30 +155,43 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
train_dataset = train_dataset.map(add_position_ids)
|
train_dataset = train_dataset.map(add_position_ids)
|
||||||
eval_dataset = eval_dataset.map(add_position_ids)
|
eval_dataset = eval_dataset.map(add_position_ids)
|
||||||
sampler = DistributedSampler(
|
if cfg.sample_packing_eff_est:
|
||||||
train_dataset,
|
total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset)
|
||||||
num_replicas=1,
|
total_num_steps = math.ceil(
|
||||||
rank=0,
|
total_num_tokens
|
||||||
seed=cfg.seed or 42,
|
/ cfg.sample_packing_eff_est
|
||||||
)
|
/ 2048
|
||||||
data_loader = MultipackDistributedDataloader(
|
* cfg.num_epochs
|
||||||
train_dataset,
|
/ cfg.batch_size
|
||||||
batch_size=cfg.micro_batch_size,
|
)
|
||||||
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
else:
|
||||||
collate_fn=DataCollatorForSeq2Seq(
|
sampler = DistributedSampler(
|
||||||
tokenizer,
|
train_dataset,
|
||||||
return_tensors="pt",
|
num_replicas=1,
|
||||||
padding="longest",
|
rank=0,
|
||||||
),
|
seed=cfg.seed or 42,
|
||||||
sampler=sampler,
|
)
|
||||||
)
|
data_loader = MultipackDistributedDataloader(
|
||||||
data_loader_len = len(data_loader)
|
train_dataset,
|
||||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
batch_size=cfg.micro_batch_size,
|
||||||
total_num_steps = int(
|
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
||||||
math.ceil(
|
collate_fn=DataCollatorForSeq2Seq(
|
||||||
data_loader_len * cfg.micro_batch_size * cfg.num_epochs / cfg.batch_size
|
tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
),
|
||||||
|
sampler=sampler,
|
||||||
|
)
|
||||||
|
data_loader_len = len(data_loader)
|
||||||
|
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||||
|
total_num_steps = int(
|
||||||
|
math.ceil(
|
||||||
|
data_loader_len
|
||||||
|
* cfg.micro_batch_size
|
||||||
|
* cfg.num_epochs
|
||||||
|
/ cfg.batch_size
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user