From 043c3860cd3591fc6999a78d005651ec4baf7bc1 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 10 Jan 2024 13:00:09 +0900 Subject: [PATCH] fix: `train_on_inputs: true` ignored for sharegpt (#1045) [skip ci] * fix: `train_on_inputs: true` ignored for sharegpt * enable unit test for train_on_inputs for sharegpt --------- Co-authored-by: Wing Lian --- src/axolotl/prompt_tokenizers.py | 24 ++++++---- tests/prompt_strategies/test_sharegpt.py | 56 ++++++++++++------------ 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 389ea9a5e..a5c243f7e 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -379,10 +379,12 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): add_eos_token=False, strip_bos_token=True, ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + if self.train_on_inputs: + labels = copy.deepcopy(res["input_ids"]) + else: + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) elif assistant in role: - # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID role = ( role.replace(role_remap[1]["from"], role_remap[1]["to"]) if role_remap @@ -406,18 +408,24 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): add_eos_token=False, strip_bos_token=True, ) - # not masked out from labels labels = copy.deepcopy(res["input_ids"]) - len_role = len(role_res["input_ids"]) - labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels)) + if not self.train_on_inputs: + # mask out role tokens from the labels + len_role = len(role_res["input_ids"]) + labels[:len_role] = [IGNORE_TOKEN_ID] * min( + len_role, len(labels) + ) elif role == "": turn = content # this is only ever the first part, should include the bos token and the user query res = self._tokenize( turn, add_eos_token=False, strip_bos_token=False ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + if self.train_on_inputs: + labels = copy.deepcopy(res["input_ids"]) + else: + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) else: LOG.warning(f"unhandled role: {role}") continue diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index ce33a8c40..ee62ab5d0 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -104,7 +104,7 @@ class TestSharegpt: role_key_human=None, ), tokenizer, - True, # train_on_inputs + False, # train_on_inputs 2048, # sequence_len ) @@ -124,30 +124,30 @@ class TestSharegpt: ] # fmt: on - # def test_no_train_on_input(self, sharegpt_dataset, tokenizer): - # strategy = SimpleShareGPTPromptTokenizingStrategy( - # ShareGPTPrompterV2( - # conversation="chatml", - # role_key_model=None, - # role_key_human=None, - # ), - # tokenizer, - # False, # train_on_inputs - # 2048, # sequence_len - # ) - # - # dataset_wrapper = TokenizedPromptDataset( - # strategy, sharegpt_dataset, process_count=1 - # ) - # - # labels = dataset_wrapper[0]["labels"] - # # fmt: off - # assert labels == [ - # 1, # bos - # 32001, 1587, 13, 25997, 32000, 28705, 13, # system - # 32001, 2188, 13, 21558, 32000, 28705, 13, # human - # 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt - # 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human - # 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt - # ] - # # fmt: on + def test_no_train_on_input(self, sharegpt_dataset, tokenizer): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="chatml", + role_key_model=None, + role_key_human=None, + ), + tokenizer, + True, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, sharegpt_dataset, process_count=1 + ) + + labels = dataset_wrapper[0]["labels"] + # fmt: off + assert labels == [ + 1, # bos + 32001, 1587, 13, 25997, 32000, 28705, 13, # system + 32001, 2188, 13, 21558, 32000, 28705, 13, # human + 32001, 13892, 13, 21558, 32000, 28705, 13, # gpt + 32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human + 32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt + ] + # fmt: on