wip add new proposed message structure (#1904)
* wip add new proposed message structure * tokenization * wip * wip transform builder * wip make the chat dataset loadable * wip chatml + llama 3 new chat objects * chore: lint * chore: lint * fix tokenization * remove dacite dependency since we're using pydantic now * fix handling when already correctly split in messages * make sure to remove chat features from tokenized ds * move chat to be a input transform for messages * make sure llama3 has the bos token * remove non-working special token code * fix messages strat loader
This commit is contained in:
315
requirements_env.txt
Normal file
315
requirements_env.txt
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
accelerate==0.34.1
|
||||||
|
addict==2.4.0
|
||||||
|
aiofiles==23.2.1
|
||||||
|
aiohttp==3.9.0
|
||||||
|
aiosignal==1.3.1
|
||||||
|
aiostream==0.5.2
|
||||||
|
alembic==1.13.1
|
||||||
|
annotated-types==0.6.0
|
||||||
|
annoy==1.17.3
|
||||||
|
ansible==6.7.0
|
||||||
|
ansible-core==2.13.13
|
||||||
|
ansible-vault==2.1.0
|
||||||
|
anyio==3.7.1
|
||||||
|
appdirs==1.4.4
|
||||||
|
art==6.0
|
||||||
|
asgiref==3.7.2
|
||||||
|
async-timeout==4.0.2
|
||||||
|
attrdict==2.0.1
|
||||||
|
attrs==22.2.0
|
||||||
|
awscli==1.32.75
|
||||||
|
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
|
||||||
|
backoff==2.2.1
|
||||||
|
base58==2.1.1
|
||||||
|
beartype==0.17.2
|
||||||
|
bitnet==0.2.1
|
||||||
|
bitsandbytes==0.42.0
|
||||||
|
bittensor==6.7.0
|
||||||
|
black==23.7.0
|
||||||
|
blinker==1.7.0
|
||||||
|
boto3==1.34.75
|
||||||
|
botocore==1.34.75
|
||||||
|
cachetools==5.3.3
|
||||||
|
cachy==0.1.1
|
||||||
|
certifi==2023.7.22
|
||||||
|
cffi==1.16.0
|
||||||
|
cfgv==3.3.1
|
||||||
|
chai-guanaco==1.2.4
|
||||||
|
charset-normalizer==3.2.0
|
||||||
|
cleo==0.6.8
|
||||||
|
click==8.1.7
|
||||||
|
cloudpickle==2.0.0
|
||||||
|
cohere==4.11.2
|
||||||
|
colorama==0.4.4
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
CoLT5-attention==0.10.20
|
||||||
|
contextlib2==21.6.0
|
||||||
|
contourpy==1.2.0
|
||||||
|
cryptography==41.0.3
|
||||||
|
cycler==0.12.1
|
||||||
|
cytoolz==0.12.3
|
||||||
|
databricks-cli==0.18.0
|
||||||
|
dataclasses-json==0.5.7
|
||||||
|
datasets==2.11.0
|
||||||
|
ddt==1.6.0
|
||||||
|
decorator==5.1.1
|
||||||
|
deepspeed==0.15.0
|
||||||
|
# Editable Git install with no remote (dialogpt==0.1)
|
||||||
|
-e /Users/wing/Projects/ml/dialogpt/src
|
||||||
|
dill==0.3.6
|
||||||
|
distlib==0.3.6
|
||||||
|
docker==7.0.0
|
||||||
|
docker-pycreds==0.4.0
|
||||||
|
docstring-parser==0.15
|
||||||
|
docutils==0.16
|
||||||
|
ecdsa==0.18.0
|
||||||
|
einops==0.7.0
|
||||||
|
einops-exts==0.0.4
|
||||||
|
einx==0.1.3
|
||||||
|
entrypoints==0.4
|
||||||
|
eth-hash==0.6.0
|
||||||
|
eth-keys==0.5.0
|
||||||
|
eth-typing==4.0.0
|
||||||
|
eth-utils==2.3.1
|
||||||
|
evaluate==0.4.0
|
||||||
|
exceptiongroup==1.1.1
|
||||||
|
fastapi==0.109.2
|
||||||
|
fastcore==1.5.29
|
||||||
|
ffmpy==0.4.0
|
||||||
|
filelock==3.12.2
|
||||||
|
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
|
||||||
|
fire==0.5.0
|
||||||
|
first==2.0.2
|
||||||
|
flake8==7.0.0
|
||||||
|
Flask==3.0.1
|
||||||
|
fonttools==4.47.2
|
||||||
|
frozendict==2.4.1
|
||||||
|
frozenlist==1.3.3
|
||||||
|
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||||
|
fsspec==2023.6.0
|
||||||
|
fuzzywuzzy==0.18.0
|
||||||
|
gitdb==4.0.10
|
||||||
|
GitPython==3.1.31
|
||||||
|
google-pasta==0.2.0
|
||||||
|
gradio==4.42.0
|
||||||
|
gradio_client==1.3.0
|
||||||
|
greenlet==2.0.2
|
||||||
|
grpclib==0.4.7
|
||||||
|
gunicorn==21.2.0
|
||||||
|
h11==0.14.0
|
||||||
|
h2==4.1.0
|
||||||
|
hpack==4.0.0
|
||||||
|
httpcore==0.17.3
|
||||||
|
httpx==0.24.1
|
||||||
|
huggingface-hub==0.23.4
|
||||||
|
humanfriendly==10.0
|
||||||
|
hyperframe==6.0.1
|
||||||
|
identify==2.5.24
|
||||||
|
idna==3.4
|
||||||
|
immutables==0.20
|
||||||
|
importlib-metadata==6.7.0
|
||||||
|
importlib-resources==6.1.1
|
||||||
|
inflection==0.5.1
|
||||||
|
iniconfig==2.0.0
|
||||||
|
itsdangerous==2.1.2
|
||||||
|
Jinja2==3.1.2
|
||||||
|
jmespath==1.0.1
|
||||||
|
joblib==1.3.2
|
||||||
|
jsonlines==3.1.0
|
||||||
|
jsonschema==2.6.0
|
||||||
|
kiwisolver==1.4.5
|
||||||
|
langchain==0.0.144
|
||||||
|
Levenshtein==0.24.0
|
||||||
|
libcst==1.1.0
|
||||||
|
liger-kernel==0.0.0
|
||||||
|
lion-pytorch==0.1.2
|
||||||
|
llama-cpp-python==0.1.36
|
||||||
|
llvmlite==0.40.1
|
||||||
|
local-attention==1.9.0
|
||||||
|
loguru==0.7.0
|
||||||
|
Mako==1.3.2
|
||||||
|
Markdown==3.5.2
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
markdown2==2.4.10
|
||||||
|
MarkupSafe==2.1.2
|
||||||
|
marshmallow==3.19.0
|
||||||
|
marshmallow-enum==1.5.1
|
||||||
|
matplotlib==3.8.2
|
||||||
|
mccabe==0.7.0
|
||||||
|
mdurl==0.1.2
|
||||||
|
MEGABYTE-pytorch==0.0.7
|
||||||
|
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
|
||||||
|
mlflow==2.10.0
|
||||||
|
modal==0.62.77
|
||||||
|
more-itertools==10.2.0
|
||||||
|
mpmath==1.2.1
|
||||||
|
msgpack==1.0.7
|
||||||
|
msgpack-numpy-opentensor==0.5.0
|
||||||
|
multidict==6.0.4
|
||||||
|
multiprocess==0.70.14
|
||||||
|
munch==2.5.0
|
||||||
|
mypy==1.3.0
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
nest-asyncio==1.6.0
|
||||||
|
netaddr==0.10.1
|
||||||
|
networkx==3.0rc1
|
||||||
|
nh3==0.2.14
|
||||||
|
nodeenv==1.8.0
|
||||||
|
nomic==2.0.2
|
||||||
|
numba==0.57.1
|
||||||
|
numexpr==2.8.4
|
||||||
|
numpy==1.24.4
|
||||||
|
oauthlib==3.2.2
|
||||||
|
openai==0.27.4
|
||||||
|
openapi==1.1.0
|
||||||
|
openapi-schema-pydantic==1.2.4
|
||||||
|
optimum==1.8.6
|
||||||
|
orjson==3.10.7
|
||||||
|
packaging==23.1
|
||||||
|
pandas==2.0.0
|
||||||
|
parameterized==0.9.0
|
||||||
|
password-strength==0.0.3.post2
|
||||||
|
pastel==0.1.1
|
||||||
|
pathos==0.3.0
|
||||||
|
pathspec==0.11.1
|
||||||
|
pathtools==0.1.2
|
||||||
|
peft==0.11.1
|
||||||
|
pendulum==3.0.0
|
||||||
|
Pillow==9.5.0
|
||||||
|
pip-tools==1.11.0
|
||||||
|
platformdirs==3.2.0
|
||||||
|
pluggy==1.4.0
|
||||||
|
poetry==0.7.1
|
||||||
|
pox==0.3.2
|
||||||
|
ppft==1.7.6.6
|
||||||
|
pre-commit==3.3.2
|
||||||
|
prettytable==3.10.0
|
||||||
|
prompt-toolkit==3.0.39
|
||||||
|
protobuf==3.20.2
|
||||||
|
protobuf3-to-dict==0.1.5
|
||||||
|
psutil==5.9.5
|
||||||
|
psycopg==3.1.18
|
||||||
|
PuLP==2.8.0
|
||||||
|
py==1.11.0
|
||||||
|
py-bip39-bindings==0.1.11
|
||||||
|
py-cpuinfo==9.0.0
|
||||||
|
py-ed25519-zebra-bindings==1.0.1
|
||||||
|
py-sr25519-bindings==0.2.0
|
||||||
|
pyarrow==11.0.0
|
||||||
|
pyasn1==0.6.0
|
||||||
|
pycodestyle==2.11.1
|
||||||
|
pycparser==2.21
|
||||||
|
pycryptodome==3.20.0
|
||||||
|
pydantic==2.5.3
|
||||||
|
pydantic_core==2.14.6
|
||||||
|
pydub==0.25.1
|
||||||
|
pyfiglet==0.8.post1
|
||||||
|
pyflakes==3.2.0
|
||||||
|
Pygments==2.15.1
|
||||||
|
PyJWT==2.8.0
|
||||||
|
pylev==1.4.0
|
||||||
|
PyNaCl==1.5.0
|
||||||
|
pynvml==11.5.0
|
||||||
|
pyparsing==2.4.7
|
||||||
|
pyrsistent==0.14.11
|
||||||
|
pytest==8.0.2
|
||||||
|
pytest-asyncio==0.23.4
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
python-Levenshtein==0.24.0
|
||||||
|
python-multipart==0.0.9
|
||||||
|
pytz==2023.3
|
||||||
|
PyYAML==6.0.1
|
||||||
|
querystring-parser==1.2.4
|
||||||
|
rapidfuzz==3.6.1
|
||||||
|
regex==2023.6.3
|
||||||
|
requests==2.31.0
|
||||||
|
requests-toolbelt==0.8.0
|
||||||
|
resolvelib==0.8.1
|
||||||
|
responses==0.18.0
|
||||||
|
retry==0.9.2
|
||||||
|
rich==13.7.0
|
||||||
|
rsa==4.7.2
|
||||||
|
ruff==0.6.3
|
||||||
|
s3transfer==0.10.1
|
||||||
|
safetensors==0.4.5
|
||||||
|
sagemaker==2.148.0
|
||||||
|
scalecodec==1.2.7
|
||||||
|
schedulefree==1.2.1
|
||||||
|
schema==0.7.5
|
||||||
|
scikit-learn==1.4.0
|
||||||
|
scipy==1.9.3
|
||||||
|
seaborn==0.13.2
|
||||||
|
semantic-version==2.10.0
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
sentry-sdk==1.19.1
|
||||||
|
setproctitle==1.3.2
|
||||||
|
shellingham==1.5.4
|
||||||
|
shortuuid==1.0.11
|
||||||
|
shtab==1.6.5
|
||||||
|
sigtools==4.0.1
|
||||||
|
six==1.16.0
|
||||||
|
skypilot==0.4.1
|
||||||
|
smdebug-rulesconfig==1.0.1
|
||||||
|
smmap==5.0.0
|
||||||
|
sniffio==1.3.0
|
||||||
|
SQLAlchemy==1.4.47
|
||||||
|
sqlparse==0.4.4
|
||||||
|
starlette==0.36.3
|
||||||
|
substrate-interface==1.5.2
|
||||||
|
svgwrite==1.4.3
|
||||||
|
sympy==1.11.1
|
||||||
|
synchronicity==0.6.7
|
||||||
|
tabulate==0.9.0
|
||||||
|
tblib==1.7.0
|
||||||
|
tenacity==8.2.2
|
||||||
|
tensor-parallel==2.0.0
|
||||||
|
termcolor==2.2.0
|
||||||
|
text2art==0.2.0
|
||||||
|
threadpoolctl==3.2.0
|
||||||
|
tiktoken==0.6.0
|
||||||
|
time-machine==2.14.1
|
||||||
|
timm==0.9.16
|
||||||
|
tokenizers==0.19.1
|
||||||
|
tokenmonster==1.1.12
|
||||||
|
toml==0.9.6
|
||||||
|
tomli==2.0.1
|
||||||
|
tomlkit==0.12.0
|
||||||
|
toolz==0.12.1
|
||||||
|
torch==2.2.0
|
||||||
|
torchdata==0.6.1
|
||||||
|
torchdiffeq==0.2.3
|
||||||
|
TorchFix==0.4.0
|
||||||
|
torchtext==0.15.2
|
||||||
|
torchvision==0.17.0
|
||||||
|
tqdm==4.66.2
|
||||||
|
transformers==4.44.2
|
||||||
|
trl==0.9.6
|
||||||
|
typer==0.12.5
|
||||||
|
types-certifi==2021.10.8.3
|
||||||
|
types-requests==2.31.0.20240125
|
||||||
|
types-setuptools==69.0.0.20240125
|
||||||
|
types-toml==0.10.8.7
|
||||||
|
typing==3.7.4.3
|
||||||
|
typing-inspect==0.8.0
|
||||||
|
typing_extensions==4.9.0
|
||||||
|
tyro==0.5.18
|
||||||
|
tzdata==2023.3
|
||||||
|
unique-names-generator==1.0.2
|
||||||
|
urllib3==2.2.2
|
||||||
|
uvicorn==0.22.0
|
||||||
|
vector_quantize_pytorch==1.14.1
|
||||||
|
virtualenv==20.23.0
|
||||||
|
voyager==2.0.2
|
||||||
|
wandb==0.16.2
|
||||||
|
watchfiles==0.21.0
|
||||||
|
wavedrom==2.0.3.post3
|
||||||
|
wcwidth==0.2.6
|
||||||
|
websocket-client==1.7.0
|
||||||
|
websockets==12.0
|
||||||
|
Werkzeug==3.0.1
|
||||||
|
wonderwords==2.2.0
|
||||||
|
xxhash==3.2.0
|
||||||
|
yarl==1.8.2
|
||||||
|
zetascale==2.2.7
|
||||||
|
zipp==3.15.0
|
||||||
@@ -27,6 +27,7 @@ from axolotl.prompt_strategies.sharegpt import (
|
|||||||
register_chatml_template,
|
register_chatml_template,
|
||||||
register_llama3_template,
|
register_llama3_template,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.trainer import disable_datasets_caching
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||||
|
|
||||||
@@ -70,10 +71,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
LOG.warning(msg)
|
LOG.warning(msg)
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
with disable_datasets_caching():
|
||||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||||
else:
|
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
else:
|
||||||
|
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
if parsed_cli_args.download:
|
if parsed_cli_args.download:
|
||||||
model_name = parsed_cfg.base_model
|
model_name = parsed_cfg.base_model
|
||||||
|
|||||||
0
src/axolotl/core/chat/__init__.py
Normal file
0
src/axolotl/core/chat/__init__.py
Normal file
0
src/axolotl/core/chat/format/__init__.py
Normal file
0
src/axolotl/core/chat/format/__init__.py
Normal file
34
src/axolotl/core/chat/format/chatml.py
Normal file
34
src/axolotl/core/chat/format/chatml.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""
|
||||||
|
ChatML transformation functions for MessageContents
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..messages import MessageContents, Messages
|
||||||
|
from .shared import wrap_tools
|
||||||
|
|
||||||
|
|
||||||
|
def format_message(
|
||||||
|
message: Messages,
|
||||||
|
message_index: Optional[int] = None, # pylint: disable=unused-argument
|
||||||
|
) -> Messages:
|
||||||
|
if message.is_chat_formatted:
|
||||||
|
return message
|
||||||
|
|
||||||
|
# prepend the role prefix within a MessageContents to message.content
|
||||||
|
message.content.insert(
|
||||||
|
0,
|
||||||
|
MessageContents(
|
||||||
|
type="text",
|
||||||
|
value=f"<|im_start|>{message.role}\n",
|
||||||
|
weight=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
message.content.append(
|
||||||
|
MessageContents(type="text", value="<|im_end|>", weight=message.weight)
|
||||||
|
)
|
||||||
|
message.content.append(MessageContents(type="text", value="\n", weight=0))
|
||||||
|
|
||||||
|
message = wrap_tools(message)
|
||||||
|
|
||||||
|
message.is_chat_formatted = True
|
||||||
|
return message
|
||||||
45
src/axolotl/core/chat/format/llama3x.py
Normal file
45
src/axolotl/core/chat/format/llama3x.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
Llama 3.x chat formatting functions for MessageContents
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..messages import MessageContents, Messages
|
||||||
|
from .shared import wrap_tools
|
||||||
|
|
||||||
|
|
||||||
|
def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
|
||||||
|
if message.is_chat_formatted:
|
||||||
|
return message
|
||||||
|
|
||||||
|
message_role = message.role
|
||||||
|
if message.role == "tool":
|
||||||
|
message_role = "ipython"
|
||||||
|
|
||||||
|
# prepend the role prefix within a MessageContents to message.content
|
||||||
|
message.content.insert(
|
||||||
|
0,
|
||||||
|
MessageContents(
|
||||||
|
type="text",
|
||||||
|
value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n",
|
||||||
|
weight=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
message.content.append(
|
||||||
|
MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
|
||||||
|
)
|
||||||
|
|
||||||
|
message = wrap_tools(message)
|
||||||
|
|
||||||
|
if message_index == 0:
|
||||||
|
message.content.insert(
|
||||||
|
0,
|
||||||
|
MessageContents(
|
||||||
|
type="text",
|
||||||
|
value="<|begin_of_text|>",
|
||||||
|
weight=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
message.is_chat_formatted = True
|
||||||
|
return message
|
||||||
47
src/axolotl/core/chat/format/shared.py
Normal file
47
src/axolotl/core/chat/format/shared.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
shared functions for format transforms
|
||||||
|
"""
|
||||||
|
from axolotl.core.chat.messages import MessageContents, Messages
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_tools(message: Messages):
|
||||||
|
# loop over message.content by index to find tool calls, we need to wrap each with tags,
|
||||||
|
# so be wary of indexing issues when changing the list while iterating.
|
||||||
|
# iterate over the range in reverse order to avoid index shifting
|
||||||
|
for i in range(len(message.content) - 1, -1, -1):
|
||||||
|
if message.content[i].type == "tool_call":
|
||||||
|
# append a </tool_call> MessageContents text tag after
|
||||||
|
message.content.insert(
|
||||||
|
i + 1,
|
||||||
|
MessageContents(
|
||||||
|
type="text", value="</tool_call>\n", weight=message.weight
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# make sure the actual tool call content ends with a newline
|
||||||
|
message.content[i].has_newline = True
|
||||||
|
# prepend a <tool_call> MessageContents text tag before
|
||||||
|
message.content.insert(
|
||||||
|
i,
|
||||||
|
MessageContents(
|
||||||
|
type="text", value="<tool_call>\n", weight=message.weight
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif message.content[i].type == "tool_response":
|
||||||
|
# append a </tool_call> MessageContents text tag after
|
||||||
|
message.content.insert(
|
||||||
|
i + 1,
|
||||||
|
MessageContents(
|
||||||
|
type="text", value="</tool_response>\n", weight=message.weight
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# make sure the actual tool response content ends with a newline
|
||||||
|
message.content[i].has_newline = True
|
||||||
|
# prepend a <tool_call> MessageContents text tag before
|
||||||
|
message.content.insert(
|
||||||
|
i,
|
||||||
|
MessageContents(
|
||||||
|
type="text", value="<tool_response>\n", weight=message.weight
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return message
|
||||||
230
src/axolotl/core/chat/messages.py
Normal file
230
src/axolotl/core/chat/messages.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""
|
||||||
|
internal message representations of chat messages
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRoles(str, Enum):
|
||||||
|
"""
|
||||||
|
Message roles for the system, user, assistant, and tools
|
||||||
|
"""
|
||||||
|
|
||||||
|
system = "system" # pylint: disable=invalid-name
|
||||||
|
user = "user" # pylint: disable=invalid-name
|
||||||
|
assistant = "assistant" # pylint: disable=invalid-name
|
||||||
|
tool = "tool" # pylint: disable=invalid-name
|
||||||
|
ipython = ( # pylint: disable=invalid-name
|
||||||
|
# for responses from builtin tools
|
||||||
|
"ipython"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageContentTypes(str, Enum):
|
||||||
|
"""
|
||||||
|
Message content types for text, image, audio, tool calls, and tool responses
|
||||||
|
"""
|
||||||
|
|
||||||
|
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
|
||||||
|
text = "text" # pylint: disable=invalid-name
|
||||||
|
image = "image" # pylint: disable=invalid-name
|
||||||
|
audio = "audio" # pylint: disable=invalid-name
|
||||||
|
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
|
||||||
|
tool_response = "tool_response" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialToken(str, Enum):
|
||||||
|
"""
|
||||||
|
Special tokens for beginning of string and end of string
|
||||||
|
"""
|
||||||
|
|
||||||
|
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
|
||||||
|
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallFunction(BaseModel):
|
||||||
|
"""
|
||||||
|
Tool call function with name and arguments
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
arguments: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class Tool(BaseModel):
|
||||||
|
"""
|
||||||
|
Tool with description, function, and parameters
|
||||||
|
"""
|
||||||
|
|
||||||
|
description: str
|
||||||
|
function: ToolCallFunction
|
||||||
|
parameters: dict[str, str] # .properties
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallContents(BaseModel):
|
||||||
|
"""
|
||||||
|
Tool call contents with name, arguments, and optional id
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
arguments: dict[str, Union[str, int]]
|
||||||
|
id: Optional[str] = None # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
data = {"name": self.name, "arguments": self.arguments}
|
||||||
|
if self.id is not None:
|
||||||
|
data["id"] = self.id
|
||||||
|
return json.dumps(data)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResponseContents(BaseModel):
|
||||||
|
"""
|
||||||
|
Tool response contents with name, content, and optional id
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
content: Union[str, dict[str, Union[str, int, float]]]
|
||||||
|
id: Optional[str] = None # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
data = {"name": self.name, "content": self.content}
|
||||||
|
if self.id is not None:
|
||||||
|
data["id"] = self.id
|
||||||
|
return json.dumps(data)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageContents(BaseModel):
|
||||||
|
"""
|
||||||
|
Message contents with type, value, metadata, weight, newline, and end of contents
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Union[str, MessageContentTypes]
|
||||||
|
value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
|
||||||
|
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
|
||||||
|
weight: Optional[Union[int, float]] = None
|
||||||
|
has_newline: bool = False
|
||||||
|
eoc: bool = False # end of contents
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
str_val = str(self.value)
|
||||||
|
if self.has_newline and not str_val.endswith("\n"):
|
||||||
|
str_val += "\n"
|
||||||
|
return str_val
|
||||||
|
|
||||||
|
|
||||||
|
class Messages(BaseModel):
|
||||||
|
"""
|
||||||
|
Messages with role, content, metadata, weight, and chat formatting
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: Union[MessageRoles, str] # allows for arbitrary roles
|
||||||
|
content: List["MessageContents"]
|
||||||
|
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
|
||||||
|
weight: Optional[Union[int, float]] = None
|
||||||
|
is_chat_formatted: bool = False
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "".join(str(c) for c in self.content)
|
||||||
|
|
||||||
|
def tokenized(
|
||||||
|
self, tokenizer: PreTrainedTokenizer, ignore_index=-100
|
||||||
|
) -> dict[str, List[int]]:
|
||||||
|
# iterate over the contents, tokenizing the concatenated string values up to the current MessageContents
|
||||||
|
# returns a dictionary mapping w input_ids, attention_mask, and labels
|
||||||
|
input_ids: List[int] = []
|
||||||
|
labels: List[int] = []
|
||||||
|
pending_input_ids: List[int] = []
|
||||||
|
pending_weight = self.weight
|
||||||
|
running_content = ""
|
||||||
|
for _, msg_content in enumerate(self.content):
|
||||||
|
# TODO also handle non-text content types
|
||||||
|
if msg_content.type in [
|
||||||
|
MessageContentTypes.text.value,
|
||||||
|
MessageContentTypes.tool_call.value,
|
||||||
|
MessageContentTypes.tool_response.value,
|
||||||
|
]:
|
||||||
|
running_content += str(msg_content)
|
||||||
|
tok_results = tokenizer(running_content, add_special_tokens=False)
|
||||||
|
tok_input_ids = tok_results["input_ids"]
|
||||||
|
if pending_input_ids:
|
||||||
|
new_pending_inputs = tok_input_ids[
|
||||||
|
len(input_ids) : len(input_ids) + len(pending_input_ids)
|
||||||
|
]
|
||||||
|
if new_pending_inputs != pending_input_ids:
|
||||||
|
# logging.warning("tokenization mismatch from concatenation.")
|
||||||
|
pending_input_ids = new_pending_inputs
|
||||||
|
input_ids.extend(pending_input_ids)
|
||||||
|
if pending_weight:
|
||||||
|
labels.extend(pending_input_ids)
|
||||||
|
else:
|
||||||
|
labels.extend([ignore_index] * len(pending_input_ids))
|
||||||
|
pending_input_ids = tok_results["input_ids"][len(input_ids) :]
|
||||||
|
pending_weight = self.weight and msg_content.weight not in [0, 0.0]
|
||||||
|
input_ids.extend(pending_input_ids)
|
||||||
|
if pending_weight:
|
||||||
|
labels.extend(pending_input_ids)
|
||||||
|
else:
|
||||||
|
labels.extend([ignore_index] * len(pending_input_ids))
|
||||||
|
attention_mask = [1] * len(input_ids)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Chats(BaseModel):
|
||||||
|
"""
|
||||||
|
top level data structure for chat conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
conversation: List[Messages]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "".join(str(c) for c in self.conversation)
|
||||||
|
|
||||||
|
def tokenized(
|
||||||
|
self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100
|
||||||
|
) -> dict[str, List[int]]:
|
||||||
|
input_ids = []
|
||||||
|
attention_mask = []
|
||||||
|
labels = []
|
||||||
|
for msg in self.conversation:
|
||||||
|
msg_results = msg.tokenized(tokenizer, ignore_index)
|
||||||
|
input_ids.extend(msg_results["input_ids"])
|
||||||
|
attention_mask.extend(msg_results["attention_mask"])
|
||||||
|
labels.extend(msg_results["labels"])
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatFormattedChats(Chats):
|
||||||
|
"""
|
||||||
|
Chat formatted chats with formatter and optional train on inputs
|
||||||
|
"""
|
||||||
|
|
||||||
|
formatter: Callable # [[Union[dict, Chats]], Chats]
|
||||||
|
train_on_inputs: bool = False
|
||||||
|
|
||||||
|
def model_post_init(self, __context):
|
||||||
|
for i, msg in enumerate(self.conversation):
|
||||||
|
self.conversation[i] = self.formatter(msg, message_index=i)
|
||||||
|
if self.train_on_inputs:
|
||||||
|
self.conversation[i].weight = 1
|
||||||
|
|
||||||
|
|
||||||
|
class PreferenceChats(BaseModel):
|
||||||
|
"""
|
||||||
|
representation for preference data for chat
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: List[Messages]
|
||||||
|
chosen: Messages
|
||||||
|
rejected: Messages
|
||||||
0
src/axolotl/core/datasets/__init__.py
Normal file
0
src/axolotl/core/datasets/__init__.py
Normal file
55
src/axolotl/core/datasets/chat.py
Normal file
55
src/axolotl/core/datasets/chat.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""
|
||||||
|
chat dataset module
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from axolotl.core.chat.messages import ChatFormattedChats
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizedChatDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Tokenized chat dataset
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data: Dataset,
|
||||||
|
model_transform: Union[PreTrainedTokenizer, Callable],
|
||||||
|
*args,
|
||||||
|
message_transform: Optional[Callable] = None,
|
||||||
|
formatter=None,
|
||||||
|
process_count: Optional[int] = None,
|
||||||
|
keep_in_memory: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
def map_fn(ex):
|
||||||
|
if message_transform is not None:
|
||||||
|
ex = message_transform(ex)
|
||||||
|
if formatter is not None:
|
||||||
|
ex = ChatFormattedChats(
|
||||||
|
formatter=formatter,
|
||||||
|
**ex,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ex = ChatFormattedChats(
|
||||||
|
**ex,
|
||||||
|
)
|
||||||
|
return ex.tokenized(model_transform)
|
||||||
|
|
||||||
|
process_or_cpu_count: int = (
|
||||||
|
process_count or os.cpu_count() # type: ignore[assignment]
|
||||||
|
)
|
||||||
|
num_proc = min(64, process_or_cpu_count)
|
||||||
|
features = data.features.keys()
|
||||||
|
tokenized_data = data.map(
|
||||||
|
map_fn,
|
||||||
|
num_proc=num_proc,
|
||||||
|
keep_in_memory=keep_in_memory,
|
||||||
|
remove_columns=features,
|
||||||
|
desc="Tokenizing Chats",
|
||||||
|
)
|
||||||
|
super().__init__(tokenized_data.data, *args, **kwargs)
|
||||||
0
src/axolotl/core/datasets/transforms/__init__.py
Normal file
0
src/axolotl/core/datasets/transforms/__init__.py
Normal file
150
src/axolotl/core/datasets/transforms/chat_builder.py
Normal file
150
src/axolotl/core/datasets/transforms/chat_builder.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""
|
||||||
|
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
||||||
|
"""
|
||||||
|
from typing import Any, Mapping, Union
|
||||||
|
|
||||||
|
|
||||||
|
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
|
||||||
|
train_on_inputs=False,
|
||||||
|
conversations_field: str = "conversations",
|
||||||
|
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
|
||||||
|
message_field_content: Union[str, list[str]] = [
|
||||||
|
"value",
|
||||||
|
"text",
|
||||||
|
"content",
|
||||||
|
], # commonly "content"
|
||||||
|
message_field_training: Union[str, list[str]] = [
|
||||||
|
"train",
|
||||||
|
"weight",
|
||||||
|
], # commonly "weight"
|
||||||
|
):
|
||||||
|
"""Builds a transform that takes a row from the dataset and converts it to a Chat
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_on_inputs (bool, optional):
|
||||||
|
If True, the transform will train on the inputs. If False, the transform will train on the targets.
|
||||||
|
Defaults to False.
|
||||||
|
conversations_field (str, optional):
|
||||||
|
The field name of the conversations. Defaults to "conversations".
|
||||||
|
message_field_role (str | list[str], optional):
|
||||||
|
The field name of the role. Defaults to "role".
|
||||||
|
message_field_content (str | list[str], optional):
|
||||||
|
The field name of the message content. Defaults to "content".
|
||||||
|
message_field_training (str | list[str], optional):
|
||||||
|
The field name of the train/weight. Defaults to "weight".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable:
|
||||||
|
A function that takes a list of conversations and returns a list of messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
message_field_role = (
|
||||||
|
[message_field_role]
|
||||||
|
if isinstance(message_field_role, str)
|
||||||
|
else message_field_role
|
||||||
|
)
|
||||||
|
message_field_content = (
|
||||||
|
[message_field_content]
|
||||||
|
if isinstance(message_field_content, str)
|
||||||
|
else message_field_content
|
||||||
|
)
|
||||||
|
message_weight_fields = (
|
||||||
|
[message_field_training]
|
||||||
|
if isinstance(message_field_training, str)
|
||||||
|
else message_field_training
|
||||||
|
)
|
||||||
|
|
||||||
|
role_value_mappings = {
|
||||||
|
"system": "system",
|
||||||
|
"user": "user",
|
||||||
|
"human": "user",
|
||||||
|
"assistant": "assistant",
|
||||||
|
"gpt": "assistant",
|
||||||
|
"tool": "tool",
|
||||||
|
"ipython": "ipython",
|
||||||
|
}
|
||||||
|
if train_on_inputs:
|
||||||
|
role_default_weights_mappings = {
|
||||||
|
"system": 1,
|
||||||
|
"user": 1,
|
||||||
|
"assistant": 1,
|
||||||
|
"tool": 1,
|
||||||
|
"ipython": 1,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
role_default_weights_mappings = {
|
||||||
|
"system": 0,
|
||||||
|
"user": 0,
|
||||||
|
"assistant": 1,
|
||||||
|
"tool": 0,
|
||||||
|
"ipython": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform_builder(sample: Mapping[str, Any]):
|
||||||
|
if conversations_field not in sample:
|
||||||
|
raise ValueError(f"Field '{conversations_field}' not found in sample.")
|
||||||
|
# if none of the role fields are in the message, raise an error
|
||||||
|
if not any(
|
||||||
|
role in sample[conversations_field][0] for role in message_field_role
|
||||||
|
):
|
||||||
|
raise ValueError("No role field found in message.")
|
||||||
|
role_field = next(
|
||||||
|
role
|
||||||
|
for role in message_field_role
|
||||||
|
if role in sample[conversations_field][0]
|
||||||
|
)
|
||||||
|
if not any(
|
||||||
|
field in sample[conversations_field][0] for field in message_field_content
|
||||||
|
):
|
||||||
|
raise ValueError("No message_content field found in message.")
|
||||||
|
message_content_field = next(
|
||||||
|
field
|
||||||
|
for field in message_field_content
|
||||||
|
if field in sample[conversations_field][0]
|
||||||
|
)
|
||||||
|
if not any(
|
||||||
|
field in sample[conversations_field][0] for field in message_field_training
|
||||||
|
):
|
||||||
|
message_weight_field = None
|
||||||
|
else:
|
||||||
|
message_weight_field = next(
|
||||||
|
field
|
||||||
|
for field in message_weight_fields
|
||||||
|
if field in sample[conversations_field][0]
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for message in sample[conversations_field]:
|
||||||
|
role = role_value_mappings[message[role_field]]
|
||||||
|
weight = (
|
||||||
|
int(message[message_weight_field])
|
||||||
|
if message_weight_field
|
||||||
|
else role_default_weights_mappings[role]
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents
|
||||||
|
if isinstance(message[message_content_field], str):
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": role,
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"value": message[message_content_field],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"weight": weight,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": role,
|
||||||
|
"content": message[message_content_field],
|
||||||
|
"weight": weight,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"conversation": messages}
|
||||||
|
|
||||||
|
return transform_builder
|
||||||
@@ -11,6 +11,10 @@ LOG = logging.getLogger("axolotl.prompt_strategies")
|
|||||||
|
|
||||||
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||||
try:
|
try:
|
||||||
|
if strategy == "messages":
|
||||||
|
from .messages import load as messages_load
|
||||||
|
|
||||||
|
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
|
||||||
load_fn = "load"
|
load_fn = "load"
|
||||||
if strategy.split(".")[-1].startswith("load_"):
|
if strategy.split(".")[-1].startswith("load_"):
|
||||||
load_fn = strategy.split(".")[-1]
|
load_fn = strategy.split(".")[-1]
|
||||||
@@ -31,4 +35,5 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
|||||||
return None
|
return None
|
||||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||||
return None
|
raise exc
|
||||||
|
return None
|
||||||
|
|||||||
34
src/axolotl/prompt_strategies/messages/__init__.py
Normal file
34
src/axolotl/prompt_strategies/messages/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Module to load message prompt strategies."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
|
||||||
|
|
||||||
|
|
||||||
|
def load(tokenizer, cfg, ds_cfg, processor=None):
|
||||||
|
try:
|
||||||
|
strategy = ds_cfg.get("input_transform", "chat")
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
load_fn = "load"
|
||||||
|
if strategy.split(".")[-1].startswith("load_"):
|
||||||
|
load_fn = strategy.split(".")[-1]
|
||||||
|
strategy = ".".join(strategy.split(".")[:-1])
|
||||||
|
mod = importlib.import_module(
|
||||||
|
f".{strategy}", "axolotl.prompt_strategies.messages"
|
||||||
|
)
|
||||||
|
func = getattr(mod, load_fn)
|
||||||
|
load_kwargs = {}
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
if "ds_cfg" in sig.parameters:
|
||||||
|
load_kwargs["ds_cfg"] = ds_cfg
|
||||||
|
if "processor" in sig.parameters:
|
||||||
|
load_kwargs["processor"] = processor
|
||||||
|
return func(tokenizer, cfg, **load_kwargs)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return None
|
||||||
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||||
|
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||||
|
raise exc
|
||||||
|
return None
|
||||||
84
src/axolotl/prompt_strategies/messages/chat.py
Normal file
84
src/axolotl/prompt_strategies/messages/chat.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""
|
||||||
|
Chat dataset wrapping strategy for new internal messages representations
|
||||||
|
"""
|
||||||
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
|
from axolotl.core.datasets.chat import TokenizedChatDataset
|
||||||
|
from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder
|
||||||
|
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):
|
||||||
|
"""
|
||||||
|
Chat dataset wrapping strategy for new internal messages representations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor,
|
||||||
|
message_transform=None,
|
||||||
|
formatter=None,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param processor: tokenizer or image processor
|
||||||
|
:param kwargs:
|
||||||
|
"""
|
||||||
|
self.processor = processor
|
||||||
|
self.dataset = None
|
||||||
|
self.message_transform = message_transform
|
||||||
|
self.formatter = formatter
|
||||||
|
|
||||||
|
def wrap_dataset(
|
||||||
|
self,
|
||||||
|
dataset,
|
||||||
|
process_count: Optional[int] = None,
|
||||||
|
keep_in_memory: Optional[bool] = False,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
self.dataset = TokenizedChatDataset(
|
||||||
|
dataset,
|
||||||
|
message_transform=self.message_transform,
|
||||||
|
model_transform=self.processor,
|
||||||
|
formatter=self.formatter,
|
||||||
|
process_count=process_count,
|
||||||
|
keep_in_memory=keep_in_memory,
|
||||||
|
)
|
||||||
|
return self.dataset
|
||||||
|
|
||||||
|
|
||||||
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
ds_cfg = ds_cfg or {}
|
||||||
|
|
||||||
|
field_messages = ds_cfg.get("field_messages")
|
||||||
|
message_field_role = ds_cfg.get("message_field_role")
|
||||||
|
message_field_content = ds_cfg.get("message_field_content")
|
||||||
|
message_field_training = ds_cfg.get("message_field_training")
|
||||||
|
|
||||||
|
builder_kwargs = {}
|
||||||
|
if field_messages:
|
||||||
|
builder_kwargs["conversations_field"] = field_messages
|
||||||
|
if message_field_role:
|
||||||
|
builder_kwargs["message_field_role"] = message_field_role
|
||||||
|
if message_field_content:
|
||||||
|
builder_kwargs["message_field_content"] = message_field_content
|
||||||
|
if message_field_training:
|
||||||
|
builder_kwargs["message_field_training"] = message_field_training
|
||||||
|
|
||||||
|
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
|
||||||
|
format_message = (
|
||||||
|
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
|
||||||
|
)
|
||||||
|
if chat_template == "chatml":
|
||||||
|
from axolotl.core.chat.format.chatml import format_message # noqa F811
|
||||||
|
if chat_template.startswith("llama3"):
|
||||||
|
from axolotl.core.chat.format.llama3x import format_message # noqa F811
|
||||||
|
message_transform: Callable = chat_message_transform_builder(
|
||||||
|
train_on_inputs=ds_cfg.get("train_on_inputs", False),
|
||||||
|
**builder_kwargs,
|
||||||
|
)
|
||||||
|
strategy = ChatMessageDatasetWrappingStrategy(
|
||||||
|
tokenizer, message_transform=message_transform, formatter=format_message
|
||||||
|
)
|
||||||
|
|
||||||
|
return strategy
|
||||||
@@ -30,6 +30,12 @@ class InvalidDataException(Exception):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetWrappingStrategy(abc.ABC):
|
||||||
|
"""
|
||||||
|
Abstract class for wrapping datasets for Chat Messages
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class PromptTokenizingStrategy(abc.ABC):
|
class PromptTokenizingStrategy(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Abstract class for tokenizing strategies
|
Abstract class for tokenizing strategies
|
||||||
|
|||||||
@@ -102,10 +102,12 @@ class SFTDataset(BaseModel):
|
|||||||
path: Optional[str] = None
|
path: Optional[str] = None
|
||||||
split: Optional[str] = None
|
split: Optional[str] = None
|
||||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||||
|
input_transform: Optional[str] = None
|
||||||
shards: Optional[int] = None
|
shards: Optional[int] = None
|
||||||
conversation: Optional[str] = None
|
conversation: Optional[str] = None
|
||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
data_files: Optional[Union[str, List[str]]] = None
|
data_files: Optional[Union[str, List[str]]] = None
|
||||||
|
input_format: Optional[str] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
ds_type: Optional[str] = None
|
ds_type: Optional[str] = None
|
||||||
train_on_split: Optional[str] = None
|
train_on_split: Optional[str] = None
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from axolotl.prompt_tokenizers import (
|
|||||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
AlpacaReflectionPTStrategy,
|
AlpacaReflectionPTStrategy,
|
||||||
|
DatasetWrappingStrategy,
|
||||||
GPTeacherPromptTokenizingStrategy,
|
GPTeacherPromptTokenizingStrategy,
|
||||||
JeopardyPromptTokenizingStrategy,
|
JeopardyPromptTokenizingStrategy,
|
||||||
OpenAssistantPromptTokenizingStrategy,
|
OpenAssistantPromptTokenizingStrategy,
|
||||||
@@ -573,7 +574,7 @@ def get_dataset_wrapper(
|
|||||||
d_base_type,
|
d_base_type,
|
||||||
dataset,
|
dataset,
|
||||||
d_prompt_style=None,
|
d_prompt_style=None,
|
||||||
processor=None,
|
processor=None, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
dataset_wrapper = None
|
dataset_wrapper = None
|
||||||
dataset_prompter = None
|
dataset_prompter = None
|
||||||
@@ -608,15 +609,16 @@ def get_dataset_wrapper(
|
|||||||
)
|
)
|
||||||
elif cfg.skip_prepare_dataset:
|
elif cfg.skip_prepare_dataset:
|
||||||
dataset_wrapper = dataset
|
dataset_wrapper = dataset
|
||||||
elif ds_strategy := load(
|
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||||
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
|
if isinstance(ds_strategy, DatasetWrappingStrategy):
|
||||||
):
|
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
||||||
dataset_prompter = UnsupportedPrompter()
|
else:
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_prompter = UnsupportedPrompter()
|
||||||
ds_strategy,
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
dataset,
|
ds_strategy,
|
||||||
**ds_kwargs,
|
dataset,
|
||||||
)
|
**ds_kwargs,
|
||||||
|
)
|
||||||
elif d_base_type == "alpaca":
|
elif d_base_type == "alpaca":
|
||||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||||
|
|||||||
0
tests/core/chat/__init__.py
Normal file
0
tests/core/chat/__init__.py
Normal file
0
tests/core/chat/format/__init__.py
Normal file
0
tests/core/chat/format/__init__.py
Normal file
197
tests/core/chat/test_messages.py
Normal file
197
tests/core/chat/test_messages.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""
|
||||||
|
Tests for the chat messages module
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AddedToken, AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.core.chat.format.chatml import format_message
|
||||||
|
from axolotl.core.chat.messages import ChatFormattedChats, Chats
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", name="llama_tokenizer")
|
||||||
|
def llama_tokenizer_fixture():
|
||||||
|
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", name="chatml_tokenizer")
|
||||||
|
def llama_tokenizer_w_chatml(llama_tokenizer):
|
||||||
|
llama_tokenizer.add_special_tokens(
|
||||||
|
{
|
||||||
|
"eos_token": AddedToken(
|
||||||
|
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
llama_tokenizer.add_tokens(
|
||||||
|
[
|
||||||
|
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return llama_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", name="chat_msgs")
|
||||||
|
def chat_msgs_fixture():
|
||||||
|
return {
|
||||||
|
"conversation": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "value": "You are a helpful assistant."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "value": "What is today's stock price of Apple?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"value": {
|
||||||
|
"name": "get_date",
|
||||||
|
"arguments": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"value": {
|
||||||
|
"name": "get_stock_price",
|
||||||
|
"arguments": {"symbol": "AAPL"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"weight": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_response",
|
||||||
|
"value": {
|
||||||
|
"name": "get_date",
|
||||||
|
"content": {"date": "2024-09-09"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_response",
|
||||||
|
"value": {
|
||||||
|
"name": "get_stock_price",
|
||||||
|
"content": {"symbol": "AAPL", "price": 123.45},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"value": "The stock price of Apple is $123.45.\n",
|
||||||
|
"weight": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"value": "<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"value": "The stock price of Apple on September 9, 2024 is $123.45.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"weight": 1,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessagesCase:
|
||||||
|
"""
|
||||||
|
Test cases for the chat messages module
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_tool_call_stringify(self, chat_msgs):
|
||||||
|
chat_msgs_as_obj = Chats(**chat_msgs)
|
||||||
|
assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str(
|
||||||
|
chat_msgs_as_obj.conversation[2].content[1].value
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_chatml_formatted_wrapper(self, chat_msgs):
|
||||||
|
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
|
||||||
|
target_chatml = """<|im_start|>system
|
||||||
|
You are a helpful assistant.<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
What is today's stock price of Apple?<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_date", "arguments": {}}
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}
|
||||||
|
</tool_call>
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>tool
|
||||||
|
<tool_response>
|
||||||
|
{"name": "get_date", "content": {"date": "2024-09-09"}}
|
||||||
|
</tool_response>
|
||||||
|
<tool_response>
|
||||||
|
{"name": "get_stock_price", "content": {"symbol": "AAPL", "price": 123.45}}
|
||||||
|
</tool_response>
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
The stock price of Apple is $123.45.
|
||||||
|
<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\n"""
|
||||||
|
assert target_chatml == str(chat_msg_formatted)
|
||||||
|
|
||||||
|
def test_chatml_formatting_tool_call(self, chat_msgs):
|
||||||
|
chat_msgs_as_obj = Chats(**chat_msgs)
|
||||||
|
target_chatml_turn2 = """<|im_start|>assistant\n<tool_call>\n{"name": "get_date", "arguments": {}}\n</tool_call>\n<tool_call>\n{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}\n</tool_call>\n<|im_end|>\n"""
|
||||||
|
assert target_chatml_turn2 == str(
|
||||||
|
format_message(chat_msgs_as_obj.conversation[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_train_labels(self, chatml_tokenizer, chat_msgs):
|
||||||
|
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
|
||||||
|
tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer)
|
||||||
|
# fmt: off
|
||||||
|
target_labels = [
|
||||||
|
-100, -100, -100, # role
|
||||||
|
27, 14506, 13735, 397, 5018, 609, 794,
|
||||||
|
330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524,
|
||||||
|
14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794,
|
||||||
|
330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314,
|
||||||
|
794, 330, 84016, 43, 96742, 524, 14506, 13735, 397,
|
||||||
|
128256, # <|im_end|>
|
||||||
|
-100 # trailing newline
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
assert tokenized["labels"] == target_labels
|
||||||
|
|
||||||
|
def test_train_labels_2(self, chatml_tokenizer, chat_msgs):
|
||||||
|
# also test if indivudal contents are set not to train
|
||||||
|
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
|
||||||
|
tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer)
|
||||||
|
# fmt: off
|
||||||
|
target_labels = [
|
||||||
|
-100, -100, -100, # role
|
||||||
|
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # initial response
|
||||||
|
27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430,
|
||||||
|
315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457,
|
||||||
|
5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315,
|
||||||
|
8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400,
|
||||||
|
4513, 13, 1774, 13,
|
||||||
|
128256, # <|im_end|>
|
||||||
|
-100, # trailing newline
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
assert tokenized["labels"] == target_labels
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
0
tests/prompt_strategies/messages/__init__.py
Normal file
0
tests/prompt_strategies/messages/__init__.py
Normal file
62
tests/prompt_strategies/messages/test_chat.py
Normal file
62
tests/prompt_strategies/messages/test_chat.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
tests for chat_template prompt strategy
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
import logging
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.messages.chat import load
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessagesChatLlama3:
|
||||||
|
"""
|
||||||
|
Test class for assistant style datasets with llama-3 prompts using the messages chat llama3 strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
||||||
|
LOG.info("Loading llama-3 tokenizer with assistant dataset")
|
||||||
|
strategy = load(
|
||||||
|
llama3_tokenizer,
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"train_on_inputs": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"field_messages": "messages",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
res = strategy.wrap_dataset(assistant_dataset)
|
||||||
|
input_ids = res[0]["input_ids"]
|
||||||
|
# fmt: off
|
||||||
|
expected_input_ids = [
|
||||||
|
128000, # bos
|
||||||
|
128006, 882, 128007, # user header
|
||||||
|
271, 15339, 128009, # user prompt eot
|
||||||
|
128006, 78191, 128007, # assistant header
|
||||||
|
271, 15339, 128009, # assistant response eot
|
||||||
|
128006, 882, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
128006, 78191, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user