From e73b8dff8d5fcfb02371916cbebc1350a3a1a9c9 Mon Sep 17 00:00:00 2001 From: Thomas Cleberg <84520378+thomascleberg@users.noreply.github.com> Date: Fri, 11 Oct 2024 12:32:50 -0500 Subject: [PATCH 1/9] Add Support for `revision` Dataset Parameter to specify reading from Huggingface Dataset Revision (#1912) * Add support for `revision` dataset parameter * only use revision on hf hub backed datasets * use revision tied to head * set download to use revision * feat: add config to model validator class * feat: add revision config to RL and tests for it --------- Co-authored-by: Wing Lian Co-authored-by: NanoCode012 --- docs/config.qmd | 1 + .../config/models/input/v0_4_1/__init__.py | 3 + src/axolotl/utils/data/rl.py | 1 + src/axolotl/utils/data/sft.py | 6 +- tests/test_datasets.py | 138 ++++++++++++++++++ 5 files changed, 148 insertions(+), 1 deletion(-) diff --git a/docs/config.qmd b/docs/config.qmd index 99a69a097..8329f3553 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -90,6 +90,7 @@ datasets: shards: # Optional[int] number of shards to split data into name: # Optional[str] name of dataset configuration to load train_on_split: train # Optional[str] name of dataset split to load from + revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets. # Optional[str] fastchat conversation type, only used with type: sharegpt conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 47796add6..1c33b5907 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -125,6 +125,7 @@ class SFTDataset(BaseModel): drop_system_message: Optional[bool] = None trust_remote_code: Optional[bool] = False + revision: Optional[str] = None class UserDefinedDPOType(BaseModel): @@ -146,6 +147,7 @@ class DPODataset(BaseModel): split: Optional[str] = None type: Optional[Union[UserDefinedDPOType, str]] = None data_files: Optional[List[str]] = None + revision: Optional[str] = None class UserDefinedKTOType(BaseModel): @@ -167,6 +169,7 @@ class KTODataset(BaseModel): type: Optional[Union[UserDefinedKTOType, str]] = None data_files: Optional[List[str]] = None trust_remote_code: Optional[bool] = False + revision: Optional[str] = None class RLType(str, Enum): diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index d0324e1eb..35bd5fcbb 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -90,6 +90,7 @@ def load_prepare_dpo_datasets(cfg): ds = load_dataset( # pylint: disable=invalid-name ds_cfg["path"], split=ds_cfg["split"], + revision=ds_cfg.get("revision", None), ) split_datasets.insert(i, ds) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 7d6922cbf..39eb2c4e0 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -242,6 +242,7 @@ def load_tokenized_prepared_datasets( name=config_dataset.name, streaming=True, token=use_auth_token, + revision=config_dataset.revision, ) ds_from_hub = True except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): @@ -346,6 +347,7 @@ def load_tokenized_prepared_datasets( streaming=False, data_files=config_dataset.data_files, token=use_auth_token, + revision=config_dataset.revision, **load_ds_kwargs, ) elif ds_from_cloud and remote_file_system: @@ -380,6 +382,7 @@ def load_tokenized_prepared_datasets( repo_id=config_dataset.path, repo_type="dataset", filename=config_dataset.data_files, + revision=config_dataset.revision, ) elif isinstance(config_dataset.data_files, list): fp = [] @@ -389,6 +392,7 @@ def load_tokenized_prepared_datasets( repo_id=config_dataset.path, repo_type="dataset", filename=file, + revision=config_dataset.revision, ) ) else: @@ -433,8 +437,8 @@ def load_tokenized_prepared_datasets( config_dataset=config_dataset, tokenizer=tokenizer, cfg=cfg, - dataset=ds, d_base_type=d_base_type, + dataset=ds, d_prompt_style=d_prompt_style, processor=processor, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index a274b7b89..f8b463a03 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download from transformers import AutoTokenizer from axolotl.utils.data import load_tokenized_prepared_datasets +from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.dict import DictDefault @@ -267,6 +268,143 @@ class TestDatasetPreparation(unittest.TestCase): assert "attention_mask" in dataset.features assert "labels" in dataset.features + def test_load_hub_with_dpo(self): + """Verify that processing dpo data from the hub works""" + + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "rl": "dpo", + "chat_template": "llama3", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "chat_template": "llama3", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + } + ], + } + ) + + train_dataset, _ = load_prepare_dpo_datasets(cfg) + + assert len(train_dataset) == 1800 + assert "conversation" in train_dataset.features + + def test_load_hub_with_revision(self): + """Verify that processing data from the hub works with a specific revision""" + with tempfile.TemporaryDirectory() as tmp_dir: + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + "revision": "d05c1cb", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + def test_load_hub_with_revision_with_dpo(self): + """Verify that processing dpo data from the hub works with a specific revision""" + + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "rl": "dpo", + "chat_template": "llama3", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "chat_template": "llama3", + "revision": "ea82cff", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + } + ], + } + ) + + train_dataset, _ = load_prepare_dpo_datasets(cfg) + + assert len(train_dataset) == 1800 + assert "conversation" in train_dataset.features + + def test_load_local_hub_with_revision(self): + """Verify that a local copy of a hub dataset can be loaded with a specific revision""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_path = Path("mhenrichsen/alpaca_2k_test") + tmp_ds_path.mkdir(parents=True, exist_ok=True) + snapshot_download( + repo_id="mhenrichsen/alpaca_2k_test", + repo_type="dataset", + local_dir=tmp_ds_path, + revision="d05c1cb", + ) + + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "ds_type": "parquet", + "type": "alpaca", + "data_files": [ + "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet", + ], + "revision": "d05c1cb", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + shutil.rmtree(tmp_ds_path) + if __name__ == "__main__": unittest.main() From 922db77521f37d32ba7a5ab72b56904fed3bcb5c Mon Sep 17 00:00:00 2001 From: Adam Hazell <34248583+awhazell@users.noreply.github.com> Date: Fri, 11 Oct 2024 18:33:06 +0100 Subject: [PATCH 2/9] Add MLFlow run name option in config (#1961) Co-authored-by: Adam Hazell --- docs/config.qmd | 1 + src/axolotl/core/trainer_builder.py | 9 ++++++--- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index 8329f3553..b6c0cb852 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -266,6 +266,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step # mlflow configuration if you're using it mlflow_tracking_uri: # URI to mlflow mlflow_experiment_name: # Your experiment name +mlflow_run_name: # Your run name hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry # Comet configuration if you're using it diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b1ee519dc..9c12b6141 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1445,9 +1445,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): report_to.append("comet_ml") training_arguments_kwargs["report_to"] = report_to - training_arguments_kwargs["run_name"] = ( - self.cfg.wandb_name if self.cfg.use_wandb else None - ) + if self.cfg.use_wandb: + training_arguments_kwargs["run_name"] = self.cfg.wandb_name + elif self.cfg.use_mlflow: + training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name + else: + training_arguments_kwargs["run_name"] = None training_arguments_kwargs["optim"] = ( self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1c33b5907..1a269b798 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -447,6 +447,7 @@ class MLFlowConfig(BaseModel): use_mlflow: Optional[bool] = None mlflow_tracking_uri: Optional[str] = None mlflow_experiment_name: Optional[str] = None + mlflow_run_name: Optional[str] = None hf_mlflow_log_artifacts: Optional[bool] = None From 76883851d233d3734c19b1979ede7020059ea37d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 11 Oct 2024 13:33:20 -0400 Subject: [PATCH 3/9] add warning that sharegpt will be deprecated (#1957) * add warning that sharegpt will be deprecated * add helper script for chat_templates and document deprecation * Update src/axolotl/prompt_strategies/sharegpt.py Co-authored-by: NanoCode012 --------- Co-authored-by: NanoCode012 --- README.md | 2 +- scripts/chat_datasets.py | 60 +++++++++++++++++++++++ src/axolotl/prompt_strategies/sharegpt.py | 3 ++ 3 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 scripts/chat_datasets.py diff --git a/README.md b/README.md index f6f4e4e80..4ce7a351b 100644 --- a/README.md +++ b/README.md @@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod - typescript type: ... # unimplemented custom format - # fastchat conversation + # fastchat conversation (deprecation soon, use chat_template) # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py - path: ... type: sharegpt diff --git a/scripts/chat_datasets.py b/scripts/chat_datasets.py new file mode 100644 index 000000000..5eb5bde1e --- /dev/null +++ b/scripts/chat_datasets.py @@ -0,0 +1,60 @@ +""" +helper script to parse chat datasets into a usable yaml +""" +import click +import yaml +from datasets import load_dataset + + +@click.command() +@click.argument("dataset", type=str) +@click.option("--split", type=str, default="train") +def parse_dataset(dataset=None, split="train"): + ds_cfg = {} + ds_cfg["path"] = dataset + ds_cfg["split"] = split + ds_cfg["type"] = "chat_template" + ds_cfg["chat_template"] = "<<>>" + + dataset = load_dataset(dataset, split=split) + features = dataset.features + feature_keys = features.keys() + field_messages = None + for key in ["conversation", "conversations", "messages"]: + if key in feature_keys: + field_messages = key + break + if not field_messages: + raise ValueError( + f'No conversation field found in dataset: {", ".join(feature_keys)}' + ) + ds_cfg["field_messages"] = field_messages + + message_fields = features["conversations"][0].keys() + message_field_role = None + for key in ["from", "role"]: + if key in message_fields: + message_field_role = key + break + if not message_field_role: + raise ValueError( + f'No role field found in messages: {", ".join(message_fields)}' + ) + ds_cfg["message_field_role"] = message_field_role + + message_field_content = None + for key in ["content", "text", "value"]: + if key in message_fields: + message_field_content = key + break + if not message_field_content: + raise ValueError( + f'No content field found in messages: {", ".join(message_fields)}' + ) + ds_cfg["message_field_content"] = message_field_content + + print(yaml.dump({"datasets": [ds_cfg]})) + + +if __name__ == "__main__": + parse_dataset() diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 321f19554..4565c35d5 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -61,6 +61,9 @@ def build_loader( default_conversation: Optional[str] = None, ): def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + LOG.warning( + "sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.", + ) conversation = ( ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg From df359c8a6e14ecdd2e1eb0049bd8143c32421952 Mon Sep 17 00:00:00 2001 From: Afrizal Hasbi Azizy Date: Sat, 12 Oct 2024 00:34:13 +0700 Subject: [PATCH 4/9] Handle image input as string paths for MMLMs (#1958) * Update mm_chat.py Handle string image (paths) * chore: lint --------- Co-authored-by: Wing Lian --- src/axolotl/utils/collators/mm_chat.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index f49e97f37..b9b67f875 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -4,6 +4,7 @@ Collators for multi-modal chat messages and packing from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union +from PIL import Image from transformers import PreTrainedTokenizerBase, ProcessorMixin from transformers.data.data_collator import DataCollatorMixin from transformers.utils import PaddingStrategy @@ -52,7 +53,12 @@ class MultiModalChatDataCollator(DataCollatorMixin): ) for example in examples ] - images = [example["images"] for example in examples] + images = [ + Image.open(example["images"]) + if isinstance(example["images"], str) + else example["images"] + for example in examples + ] if max_images > 0: images = [img_batch[:max_images] for img_batch in images] From 09bf1ceacc67b46d6bc5abb8cef2b47c9dd84b8c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 12 Oct 2024 18:19:48 -0400 Subject: [PATCH 5/9] update hf deps (#1964) * update hf deps * remove deprecated set_caching_enabled --- requirements.txt | 12 ++++++------ src/axolotl/utils/trainer.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4323c76ce..2dd3517a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 -peft==0.13.0 -transformers==4.45.1 -tokenizers>=0.19.1 -bitsandbytes==0.44.0 -accelerate==0.34.2 -datasets==2.21.0 +peft==0.13.2 +transformers==4.45.2 +tokenizers>=0.20.1 +bitsandbytes==0.44.1 +accelerate==1.0.0 +datasets==3.0.1 deepspeed==0.14.4 pydantic==2.6.3 addict diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 17276dd8e..30b40925f 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -11,7 +11,7 @@ import numpy as np import torch import torch.cuda from accelerate.logging import get_logger -from datasets import set_caching_enabled +from datasets import disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available @@ -87,10 +87,10 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True): @contextmanager def disable_datasets_caching(): try: - set_caching_enabled(False) + disable_caching() yield finally: - set_caching_enabled(True) + enable_caching() def add_position_ids(sample): From d20b48a61e8dff5565303166fde5303c811e5491 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 12 Oct 2024 20:53:48 -0400 Subject: [PATCH 6/9] only install torchao for torch versions >= 2.4.0 (#1963) --- requirements.txt | 2 ++ setup.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2dd3517a7..37ee1e42c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,3 +52,5 @@ lm_eval==0.4.4 langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 + +torchao==0.5.0 diff --git a/setup.py b/setup.py index e939bc37e..7d9568dbf 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ def parse_requirements(): try: xformers_version = [req for req in _install_requires if "xformers" in req][0] + torchao_version = [req for req in _install_requires if "torchao" in req][0] if "Darwin" in platform.system(): # don't install xformers on MacOS _install_requires.pop(_install_requires.index(xformers_version)) @@ -53,7 +54,8 @@ def parse_requirements(): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") - if (major, minor) >= (2, 3): + elif (major, minor) >= (2, 3): + _install_requires.pop(_install_requires.index(torchao_version)) if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.26.post1") @@ -61,9 +63,11 @@ def parse_requirements(): _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") elif (major, minor) >= (2, 2): + _install_requires.pop(_install_requires.index(torchao_version)) _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.25.post1") else: + _install_requires.pop(_install_requires.index(torchao_version)) _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.23.post1") From 31591bd94cf8fd3e18fc8949385cf405b1ff0dda Mon Sep 17 00:00:00 2001 From: pandora <128635000+pandora-s-git@users.noreply.github.com> Date: Sun, 13 Oct 2024 03:40:39 +0200 Subject: [PATCH 7/9] Fixing Validation - Mistral Templates (#1962) --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1a269b798..af1570db6 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -187,7 +187,9 @@ class ChatTemplate(str, Enum): alpaca = "alpaca" # pylint: disable=invalid-name chatml = "chatml" # pylint: disable=invalid-name - inst = "inst" # pylint: disable=invalid-name + mistral_v1 = "mistral_v1" # pylint: disable=invalid-name + mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name + mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name gemma = "gemma" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name From ac128b7b1dde6e6f0ca9a06697cad6fa31c9d5b0 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 13 Oct 2024 08:41:13 +0700 Subject: [PATCH 8/9] fix: update eval causal lm metrics to add perplexity (#1951) [skip ci] --- docs/config.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/config.qmd b/docs/config.qmd index b6c0cb852..703d58775 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -315,7 +315,7 @@ max_steps: eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 -eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf] +eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"] loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) From 1834cdc3645c003e3db02346912cab19a1eb5ca3 Mon Sep 17 00:00:00 2001 From: Vincent Haines Date: Sat, 12 Oct 2024 21:41:43 -0400 Subject: [PATCH 9/9] Add support for qwen 2.5 chat template (#1934) --- src/axolotl/utils/chat_templates.py | 1 + src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 9e1e6ca32..2443f56f9 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -17,6 +17,7 @@ CHAT_TEMPLATES = { "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', + "qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", } diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index af1570db6..40f4a36ab 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -198,6 +198,7 @@ class ChatTemplate(str, Enum): phi_35 = "phi_35" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name jamba = "jamba" # pylint: disable=invalid-name + qwen_25 = "qwen_25" # pylint: disable=invalid-name class LoftQConfig(BaseModel):