diff --git a/src/axolotl/utils/samplers/utils.py b/src/axolotl/utils/samplers/utils.py index 4e41c9b44..8f400c98d 100755 --- a/src/axolotl/utils/samplers/utils.py +++ b/src/axolotl/utils/samplers/utils.py @@ -5,12 +5,12 @@ import numpy as np def get_dataset_lengths(dataset): - if "length" in dataset.data.column_names: - lengths = np.array(dataset.data.column("length")) - elif "position_ids" in dataset.data.column_names: - position_ids = dataset.data.column("position_ids") + if "length" in dataset.column_names: + lengths = np.array(dataset["length"]) + elif "position_ids" in dataset.column_names: + position_ids = dataset["position_ids"] lengths = np.array([x[-1] + 1 for x in position_ids]) else: - input_ids = dataset.data.column("input_ids") - lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) + input_ids = dataset["input_ids"] + lengths = np.array([len(seq) for seq in input_ids]) return lengths diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index ceff11df9..b52320e2a 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -7,6 +7,7 @@ from transformers import AutoTokenizer from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies.completion import load from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.data.utils import drop_long_seq_in_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -18,11 +19,6 @@ def fixture_tokenizer(): return tokenizer -@pytest.fixture(name="max_seq_length") -def fixture_max_seq_length(): - return 4096 - - class TestBatchedSamplerPacking: """ Test class for packing streaming dataset sequences @@ -37,6 +33,7 @@ class TestBatchedSamplerPacking: (2, 2), ], ) + @pytest.mark.parametrize("max_seq_length", [4096, 512]) def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 @@ -62,6 +59,9 @@ class TestBatchedSamplerPacking: dataset, ) train_dataset = concatenate_datasets([dataset_wrapper]) + + train_dataset = drop_long_seq_in_dataset(train_dataset, cfg) + lengths = get_dataset_lengths(train_dataset) batch_sampler = MultipackBatchSampler( sampler=RandomSampler(train_dataset), @@ -90,7 +90,7 @@ class TestBatchedSamplerPacking: batch_idxs.extend(pack) for batch in loader: - assert len(batch["input_ids"]) <= batch_size * max_seq_length + assert batch["input_ids"].numel() <= batch_size * max_seq_length assert batch["input_ids"].shape[1] == max_seq_length original_idxs = set(range(len(train_dataset)))