Compare commits
5 Commits
v0.7.0
...
lora-kerne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02efd7e83d | ||
|
|
8dfadc2b3c | ||
|
|
23a9fcb0a7 | ||
|
|
c3d4f6e295 | ||
|
|
7fa690fac8 |
@@ -12,6 +12,7 @@ to leverage operator fusion and tensor re-use in order to improve speed and redu
|
|||||||
memory usage during the forward and backward passes of these calculations.
|
memory usage during the forward and backward passes of these calculations.
|
||||||
|
|
||||||
We currently support several common model architectures, including (but not limited to):
|
We currently support several common model architectures, including (but not limited to):
|
||||||
|
|
||||||
- `llama`
|
- `llama`
|
||||||
- `mistral`
|
- `mistral`
|
||||||
- `qwen2`
|
- `qwen2`
|
||||||
@@ -82,7 +83,7 @@ lora_o_kernel: true
|
|||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||||
- AMD can be used with experimental Triton support by setting the environment variable `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`
|
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
|
||||||
- Targeted LoRA adapters cannot use Dropout
|
- Targeted LoRA adapters cannot use Dropout
|
||||||
- This may limit model expressivity / cause overfitting
|
- This may limit model expressivity / cause overfitting
|
||||||
- Targeted LoRA adapters cannot have bias terms
|
- Targeted LoRA adapters cannot have bias terms
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.7.0"
|
__version__ = "0.8.0.dev0"
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
def get_dataset_lengths(dataset):
|
def get_dataset_lengths(dataset):
|
||||||
if "length" in dataset.data.column_names:
|
if "length" in dataset.column_names:
|
||||||
lengths = np.array(dataset.data.column("length"))
|
lengths = np.array(dataset["length"])
|
||||||
elif "position_ids" in dataset.data.column_names:
|
elif "position_ids" in dataset.column_names:
|
||||||
position_ids = dataset.data.column("position_ids")
|
position_ids = dataset["position_ids"]
|
||||||
lengths = np.array([x[-1] + 1 for x in position_ids])
|
lengths = np.array([x[-1] + 1 for x in position_ids])
|
||||||
else:
|
else:
|
||||||
input_ids = dataset.data.column("input_ids")
|
input_ids = dataset["input_ids"]
|
||||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
lengths = np.array([len(seq) for seq in input_ids])
|
||||||
return lengths
|
return lengths
|
||||||
|
|||||||
@@ -125,6 +125,12 @@ def fixture_llama3_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
|
||||||
|
def fixture_smollm2_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||||
def fixture_mistralv03_tokenizer():
|
def fixture_mistralv03_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
|||||||
61
tests/prompt_strategies/test_dpo_chatml.py
Normal file
61
tests/prompt_strategies/test_dpo_chatml.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""
|
||||||
|
Tests for loading DPO preference datasets with chatml formatting
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||||
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="minimal_dpo_cfg")
|
||||||
|
def fixture_cfg():
|
||||||
|
return DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"rl": "dpo",
|
||||||
|
"learning_rate": 0.000001,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"sequence_len": 2048,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDPOChatml:
|
||||||
|
"""
|
||||||
|
Test loading DPO preference datasets with chatml formatting
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_default(self, minimal_dpo_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
||||||
|
"type": "chatml",
|
||||||
|
"split": "train[:1%]",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
| minimal_dpo_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
# test that dpo.load works
|
||||||
|
load_dpo("chatml", cfg)
|
||||||
|
# now actually load the datasets with the strategy
|
||||||
|
train_ds, _ = load_prepare_preference_datasets(cfg)
|
||||||
|
assert train_ds[0]["prompt"].startswith("<|im_start|>")
|
||||||
|
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
|
||||||
|
assert "chosen" in train_ds[0]
|
||||||
|
assert "rejected" in train_ds[0]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer
|
|||||||
from axolotl.datasets import TokenizedPromptDataset
|
from axolotl.datasets import TokenizedPromptDataset
|
||||||
from axolotl.prompt_strategies.completion import load
|
from axolotl.prompt_strategies.completion import load
|
||||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
@@ -18,11 +19,6 @@ def fixture_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="max_seq_length")
|
|
||||||
def fixture_max_seq_length():
|
|
||||||
return 4096
|
|
||||||
|
|
||||||
|
|
||||||
class TestBatchedSamplerPacking:
|
class TestBatchedSamplerPacking:
|
||||||
"""
|
"""
|
||||||
Test class for packing streaming dataset sequences
|
Test class for packing streaming dataset sequences
|
||||||
@@ -37,6 +33,7 @@ class TestBatchedSamplerPacking:
|
|||||||
(2, 2),
|
(2, 2),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
||||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||||
|
|
||||||
@@ -62,6 +59,9 @@ class TestBatchedSamplerPacking:
|
|||||||
dataset,
|
dataset,
|
||||||
)
|
)
|
||||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||||
|
|
||||||
|
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
|
||||||
|
|
||||||
lengths = get_dataset_lengths(train_dataset)
|
lengths = get_dataset_lengths(train_dataset)
|
||||||
batch_sampler = MultipackBatchSampler(
|
batch_sampler = MultipackBatchSampler(
|
||||||
sampler=RandomSampler(train_dataset),
|
sampler=RandomSampler(train_dataset),
|
||||||
@@ -90,7 +90,7 @@ class TestBatchedSamplerPacking:
|
|||||||
batch_idxs.extend(pack)
|
batch_idxs.extend(pack)
|
||||||
|
|
||||||
for batch in loader:
|
for batch in loader:
|
||||||
assert len(batch["input_ids"]) <= batch_size * max_seq_length
|
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||||
assert batch["input_ids"].shape[1] == max_seq_length
|
assert batch["input_ids"].shape[1] == max_seq_length
|
||||||
|
|
||||||
original_idxs = set(range(len(train_dataset)))
|
original_idxs = set(range(len(train_dataset)))
|
||||||
|
|||||||
Reference in New Issue
Block a user