Compare commits

...

3 Commits

Author SHA1 Message Date
Dan Saunders
02efd7e83d quick formatting fix for LoRA optims doc 2025-02-19 14:17:20 +00:00
Tobias
8dfadc2b3c Fix sample packing producing longer sequences than specified by sequence_len (#2332)
* Extend MultiPackBatchSampler test to include shorter sequence length and drop long sequences filter

* Fix get_dataset_lengths for datasets that were previously filtered (e.g., with drop_long_seq_in_dataset)

* Update src/axolotl/utils/samplers/utils.py

Fix get_dataset_lengths for datasets that do not have position_ids or length attributes

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2025-02-19 12:02:35 +07:00
Wing Lian
23a9fcb0a7 make sure chatml dpo dataset loading works (#2333) 2025-02-18 16:08:40 -05:00
5 changed files with 80 additions and 12 deletions

View File

@@ -12,6 +12,7 @@ to leverage operator fusion and tensor re-use in order to improve speed and redu
memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to):
- `llama`
- `mistral`
- `qwen2`

View File

@@ -5,12 +5,12 @@ import numpy as np
def get_dataset_lengths(dataset):
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
elif "position_ids" in dataset.data.column_names:
position_ids = dataset.data.column("position_ids")
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths

View File

@@ -125,6 +125,12 @@ def fixture_llama3_tokenizer():
return tokenizer
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
def fixture_smollm2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
def fixture_mistralv03_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(

View File

@@ -0,0 +1,61 @@
"""
Tests for loading DPO preference datasets with chatml formatting
"""
import unittest
import pytest
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_dpo_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"rl": "dpo",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"sequence_len": 2048,
}
)
class TestDPOChatml:
"""
Test loading DPO preference datasets with chatml formatting
"""
def test_default(self, minimal_dpo_cfg):
cfg = DictDefault(
{
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "chatml",
"split": "train[:1%]",
}
]
}
| minimal_dpo_cfg
)
# test that dpo.load works
load_dpo("chatml", cfg)
# now actually load the datasets with the strategy
train_ds, _ = load_prepare_preference_datasets(cfg)
assert train_ds[0]["prompt"].startswith("<|im_start|>")
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
assert "chosen" in train_ds[0]
assert "rejected" in train_ds[0]
if __name__ == "__main__":
unittest.main()

View File

@@ -7,6 +7,7 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -18,11 +19,6 @@ def fixture_tokenizer():
return tokenizer
@pytest.fixture(name="max_seq_length")
def fixture_max_seq_length():
return 4096
class TestBatchedSamplerPacking:
"""
Test class for packing streaming dataset sequences
@@ -37,6 +33,7 @@ class TestBatchedSamplerPacking:
(2, 2),
],
)
@pytest.mark.parametrize("max_seq_length", [4096, 512])
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
@@ -62,6 +59,9 @@ class TestBatchedSamplerPacking:
dataset,
)
train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
@@ -90,7 +90,7 @@ class TestBatchedSamplerPacking:
batch_idxs.extend(pack)
for batch in loader:
assert len(batch["input_ids"]) <= batch_size * max_seq_length
assert batch["input_ids"].numel() <= batch_size * max_seq_length
assert batch["input_ids"].shape[1] == max_seq_length
original_idxs = set(range(len(train_dataset)))