diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 8b452ae19..321f19554 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -143,6 +143,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): role_map[t[role_key]] if t[role_key] in role_map else t[role_key] ), "value": t[value_key], + "weight": 1 + if "weight" not in t or t["weight"] is None + else t["weight"], } for t in conversations ] diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index bb13cf76d..11dd084a8 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -377,7 +377,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): LOG.warning(f"expected tuple, got {part}") continue - role, content = part + if len(part) <= 2: + role, content = part + weight = 1 + else: + role, content, weight = part # Uses "in" because role contains extra characters input_turn = any(r.lower() in role.lower() for r in input_roles) @@ -403,7 +407,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): add_eos_token=False, strip_bos_token=True, ) - if self.train_on_inputs: + if self.train_on_inputs and weight == 1: labels = copy.deepcopy(res["input_ids"]) else: # everything from this is masked out from the labels @@ -439,13 +443,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): labels[:len_role] = [IGNORE_TOKEN_ID] * min( len_role, len(labels) ) + if weight == 0: + # everything from this is masked out from the labels + # (role is masked out too because it makes no sense if contents is masked out) + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + elif empty_role: turn = content # this is only ever the first part, should include the bos token and the user query res = self._tokenize( turn, add_eos_token=False, strip_bos_token=False ) - if self.train_on_inputs: + if self.train_on_inputs and weight == 1: labels = copy.deepcopy(res["input_ids"]) else: # everything from this is masked out from the labels diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 102e9e53b..0ffa3e55f 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -319,6 +319,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods conv = self._conversation.copy() + original_source = source.copy() # Add the conversation system prompt if provided, otherwise use the default one if source[0]["from"] == "system": conv.set_system_message(source[0]["value"]) @@ -360,8 +361,27 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") conv.append_message(role, sentence["value"]) - - return conv.get_turns() + turns = list(conv.get_turns()) + original_source_length = len(original_source) + assert len(turns) in [ + original_source_length - 1, + original_source_length, + original_source_length + 1, + ] + if len(turns) == original_source_length + 1: + original_source = [{"weight": None}] + original_source + elif len(turns) == original_source_length - 1: + original_source = original_source[1:] + return [ + (*turn, weight) + for turn, weight in zip( + turns, + [ + 1 if "weight" not in e or e["weight"] is None else e["weight"] + for e in original_source + ], + ) + ] def build_prompt(self, source) -> Generator[str, None, None]: turns = self._build_result(source) diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index 6e6909834..aba53cd5f 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -52,6 +52,51 @@ def fixture_sharegpt_dataset(): ) +@pytest.fixture(name="sharegpt_dataset_with_weights") +def fixture_sharegpt_dataset_with_weights(): + return Dataset.from_list( + [ + { + "conversations": [ + { + "from": "system", + "value": "repeat", + }, + { + "from": "human", + "value": "hello", + "weight": 1, + }, + { + "from": "gpt", + "value": "hello", + "weight": 0, + }, + { + "from": "human", + "value": "rehello", + "weight": 0, + }, + { + "from": "gpt", + "value": "rehello", + "weight": 1, + }, + { + "from": "human", + "value": "goodbye", + }, + { + "from": "gpt", + "value": "goodbye", + "weight": 0, + }, + ] + } + ] + ) + + @pytest.fixture(name="glaive_dataset") def fixture_sharegpt_glaive_dataset(): return Dataset.from_list( @@ -162,6 +207,46 @@ class TestSharegptLlama3: ] # fmt: on + def test_tokenization_with_weights( + self, sharegpt_dataset_with_weights, llama3_tokenizer + ): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="llama3", + role_key_model=None, + role_key_human=None, + ), + llama3_tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, sharegpt_dataset_with_weights, process_count=1 + ) + + input_ids = dataset_wrapper[0]["input_ids"] + + # fmt: off + assert input_ids == [ + 128000, # bos + 128006, 9125, 128007, # system header + 271, 31724, 128009, # sys prompt, eot + 128006, 882, 128007, # user header + 271, 15339, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 15339, 128009, # assistant response eot + 128006, 882, 128007, + 271, 11310, 4896, 128009, + 128006, 78191, 128007, + 271, 11310, 4896, 128009, + 128006, 882, 128007, + 271, 19045, 29474, 128009, + 128006, 78191, 128007, + 271, 19045, 29474, 128009, + ] + # fmt: on + class TestSharegptChatML: """ @@ -197,7 +282,40 @@ class TestSharegptChatML: ] # fmt: on - def test_w_train_on_input(self, sharegpt_dataset, tokenizer): + def test_no_double_im_end_with_weights( + self, sharegpt_dataset_with_weights, tokenizer + ): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="chatml", + role_key_model=None, + role_key_human=None, + ), + tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, sharegpt_dataset_with_weights, process_count=1 + ) + + input_ids = dataset_wrapper[0]["input_ids"] + # fmt: off + assert input_ids == [ + # 28705, 13, is " \n" + 1, # bos + 32001, 1587, 13, 25997, 32000, 28705, 13, # system + 32001, 2188, 13, 21558, 32000, 28705, 13, # human + 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt + 32001, 2188, 13, 267, 21558, 32000, 28705, 13, # human + 32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt + 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human + 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt + ] + # fmt: on + + def test_no_train_on_input(self, sharegpt_dataset, tokenizer): strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation="chatml", @@ -225,7 +343,39 @@ class TestSharegptChatML: ] # fmt: on - def test_no_train_on_input(self, sharegpt_dataset, tokenizer): + def test_no_train_on_input_with_weights( + self, sharegpt_dataset_with_weights, tokenizer + ): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="chatml", + role_key_model=None, + role_key_human=None, + ), + tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, sharegpt_dataset_with_weights, process_count=1 + ) + + labels = dataset_wrapper[0]["labels"] + # fmt: off + assert labels == [ + -100, # bos + -100, -100, -100, -100, -100, -100, -100, # system + -100, -100, -100, -100, -100, -100, -100, # human + -100, -100, -100, -100, -100, -100, -100, # gpt with weight zero + -100, -100, -100, -100, -100, -100, -100, -100, # human + -100, -100, 13, 267, 21558, 32000, 28705, 13, # gpt + -100, -100, -100, -100, -100, -100, -100, -100, # human + -100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight zero + ] + # fmt: on + + def test_w_train_on_input(self, sharegpt_dataset, tokenizer): strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation="chatml", @@ -253,6 +403,38 @@ class TestSharegptChatML: ] # fmt: on + def test_w_train_on_input_with_weights( + self, sharegpt_dataset_with_weights, tokenizer + ): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="chatml", + role_key_model=None, + role_key_human=None, + ), + tokenizer, + True, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, sharegpt_dataset_with_weights, process_count=1 + ) + + labels = dataset_wrapper[0]["labels"] + # fmt: off + assert labels == [ + 1, # bos + 32001, 1587, 13, 25997, 32000, 28705, 13, # system + 32001, 2188, 13, 21558, 32000, 28705, 13, # human + -100, -100, -100, -100, -100, -100, -100, # gpt with weight 0 + -100, -100, -100, -100, -100, -100, -100, -100, # human with weight 0 + 32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt + 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human + -100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight 0 + ] + # fmt: on + def test_chatml_glaive(self, glaive_dataset, tokenizer): strategy = GlaiveShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2(