From c06a6be915720cf0c4ba41097c1a94ccc35aae2d Mon Sep 17 00:00:00 2001 From: Sunny Date: Tue, 14 Jan 2025 00:22:05 -0500 Subject: [PATCH] flex_attn sample packing WIP --- src/axolotl/utils/samplers/multipack.py | 11 ++++++++--- src/axolotl/utils/trainer.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index db14a6819..74ce1265b 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -160,19 +160,24 @@ class MultipackBatchSampler(BatchSampler): for i in range(0, len(batches), self.batch_size) ] + seq_lens = [ + [[lengths[idx] for idx in sub_batch] for sub_batch in batch] + for batch in batches + ] + # statistics if set_stats: self.eff_total_used += total_used self.eff_total_slots += total_slots - return batches + return batches, seq_lens def __iter__(self): - batches = self.generate_batches(set_stats=True) + batches, _ = self.generate_batches(set_stats=True) return iter(batches) def num_batches(self): - batches = self.generate_batches(set_stats=True) + batches, _ = self.generate_batches(set_stats=True) return len(batches) def efficiency(self): diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 34b505ff1..de8fff625 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -96,7 +96,7 @@ def disable_datasets_caching(): def add_position_ids(sample): sample_len = len(sample["input_ids"]) - sample["position_ids"] = torch.arange(len(sample["input_ids"])) + sample["position_ids"] = torch.arange(sample_len) sample["length"] = sample_len return sample