From 90dfcd8c0315a7b8d52578e79f77e378ea395486 Mon Sep 17 00:00:00 2001 From: Sung Ching Liu <22844540+bursteratom@users.noreply.github.com> Date: Wed, 19 Feb 2025 21:13:25 -0500 Subject: [PATCH] =?UTF-8?q?Revert=20"Fix=20sample=20packing=20producing=20?= =?UTF-8?q?longer=20sequences=20than=20specified=20by=20`sequ=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 8dfadc2b3c7f85e4208b74eb93c32f440c7e25b4. --- src/axolotl/utils/samplers/utils.py | 12 ++++++------ tests/test_packed_batch_sampler.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py index 8f400c98d..4e41c9b44 100755 --- a/src/axolotl/utils/samplers/utils.py +++ b/src/axolotl/utils/samplers/utils.py @@ -5,12 +5,12 @@ import numpy as np def get_dataset_lengths(dataset): - if "length" in dataset.column_names: - lengths = np.array(dataset["length"]) - elif "position_ids" in dataset.column_names: - position_ids = dataset["position_ids"] + if "length" in dataset.data.column_names: + lengths = np.array(dataset.data.column("length")) + elif "position_ids" in dataset.data.column_names: + position_ids = dataset.data.column("position_ids") lengths = np.array([x[-1] + 1 for x in position_ids]) else: - input_ids = dataset["input_ids"] - lengths = np.array([len(seq) for seq in input_ids]) + input_ids = dataset.data.column("input_ids") + lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) return lengths diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index b52320e2a..ceff11df9 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -7,7 +7,6 @@ from transformers import AutoTokenizer from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies.completion import load from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq -from axolotl.utils.data.utils import drop_long_seq_in_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -19,6 +18,11 @@ def fixture_tokenizer(): return tokenizer +@pytest.fixture(name="max_seq_length") +def fixture_max_seq_length(): + return 4096 + + class TestBatchedSamplerPacking: """ Test class for packing streaming dataset sequences @@ -33,7 +37,6 @@ class TestBatchedSamplerPacking: (2, 2), ], ) - @pytest.mark.parametrize("max_seq_length", [4096, 512]) def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 @@ -59,9 +62,6 @@ class TestBatchedSamplerPacking: dataset, ) train_dataset = concatenate_datasets([dataset_wrapper]) - - train_dataset = drop_long_seq_in_dataset(train_dataset, cfg) - lengths = get_dataset_lengths(train_dataset) batch_sampler = MultipackBatchSampler( sampler=RandomSampler(train_dataset), @@ -90,7 +90,7 @@ class TestBatchedSamplerPacking: batch_idxs.extend(pack) for batch in loader: - assert batch["input_ids"].numel() <= batch_size * max_seq_length + assert len(batch["input_ids"]) <= batch_size * max_seq_length assert batch["input_ids"].shape[1] == max_seq_length original_idxs = set(range(len(train_dataset)))