Feat: add support for datasets with str saved messages field (#3607)
* feat: support datasets saved in str format * add also str for tools * format * fix: address comments + add unit test * format
This commit is contained in:
@@ -394,8 +394,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||
try:
|
||||
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||
isinstance(v, list) for v in prompt[self.prompter.field_messages]
|
||||
return all(isinstance(v, (str, list)) for v in prompt.values()) and all(
|
||||
isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages]
|
||||
)
|
||||
except KeyError:
|
||||
return False
|
||||
@@ -1004,6 +1004,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if tools is None:
|
||||
return None
|
||||
|
||||
# Some datasets have tools set to str
|
||||
if isinstance(tools, str):
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError as e:
|
||||
LOG.error(f"Error parsing tool parameters as JSON. Error: {e}")
|
||||
raise
|
||||
if isinstance(tools, list):
|
||||
# Process each tool to handle JSON string parameters
|
||||
for tool in tools:
|
||||
@@ -1034,6 +1041,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if messages is None:
|
||||
raise ValueError("Messages is null. Please check `field_messages`.")
|
||||
|
||||
if isinstance(messages, str):
|
||||
try:
|
||||
messages = json.loads(messages)
|
||||
except json.JSONDecodeError as e:
|
||||
LOG.error(f"Error parsing messages as JSON. Error: {e}")
|
||||
raise
|
||||
assert isinstance(messages, list), (
|
||||
f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}"
|
||||
)
|
||||
|
||||
# Extra check here to make sure decoded json is a list of dicts.
|
||||
for i, message in enumerate(messages):
|
||||
assert isinstance(message, dict), (
|
||||
f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}"
|
||||
)
|
||||
|
||||
if isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
|
||||
@@ -487,3 +487,70 @@ class TestDatasetPreparation:
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_dataset_with_str_json_data(self, tokenizer):
|
||||
"""
|
||||
Test loading datasets where data is stored as str JSON instead of list of dicts.
|
||||
see: https://github.com/axolotl-ai-cloud/axolotl/pull/3607 for more details.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
import json
|
||||
|
||||
str_json_ds = Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"messages": json.dumps(
|
||||
[
|
||||
{"role": "user", "content": "Hello how are you?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I am doing good thanks",
|
||||
},
|
||||
]
|
||||
)
|
||||
},
|
||||
{
|
||||
"messages": json.dumps(
|
||||
[
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "2+2 equals 4."},
|
||||
]
|
||||
)
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
tmp_ds_path = Path(tmp_dir) / "str_json_dataset.parquet"
|
||||
str_json_ds.to_parquet(tmp_ds_path)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 512,
|
||||
"datasets": [
|
||||
{
|
||||
"path": str(tmp_ds_path),
|
||||
"name": "test_str_json",
|
||||
"type": "chat_template",
|
||||
"field_messages": "messages",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
},
|
||||
],
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
|
||||
assert len(dataset[0]["input_ids"]) > 0
|
||||
|
||||
Reference in New Issue
Block a user