diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index a943a1448..be6a38800 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -394,8 +394,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def is_prompt_batched(self, prompt: dict[str, Any]) -> bool: try: - return all(isinstance(v, list) for v in prompt.values()) and all( - isinstance(v, list) for v in prompt[self.prompter.field_messages] + return all(isinstance(v, (str, list)) for v in prompt.values()) and all( + isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages] ) except KeyError: return False @@ -1004,6 +1004,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): if tools is None: return None + # Some datasets have tools set to str + if isinstance(tools, str): + try: + tools = json.loads(tools) + except json.JSONDecodeError as e: + LOG.error(f"Error parsing tool parameters as JSON. Error: {e}") + raise if isinstance(tools, list): # Process each tool to handle JSON string parameters for tool in tools: @@ -1034,6 +1041,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): if messages is None: raise ValueError("Messages is null. Please check `field_messages`.") + if isinstance(messages, str): + try: + messages = json.loads(messages) + except json.JSONDecodeError as e: + LOG.error(f"Error parsing messages as JSON. Error: {e}") + raise + assert isinstance(messages, list), ( + f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}" + ) + + # Extra check here to make sure decoded json is a list of dicts. + for i, message in enumerate(messages): + assert isinstance(message, dict), ( + f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}" + ) + if isinstance(messages, list): return messages diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 3b24ad580..bdb795e13 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -487,3 +487,70 @@ class TestDatasetPreparation: assert "attention_mask" in dataset.features assert "labels" in dataset.features shutil.rmtree(tmp_ds_path) + + @enable_hf_offline + def test_load_dataset_with_str_json_data(self, tokenizer): + """ + Test loading datasets where data is stored as str JSON instead of list of dicts. + see: https://github.com/axolotl-ai-cloud/axolotl/pull/3607 for more details. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + import json + + str_json_ds = Dataset.from_list( + [ + { + "messages": json.dumps( + [ + {"role": "user", "content": "Hello how are you?"}, + { + "role": "assistant", + "content": "I am doing good thanks", + }, + ] + ) + }, + { + "messages": json.dumps( + [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + ] + ) + }, + ] + ) + + tmp_ds_path = Path(tmp_dir) / "str_json_dataset.parquet" + str_json_ds.to_parquet(tmp_ds_path) + + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 512, + "datasets": [ + { + "path": str(tmp_ds_path), + "name": "test_str_json", + "type": "chat_template", + "field_messages": "messages", + "message_field_role": "role", + "message_field_content": "content", + }, + ], + "dataset_num_proc": 4, + } + ) + + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) + + assert len(dataset) == 2 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + assert len(dataset[0]["input_ids"]) > 0