remove skipped test (#2002)
* remove skipped test * use mean_resizing_embeddings with qlora and added tokens * use </s> as pad_token to prevent resize of embeddings * make sure local hub test saves to a tmp dir * use Path so concatenation works * make sure to use tmp_ds_path for data files
This commit is contained in:
@@ -273,7 +273,6 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip("disabled due to upstream issue")
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -282,6 +281,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
|
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
|
||||||
"tokenizer_type": "AutoTokenizer",
|
"tokenizer_type": "AutoTokenizer",
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
|
"mean_resizing_embeddings": True,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
@@ -297,7 +297,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|end_of_text|>",
|
"pad_token": "</s>",
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -367,43 +367,44 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
def test_load_local_hub_with_revision(self):
|
def test_load_local_hub_with_revision(self):
|
||||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
with tempfile.TemporaryDirectory() as tmp_dir2:
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
|
||||||
snapshot_download(
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
snapshot_download(
|
||||||
repo_type="dataset",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
local_dir=tmp_ds_path,
|
repo_type="dataset",
|
||||||
revision="d05c1cb",
|
local_dir=tmp_ds_path,
|
||||||
)
|
revision="d05c1cb",
|
||||||
|
)
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
"ds_type": "parquet",
|
"ds_type": "parquet",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
"data_files": [
|
"data_files": [
|
||||||
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||||
],
|
],
|
||||||
"revision": "d05c1cb",
|
"revision": "d05c1cb",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
self.tokenizer, cfg, prepared_path
|
self.tokenizer, cfg, prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user