Fix prompt assembly for llama (#952)
* start at index 0 * add test to check for missing turns * apply black * Update test_prompt_tokenizers.py * Update src/axolotl/monkeypatch/fastchat_conversation_turns.py Co-authored-by: Motoki Wu <tokestermw@gmail.com> * fix linting * apply black * add more tests for llama/sharegpt * make logic clearer --------- Co-authored-by: Motoki Wu <tokestermw@gmail.com>
This commit is contained in:
@@ -83,14 +83,21 @@ def get_turns( # pylint: disable=too-many-return-statements
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||
seps = [self.sep, self.sep2]
|
||||
if self.system_message:
|
||||
if self.messages:
|
||||
# For llama, the system message is incorporated into the first human instruction
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt += first_msg
|
||||
self.messages.pop(0)
|
||||
yield "", system_prompt
|
||||
else:
|
||||
yield "", "[INST] "
|
||||
for i, (role, message) in enumerate(self.messages[1:]):
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + " ", message + seps[i % 2]
|
||||
if (i % 2 == 0 and not self.system_message) or (
|
||||
i % 2 != 0 and self.system_message
|
||||
):
|
||||
role = "<s> " + role
|
||||
yield role + " ", message
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user