use fastchat conversations template (#578)

* use fastchat conversations template

* require fastchat (fschat) pip install

* handle roles dynamically from conversation

* tweak fastchat conversation with a monkeypatch to get individual turns

* fix up so it works with multiple conversation styles, and don't strip the turns

* fix sharegpt fixture now that we're using a more correct tokenization

* use a new prompter and support fastchat conversation type

* use sharegpt from prompt strategies now

* update docs, add chatml template

* add a newline after im_end token

* ensure we correctly set system message

* update per PR feedback to handle deprecated sharegpt types

* don't add duplicate wandb req

* make sharegpt fields configurable from yml

* llama2 fixes

* don't fail fatally when turns are improper
This commit is contained in:
Wing Lian
2023-09-27 12:10:45 -04:00
committed by GitHub
parent 60c7c48c97
commit e7d3e2dbb6
13 changed files with 324 additions and 112 deletions

File diff suppressed because one or more lines are too long

View File

@@ -21,7 +21,7 @@ from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
LOG = logging.getLogger("axolotl")
@@ -60,7 +60,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
) as fin:
data = fin.read()
tokenized_conversation = json.loads(data)
prompter = ShareGPTPrompter("chat")
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
@@ -79,7 +79,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
) as fin:
data = fin.read()
conversation = json.loads(data)
prompter = ShareGPTPrompter("chat")
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,

View File

@@ -374,3 +374,26 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
def test_sharegpt_deprecation(self):
cfg = DictDefault(
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`type: sharegpt:chat` will soon be deprecated." in record.message
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt"
cfg = DictDefault(
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`type: sharegpt_simple` will soon be deprecated." in record.message
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt:load_role"