From 66774011c4082905cf1ccf5e8748993a2b3b6724 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 18 Jul 2023 11:30:07 -0400 Subject: [PATCH] est total tokens, fix field loop --- src/axolotl/utils/dataloader.py | 2 +- src/axolotl/utils/trainer.py | 59 ++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 2a95749d2..a9b1f5e89 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -195,7 +195,7 @@ class MultipackDistributedDataloader: for feature, array in concatenated.items() } chunked_data.append(chunk) - yield self.collate_fn(chunked_data) + yield self.collate_fn(chunked_data) def __len__(self): batches, _ = self.generate_batches() diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 85a32e914..4a2fc1e6d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -155,30 +155,43 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.sample_packing: train_dataset = train_dataset.map(add_position_ids) eval_dataset = eval_dataset.map(add_position_ids) - sampler = DistributedSampler( - train_dataset, - num_replicas=1, - rank=0, - seed=cfg.seed or 42, - ) - data_loader = MultipackDistributedDataloader( - train_dataset, - batch_size=cfg.micro_batch_size, - seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len, - collate_fn=DataCollatorForSeq2Seq( - tokenizer, - return_tensors="pt", - padding="longest", - ), - sampler=sampler, - ) - data_loader_len = len(data_loader) - LOG.info(f"data_loader_len: {data_loader_len}") - total_num_steps = int( - math.ceil( - data_loader_len * cfg.micro_batch_size * cfg.num_epochs / cfg.batch_size + if cfg.sample_packing_eff_est: + total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset) + total_num_steps = math.ceil( + total_num_tokens + / cfg.sample_packing_eff_est + / 2048 + * cfg.num_epochs + / cfg.batch_size + ) + else: + sampler = DistributedSampler( + train_dataset, + num_replicas=1, + rank=0, + seed=cfg.seed or 42, + ) + data_loader = MultipackDistributedDataloader( + train_dataset, + batch_size=cfg.micro_batch_size, + seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len, + collate_fn=DataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding="longest", + ), + sampler=sampler, + ) + data_loader_len = len(data_loader) + LOG.info(f"data_loader_len: {data_loader_len}") + total_num_steps = int( + math.ceil( + data_loader_len + * cfg.micro_batch_size + * cfg.num_epochs + / cfg.batch_size + ) ) - ) else: total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)