Feat: add support for datasets with str saved messages field (#3607)
* feat: support datasets saved in str format * add also str for tools * format * fix: address comments + add unit test * format
This commit is contained in:
@@ -487,3 +487,70 @@ class TestDatasetPreparation:
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_dataset_with_str_json_data(self, tokenizer):
|
||||
"""
|
||||
Test loading datasets where data is stored as str JSON instead of list of dicts.
|
||||
see: https://github.com/axolotl-ai-cloud/axolotl/pull/3607 for more details.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
import json
|
||||
|
||||
str_json_ds = Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"messages": json.dumps(
|
||||
[
|
||||
{"role": "user", "content": "Hello how are you?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I am doing good thanks",
|
||||
},
|
||||
]
|
||||
)
|
||||
},
|
||||
{
|
||||
"messages": json.dumps(
|
||||
[
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "2+2 equals 4."},
|
||||
]
|
||||
)
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
tmp_ds_path = Path(tmp_dir) / "str_json_dataset.parquet"
|
||||
str_json_ds.to_parquet(tmp_ds_path)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 512,
|
||||
"datasets": [
|
||||
{
|
||||
"path": str(tmp_ds_path),
|
||||
"name": "test_str_json",
|
||||
"type": "chat_template",
|
||||
"field_messages": "messages",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
},
|
||||
],
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
|
||||
assert len(dataset[0]["input_ids"]) > 0
|
||||
|
||||
Reference in New Issue
Block a user