Compare commits

...

15 Commits

Author SHA1 Message Date
NanoCode012
d47093fcdd fix: simplify fn same as sft and pass model to plugin 2025-07-08 22:29:56 +07:00
Wing Lian
d68cc1e8ab densemixer plugin integration (#2868)
* densemixer plugin integration

* update readme with usage docs

* automatically find new integrations that aren't explicitly defined

* make sure to import os
2025-07-07 17:05:19 -04:00
github-actions[bot]
21f1bf4805 chore: update pre-commit hooks (#2870) [skip ci]
* chore: update pre-commit hooks

* don't bandit huggingface hub downloads without revision

---------

Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-07-07 15:26:15 -04:00
Wing Lian
de2c5ba103 mark flaky geglu tests and add torch seed (#2876) [skip ci]
* mark flaky geglu tests and add torch seed

* restore accidental removal of seed
2025-07-07 15:24:16 -04:00
Wing Lian
9c0d7ee761 TiledMLP support (#2865) 2025-07-07 15:23:49 -04:00
NanoCode012
22d4a838dc feat(doc): add vllm and fa2 incompat error to faq (#2877) 2025-07-07 14:13:37 -04:00
Wing Lian
a108e5db56 use latest version of cce fork for SP fix (#2871) [skip ci]
* use latest version of cce fork for SP fix

* latest sha to handle older transformers
2025-07-07 13:05:11 -04:00
Wing Lian
faff0cff41 manage jinja templates as nicely formatted files (#2795)
* manage jinja templates as nicely formatted files

* chore: lint

* use path for templates relative to the module

* fix template reformating

* handle newlines in llama3 template

* fix gemma3 jinja

* fix templates

* suport for passing jinja template file in yaml

* handle file loading of jinja template outside of validation

* fix typing and typo
2025-07-07 10:11:48 -04:00
Wing Lian
759cefb741 setup defaults for dataloader to ensure GPU is kept busy (#2632) [skip ci] 2025-07-07 10:10:58 -04:00
Wing Lian
69cd49a7aa update transformers to 4.53.1 (#2844) [skip ci]
* update transformers to 4.53.0

* remove attention_mask from signature columns if using packing

* remove attention_mask column from dataloader

* update signature of flash attn forward for ring attn patch

* fix FSDP

* patch ring-flash-attn with upstream signature fix

* fix patch indentation level

* fix the patch

* add batch flattening smoke test with loss check that works in older transformers

* fix patch

* don't drop attention mask for flex

* more fixes

* patch create_causal_mask for packing w flex

* global torch manual_seed fixture

* tweak loss checks

* fix patch and use single batch for flex

* don't need to reload

* fix causal mask patch

* use transformers patch releasE

* make sure env var is string

* make sure to drop attention mask for flex w packing for latest transformers patch release

* tweak loss

* guard on signature columns before removing attention mask

* bump loss

* set remove isn't chainable

* skip slow mistral test in 2.5.1
2025-07-07 09:35:22 -04:00
NanoCode012
5a961ecadf Fix: do not call preprocess in multimodal or pretraining case (#2861)
* fix: let users know to not call preprocess for vision mode

* fix: improve ux for pretraining dataset and skip prepare ds

* feat: add info to doc

* Update src/axolotl/cli/preprocess.py following comment

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-07-06 21:55:33 -04:00
Wing Lian
b37ddf9778 don't use tokenizer parallelism when using packing (#2862) [skip ci] 2025-07-06 21:55:09 -04:00
Wing Lian
bf38e507fb respect shuffle_merged_datasets for single dataset too (#2866) [skip ci]
* respect shuffle_merged_datasets for single dataset too

* update inline comment for behavior

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-07-06 21:20:41 -04:00
Wing Lian
a5946ff1f0 build fa2 from source for base image with torch2.6 and cu124 (#2867) 2025-07-05 09:21:18 -04:00
Wing Lian
70ca1b2291 fix nightlies to use correct cache (#2848) [skip ci]
* fix nightlies to use correct cache

* fix for handling None for bf16
2025-07-03 12:21:39 -04:00
80 changed files with 2273 additions and 321 deletions

View File

@@ -1,3 +1,3 @@
[bandit]
exclude = tests
skips = B101
skips = B101,B615

View File

@@ -5,11 +5,13 @@ on:
branches:
- "main"
paths:
- 'Dockerfile-base'
- 'docker/Dockerfile-base'
- 'docker/Dockerfile-uv-base'
- '.github/workflows/base.yml'
pull_request:
paths:
- 'Dockerfile-base'
- 'docker/Dockerfile-base'
- 'docker/Dockerfile-uv-base'
- '.github/workflows/base.yml'
workflow_dispatch:

View File

@@ -18,96 +18,9 @@ jobs:
env:
SKIP: no-commit-to-branch
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 2
@@ -120,14 +33,11 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -168,10 +78,6 @@ jobs:
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
@@ -193,15 +99,8 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1

View File

@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.8.5
rev: 1.8.6
hooks:
- id: bandit
args: [

View File

@@ -2,4 +2,5 @@ include requirements.txt
include README.md
include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
include src/axolotl/utils/chat_templates/templates/*.jinja
recursive-include axolotl *.py

View File

@@ -9,6 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_PROCESSES="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev

View File

@@ -37,3 +37,7 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
fi

View File

@@ -7,6 +7,7 @@ toc-depth: 3
```{python}
#| echo: false
import os
import re
def process_readme(integration_name):
@@ -53,6 +54,24 @@ sections = [
("LLMCompressor", "llm_compressor")
]
for folder_name in os.listdir("../src/axolotl/integrations/"):
if folder_name in [path for name, path in sections]:
# skip if already in sections
continue
if os.path.exists(f"../src/axolotl/integrations/{folder_name}/README.md"):
# grab the first heading in README.md as the section name
with open(f"../src/axolotl/integrations/{folder_name}/README.md", "r") as f:
txt = f.read()
matches = re.search(r'^# (.*)\n?', txt, flags=re.MULTILINE)
if matches:
name = matches.group(1)
else:
continue
sections.append((name, folder_name))
# sort sections by name
sections = sorted(sections, key=lambda x: x[0])
for section_name, folder_name in sections:
print(print_section(section_name, folder_name))
```

View File

@@ -51,6 +51,18 @@ description: Frequently asked questions
> pad_token: "..."
> ```
**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI**
> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand.
**Q: vLLM is not working with Axolotl**
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

View File

@@ -13,7 +13,7 @@ packaging==23.2
huggingface_hub==0.32.2
peft==0.15.2
transformers==4.52.4
transformers==4.53.1
tokenizers>=0.21.1
accelerate==1.8.1
datasets==3.6.0

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"'
)

View File

@@ -114,7 +114,7 @@ extras_require = {
"flash-attn": ["flash-attn==2.8.0.post2"],
"ring-flash-attn": [
"flash-attn==2.8.0.post2",
"ring-flash-attn>=0.1.4",
"ring-flash-attn>=0.1.5",
"yunchang==0.6.0",
],
"deepspeed": [

View File

@@ -35,6 +35,12 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get("key"):
raise ValueError(
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
)
if not cfg.dataset_prepared_path:
msg = (
Fore.RED

View File

View File

@@ -0,0 +1,162 @@
"""
monkeypatch for flex + packing
"""
import sys
from typing import Callable, Optional, Union
import torch
from torch.nn.attention.flex_attention import BlockMask
from transformers import Cache, PretrainedConfig
from transformers.masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
_preprocess_mask_arguments,
and_masks,
causal_mask_function,
or_masks,
)
from transformers.utils import is_torch_greater_or_equal
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
def create_causal_mask(
config: PretrainedConfig,
input_embeds: torch.Tensor,
attention_mask: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
"""
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
to what is needed in the `modeling_xxx.py` files).
Args:
config (`PretrainedConfig`):
The model config.
input_embeds (`torch.Tensor`):
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
batch size, query length and dtype.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
It can also be an already prepared 4D mask, in which case it is returned as-is.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
and_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the full layers
if (
past_key_values
and hasattr(past_key_values, "is_sliding")
and False in past_key_values.is_sliding
):
layer_idx = past_key_values.is_sliding.index(False)
else:
layer_idx = 0
original_attention_mask = (
None
if attention_mask is None
else attention_mask.clone().to(cache_position.device)
)
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
)
if early_exit:
return attention_mask
batch_size, total_seq_len = cache_position.shape
key_length = total_seq_len
document_ids = torch.nn.functional.pad(
original_attention_mask, value=0, pad=(0, key_length)
)
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
if attention_mask is not None:
def causal_doc_mask_mod(
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask_ = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
final_mask = causal_mask_ & document_mask
return final_mask
mask_factory_function = causal_doc_mask_mod
else:
mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
config._attn_implementation # pylint: disable=protected-access
]
# Do not allow skip if we are compiling (this is to match BC)
allow_is_causal_skip = (
not past_key_values.is_compileable if past_key_values is not None else True
)
# Allow slight deviations from causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError(
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
)
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError(
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
)
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
# We now create the mask
causal_mask = mask_interface(
batch_size=batch_size,
cache_position=cache_position,
kv_length=kv_length,
kv_offset=kv_offset,
mask_function=mask_factory_function,
attention_mask=attention_mask,
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
)
return causal_mask
def patch_create_causal_mask(model_type):
import transformers.masking_utils
transformers.masking_utils.create_causal_mask = create_causal_mask
if model_type:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(module_path)
module.create_causal_mask = create_causal_mask
del sys.modules[module_path]
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e

View File

@@ -219,7 +219,9 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
bf16 = self.cfg.bf16 or self.cfg.bfloat16
bf16 = bf16 if bf16 is not None else False
training_args_kwargs["bf16"] = bf16
def _configure_scheduler(self, training_args_kwargs: dict):
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:

View File

@@ -245,10 +245,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
self.cfg.flash_attention
or self.cfg.xformers_attention
or self.cfg.flex_attention
)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.multipack_real_batches
if self.cfg.multipack_real_batches is not None
else not self.cfg.flash_attention
else not (
self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.xformers_attention
)
)
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing

View File

@@ -6,6 +6,7 @@ from pathlib import Path
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlCPOTrainer,
AxolotlDPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
)
@@ -36,33 +37,23 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def _get_trainer_cls(self, trainer_kwargs: dict):
"""
Returns trainer_cls and trainer_cls_args
"""
def _get_trainer_cls(self):
"""Returns trainer_cls"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
trainer_cls_args = [] # type: ignore
if trainer_cls is not None:
return trainer_cls, trainer_cls_args
return trainer_cls
trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args.append(self.model_ref)
trainer_cls = AxolotlDPOTrainer
elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer
elif self.cfg.rl is RLType.KTO:
@@ -72,7 +63,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
return trainer_cls, trainer_cls_args
return trainer_cls
def _build_training_arguments(self, total_num_steps):
"""
@@ -182,7 +173,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.precompute_ref_log_probs
)
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
trainer_cls = self._get_trainer_cls()
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
if self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls_args.append(self.model_ref)
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters:
@@ -190,9 +189,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]

View File

@@ -27,6 +27,7 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import (
CheckpointSaveMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
SchedulerMixin,
)
@@ -42,7 +43,12 @@ LOG = get_logger(__name__)
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
PackingMixin,
SchedulerMixin,
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
Trainer,
):
"""Extend the base Trainer for axolotl helpers"""
@@ -206,6 +212,14 @@ class AxolotlTrainer(
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if (
dataset.column_names
and "position_ids" in dataset.column_names
and "attention_mask" in dataset.column_names
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
dataset = dataset.remove_columns(["attention_mask"])
if isinstance(dataset, datasets.Dataset):
if is_training:

View File

@@ -5,5 +5,6 @@
from .checkpoints import CheckpointSaveMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin

View File

@@ -0,0 +1,20 @@
"""Trainer mixin to support packing"""
from transformers import Trainer
class PackingMixin(Trainer):
"""
Trainer mixin to support packing
"""
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
if (
self._signature_columns
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
set_sig_columns = set(self._signature_columns)
set_sig_columns.remove("attention_mask")
self._signature_columns = list(set_sig_columns)

View File

@@ -42,6 +42,10 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "The multiprocessing start method to use."},
)
sample_packing_drop_attention_mask: bool = field(
default=False,
metadata={"help": "Drop attention mask from inputs when using packing."},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"
```
## Usage

View File

@@ -32,7 +32,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@7f6afce"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"`'
)

View File

@@ -0,0 +1,12 @@
# DenseMixer
See [DenseMixer](https://github.com/yaof20/DenseMixer/)
# Usage
Simply add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.densemixer.DenseMixerPlugin
```

View File

@@ -0,0 +1,5 @@
"""Integration entry point for the DenseMixer plugin."""
from .plugin import DenseMixerPlugin
__all__ = ["DenseMixerPlugin"]

View File

@@ -0,0 +1,11 @@
"""Pydantic models for DenseMixer plugin"""
from pydantic import BaseModel
class DenseMixerArgs(BaseModel):
"""
Args for DenseMixer
"""
dense_mixer: bool = True

View File

@@ -0,0 +1,42 @@
"""DenseMixer plugin for Axolotl"""
import importlib
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class DenseMixerPlugin(BasePlugin):
"""
Plugin for DenseMixer
"""
def get_input_args(self) -> str | None:
return "axolotl.integrations.densemixer.args.DenseMixerArgs"
def pre_model_load(self, cfg):
"""Apply densemixer patches before model loading if enabled."""
if cfg.dense_mixer:
if not importlib.util.find_spec("densemixer"):
raise RuntimeError(
"DenseMixer is not installed. Install it with `pip install densemizer`"
)
from densemixer.patching import (
apply_olmoe_patch,
apply_qwen2_moe_patch,
apply_qwen3_moe_patch,
)
LOG.info(
f"Applying DenseMixer patches for model type: {cfg.model_config_type}"
)
if cfg.model_config_type == "olmoe":
apply_olmoe_patch()
if cfg.model_config_type == "qwen2_moe":
apply_qwen2_moe_patch()
if cfg.model_config_type == "qwen3_moe":
apply_qwen3_moe_patch()

View File

@@ -49,11 +49,11 @@ class PatchManager:
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_flex_attention_patches()
self._apply_model_specific_patches()
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches()
@@ -66,6 +66,7 @@ class PatchManager:
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
self._apply_tiled_mlp(self.cfg.model_config_type)
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
@@ -97,6 +98,14 @@ class PatchManager:
patch_accelerate_fsdp2()
# if self.cfg.fsdp_config:
# # see transformers#39152
# from axolotl.monkeypatch.trainer_fsdp_optim import (
# patch_training_loop_for_fsdp,
# )
#
# patch_training_loop_for_fsdp()
def _apply_adapter_patches(self):
"""Apply patches for adapter configurations."""
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
@@ -107,14 +116,20 @@ class PatchManager:
def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention."""
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
# from axolotl.monkeypatch.attention.flex_attn import (
# patch_flex_make_mask,
# patch_flex_wrapper,
# )
#
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
# patch_flex_wrapper(**flex_attn_compile_kwargs)
# patch_flex_make_mask()
if self.cfg.sample_packing:
from axolotl.core.attention.flex_block_mask import (
patch_create_causal_mask,
)
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
patch_flex_make_mask()
patch_create_causal_mask(self.cfg.model_config_type)
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
@@ -243,6 +258,12 @@ class PatchManager:
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):

View File

@@ -33,7 +33,7 @@ RING_ATTN_FUNC_MAPPING = {
}
def create_flash_attn_forward(
def create_flash_attn_forward_varlen_llama3(
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
) -> Callable:
"""
@@ -71,6 +71,7 @@ def create_flash_attn_forward(
max_length_q: int | None = None,
max_length_k: int | None = None,
target_dtype: torch.dtype | None = None,
attn_implementation: str | None = None,
**kwargs,
):
"""
@@ -97,6 +98,7 @@ def create_flash_attn_forward(
max_length_q: Not used in this implementation.
max_length_k: Not used in this implementation.
target_dtype: Not used in this implementation.
attn_implementation: Not used in this implementation.
**kwargs: Additional keyword arguments. Not used in this implementation.
Returns:
@@ -161,7 +163,7 @@ def substitute_hf_flash_attn(
old_flash_attention_forward = (
transformers.modeling_flash_attention_utils._flash_attention_forward
)
new_flash_attention_forward = create_flash_attn_forward(
new_flash_attention_forward = create_flash_attn_forward_varlen_llama3(
process_group=process_group, ring_attn_func=ring_attn_func
)

View File

@@ -9,10 +9,13 @@ sequence parallelism training.
"""
import inspect
import os
from typing import Optional
import accelerate
import torch
import torch.distributed as dist
from transformers.modeling_flash_attention_utils import _flash_supports_window_size
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
@@ -62,6 +65,96 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group
def create_ring_flash_attention_forward(
process_group: dist.ProcessGroup, heads_k_stride: int
):
from ring_flash_attn import llama3_flash_attn_varlen_func
from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS
def _flash_attention_forward_v3(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor, # pylint: disable=unused-argument
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
cu_seq_lens_q: Optional[
torch.LongTensor
] = None, # pylint: disable=unused-argument
cu_seq_lens_k: Optional[
torch.LongTensor
] = None, # pylint: disable=unused-argument
max_length_q: Optional[int] = None, # pylint: disable=unused-argument
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
target_dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument
attn_implementation: Optional[str] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# pylint: disable=duplicate-code
if not use_top_left_mask:
causal = is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
causal = is_causal and query_length != 1
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
flash_kwargs = (
{"window_size": (sliding_window, sliding_window)}
if use_sliding_windows
else {}
)
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs["deterministic"] = deterministic
assert (
softcap is None
), "llama3_flash_attn_varlen_func does not support softcap yet."
# flash_kwargs["softcap"] = softcap
flash_kwargs["group"] = process_group
# not sure why attention_mask can be not None...
assert causal, "only causal attention is supported yet."
batch_size = query_states.size(0)
assert batch_size == 1, "varlen data should be processed in advance."
attn_output = llama3_flash_attn_varlen_func(
query_states.squeeze(dim=0),
key_states.squeeze(dim=0),
value_states.squeeze(dim=0),
cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"],
cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"],
max_seqlen_q=DATA_PARAMS["max_seqlen_q"],
max_seqlen_k=DATA_PARAMS["max_seqlen_k"],
heads_k_stride=heads_k_stride,
local_k_slice=DATA_PARAMS["local_k_slice"],
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = attn_output.unsqueeze(dim=0)
return attn_output
return [
_flash_attention_forward_v3,
]
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
@@ -118,9 +211,20 @@ def register_ring_attn(
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn
# fmt: off
import ring_flash_attn.adapters.hf_adapter
substitute_hf_flash_attn(
from ring_flash_attn.adapters.hf_adapter import ( # isort: skip # pylint: disable=unused-import
create_ring_flash_attention_forward as create_ring_flash_attention_forward_orig,
)
create_ring_flash_attention_forward_orig = ( # noqa: F811,F841
create_ring_flash_attention_forward
)
ring_flash_attn.adapters.hf_adapter.create_ring_flash_attention_forward = create_ring_flash_attention_forward
# fmt: on
ring_flash_attn.adapters.hf_adapter.substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func is RingAttnFunc.BATCH_RING:

View File

@@ -0,0 +1,64 @@
"""Monkeypatch for Tiled MLP implementation"""
import math
import torch
import torch.distributed as dist
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
if use_original_mlp:
mlp_forward = mlp_cls.forward
else:
def generic_mlp_forward(self_, hs):
return self_.down_proj(
self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs)
)
mlp_forward = torch.compile(generic_mlp_forward)
def tiled_mlp_forward(self, x):
input_shape = x.shape
seqlen = input_shape[-2]
hidden = input_shape[-1]
if cfg_num_shards is None:
num_shards = math.ceil(seqlen / hidden)
num_shards_tensor = torch.tensor(num_shards, device=x.device)
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
num_shards = num_shards_tensor.item()
else:
num_shards = cfg_num_shards
compute_params = [
self.down_proj.weight,
self.gate_proj.weight,
self.up_proj.weight,
]
down_res = TiledMLP.apply(
mlp_forward,
self,
x,
num_shards,
compute_params,
)
return down_res
mlp_cls.forward = tiled_mlp_forward
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import MLP class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e

View File

@@ -12,15 +12,13 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
if delay_optimizer_creation:
self.optimizer = self.accelerator.prepare(self.optimizer)
"""
PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
if delay_optimizer_creation:
model = self.accelerator.prepare(self.model)
"""

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,20 @@
"""
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""
from .base import (
_CHAT_TEMPLATES,
extract_chat_template_args,
get_chat_template,
get_chat_template_from_config,
register_chat_template,
)
__all__ = [
"get_chat_template",
"extract_chat_template_args",
"get_chat_template_from_config",
"register_chat_template",
"_CHAT_TEMPLATES",
]

View File

@@ -0,0 +1,125 @@
"""
utility functions for chat templates
"""
import os
from typing import TYPE_CHECKING, Any, Dict, Optional
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase
LOG = get_logger("axolotl.utils.chat_templates")
_JINJA_TEMPLATE_CHOICE = "jinja"
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "templates")
_CHAT_TEMPLATES: dict[str, str] = {}
for filename in [f for f in os.listdir(TEMPLATE_DIR) if f.endswith(".jinja")]:
with open(os.path.join(TEMPLATE_DIR, filename), "r", encoding="utf-8") as f:
_CHAT_TEMPLATES[filename[:-6]] = f.read()
def get_chat_template(
user_choice: str,
jinja_template: str | None = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
) -> str:
"""
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
Args:
user_choice (str): The user's choice of template.
jinja_template (str, optional): The jinja template string or Path to a valid jinja template file. Defaults to None.
tokenizer (PreTrainedTokenizerBase, optional): The tokenizer. Defaults to None.
Returns:
str: The chosen template string.
Raises:
ValueError: If the user_choice is not found in the templates.
"""
if user_choice == _JINJA_TEMPLATE_CHOICE:
if not jinja_template:
raise ValueError(
f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPLATE_CHOICE}"
)
if os.path.exists(jinja_template) and os.path.isfile(jinja_template):
with open(jinja_template, "r", encoding="utf-8") as file:
jinja_template = file.read()
return jinja_template
if user_choice == _DEFAULT_TEMPLATE_CHOICE:
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}"
)
if not tokenizer.chat_template:
raise ValueError(
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
f"Please add a chat_template in tokenizer config"
)
return tokenizer.chat_template # type: ignore
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
)
if tokenizer.chat_template:
return tokenizer.chat_template # type: ignore
user_choice = user_choice[
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
]
LOG.warning(
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
)
if user_choice in _CHAT_TEMPLATES:
return _CHAT_TEMPLATES[user_choice]
raise ValueError(f"Template '{user_choice}' not found.")
def extract_chat_template_args(cfg, ds_cfg: Dict[str, Any] | None = None):
if ds_cfg and ds_cfg.get("chat_template"):
chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
chat_template_jinja = ds_cfg.get("chat_template_jinja")
else:
chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
chat_template_jinja = cfg.get("chat_template_jinja")
return chat_template_choice, chat_template_jinja
def get_chat_template_from_config(
cfg,
ds_cfg: Dict[str, Any] | None = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
) -> str:
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
return get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
def register_chat_template(template_name: str, chat_template: str):
"""
Registers chat templates.
Args:
template_name (str): The name of the template.
chat_template (str): The template string.
"""
if template_name in _CHAT_TEMPLATES:
raise ValueError(f"Template '{template_name}' already exists.")
_CHAT_TEMPLATES[template_name] = chat_template

View File

@@ -0,0 +1,8 @@
{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' and loop.first %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '### Instruction:
' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '### Response:
' + message['content'] + eos_token }}{% endif %}{% if not loop.last %}{{ '
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '
### Response:
' }}{% endif %}

View File

@@ -0,0 +1 @@
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}

View File

@@ -0,0 +1,4 @@
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '
' + message['content'] + '<|im_end|>' + '
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
' }}{% endif %}

View File

@@ -0,0 +1 @@
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}

View File

@@ -0,0 +1,210 @@
{{ bos_token }}{% if documents %}
{% set tools = [] %}
{%- macro document_turn(documents) -%}
{# format documents into chat turn #}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[
{"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}
]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
{
"tool_call_id": "0",
"results": {
{% for doc in documents %}
"{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},
{% endif %}
{% endfor %}
},
"is_error": null
}
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}
{%- macro tool_call_id_to_int(messages, tool_call_id) %}
{%- set counter = namespace(value=0) %}
{%- set tool_call_id_seen = namespace(value=false) %}
{%- for msg in messages %}
{%- if msg.tool_calls %}
{%- for tool_call in msg.tool_calls %}
{%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}
{{ counter.value }}
{%- set tool_call_id_seen.value = true %}
{%- endif %}
{%- set counter.value = counter.value + 1 %}
{%- endfor %}
{%- endif %}
{%- endfor %}
{%- endmacro %}
{%- macro format_tool_message(messages, tool_msg) -%}
{# format tool message #}
{
"tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",
"results": {
"0": {{ tool_msg.content|tojson }}
},
"is_error": null
}
{%- endmacro -%}
{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}
{%- set tool_idx = namespace(value=0) %}
{%- set tool_ids_seen = namespace(value=[]) %}
{%- set sent_documents = namespace(value=false) %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble
You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.
Your information cutoff date is June 2024.
You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.
{% if tools or documents %}
You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.
## Tool Use
Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.
0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.
You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.
NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.
Then carry out your plan by repeatedly executing the following steps.
1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.
When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.
2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.
Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".
3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.
You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.
NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.
You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.
4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.
{% if enable_citations %}
## Grounding
Importantly, note that "Reflection" and "Response" above can be grounded.
Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "<co>" and "</co>" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "<co>span</co: 0:[1,2],1:[0]>" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".
{% endif %}
## Available Tools
Here is the list of tools that you have available to you.
You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.
Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).
```json
[
{% if documents %}
{"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}
{% endif %}
{% for tool in tools %}
{"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}
{% endfor %}
]
```
{% endif %}
# Default Preamble
The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.
- Your name is Command.
- You are a large language model built by Cohere.
- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.
- If the input is ambiguous, ask clarifying follow-up questions.
- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).
- Use LaTeX to generate mathematical notation for complex equations.
- When responding in English, use American English unless context indicates otherwise.
- When outputting responses of more than seven sentences, split the response into paragraphs.
- Prefer the active voice.
- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.
- Use gender-neutral pronouns for unspecified persons.
- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.
- Use the third person when asked to write a summary.
- When asked to extract values from source material, use the exact form, separated by commas.
- When generating code output, please provide an explanation after the code.
- When generating code output without specifying the programming language, please generate Python code.
- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.
{%- if developer_preamble %}
# Developer Preamble
The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.
{{ developer_preamble }}
{%- endif -%}
<|END_OF_TURN_TOKEN|>
{%- for message in messages %}
{%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
{%- elif message.role|lower == 'user' %}
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[
{% for tc in message.tool_calls %}
{"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}
{% set tool_idx.value = tool_idx.value + 1 %}
{% endfor %}
]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}
{% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
{{ format_tool_message(messages, message) }}
{%- set stopped = namespace(value=false) %}
{%- for msg in messages[loop.index0 + 1:] %}
{%- if not stopped.value and msg.role|lower == 'tool' %},
{{ format_tool_message(messages, msg) }}
{%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}
{%- else %}
{%- set stopped.value = true %}
{%- endif %}
{%- endfor %}
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>
{%- endif %}
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
{%- else -%}
{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble
{% if safety_mode|upper == 'STRICT' -%}
You are in strict safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will reject requests to generate content related to violence, hate, misinformation or sex to any amount. You will avoid using profanity. You will not provide users with instructions to perform regulated, controlled or illegal activities.
{%- else -%}
You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.
{%- endif %}
Your information cutoff date is June 2024.
You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.
# Default Preamble
The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.
- Your name is Command.
- You are a large language model built by Cohere.
- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.
- If the input is ambiguous, ask clarifying follow-up questions.
- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).
- Use LaTeX to generate mathematical notation for complex equations.
- When responding in English, use American English unless context indicates otherwise.
- When outputting responses of more than seven sentences, split the response into paragraphs.
- Prefer the active voice.
- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.
- Use gender-neutral pronouns for unspecified persons.
- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.
- Use the third person when asked to write a summary.
- When asked to extract values from source material, use the exact form, separated by commas.
- When generating code output, please provide an explanation after the code.
- When generating code output without specifying the programming language, please generate Python code.
- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.
{%- if developer_preamble %}
# Developer Preamble
The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.
{{ developer_preamble }}
{%- endif -%}
<|END_OF_TURN_TOKEN|>
{%- for message in messages %}
{%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
{%- elif message.role|lower == 'user' %}
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>
{%- endif %}
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{%- if add_generation_prompt -%}<|START_RESPONSE|>{%- endif %}
{% endif %}

View File

@@ -0,0 +1,158 @@
{{ bos_token }}{% set tools = [] %}
{%- macro document_turn(documents) -%}
{# format documents into chat turn #}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[
{"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}
]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
{
"tool_call_id": "0",
"results": {
{% for doc in documents %}
"{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},
{% endif %}
{% endfor %}
},
"is_error": null
}
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}
{%- macro tool_call_id_to_int(messages, tool_call_id) %}
{%- set counter = namespace(value=0) %}
{%- set tool_call_id_seen = namespace(value=false) %}
{%- for msg in messages %}
{%- if msg.tool_calls %}
{%- for tool_call in msg.tool_calls %}
{%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}
{{ counter.value }}
{%- set tool_call_id_seen.value = true %}
{%- endif %}
{%- set counter.value = counter.value + 1 %}
{%- endfor %}
{%- endif %}
{%- endfor %}
{%- endmacro %}
{%- macro format_tool_message(messages, tool_msg) -%}
{# format tool message #}
{
"tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",
"results": {
"0": {{ tool_msg.content|tojson }}
},
"is_error": null
}
{%- endmacro -%}
{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}
{%- set tool_idx = namespace(value=0) %}
{%- set tool_ids_seen = namespace(value=[]) %}
{%- set sent_documents = namespace(value=false) %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble
You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.
Your information cutoff date is June 2024.
You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.
{% if tools or documents %}
You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.
## Tool Use
Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.
0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.
You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.
NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.
Then carry out your plan by repeatedly executing the following steps.
1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.
When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.
2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.
Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".
3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.
You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.
NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.
You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.
4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.
{% if enable_citations %}
## Grounding
Importantly, note that "Reflection" and "Response" above can be grounded.
Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "<co>" and "</co>" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "<co>span</co: 0:[1,2],1:[0]>" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".
{% endif %}
## Available Tools
Here is the list of tools that you have available to you.
You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.
Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).
```json
[
{% if documents %}
{"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}
{% endif %}
{% for tool in tools %}
{"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}
{% endfor %}
]
```
{% endif %}
# Default Preamble
The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.
- Your name is Command.
- You are a large language model built by Cohere.
- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.
- If the input is ambiguous, ask clarifying follow-up questions.
- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).
- Use LaTeX to generate mathematical notation for complex equations.
- When responding in English, use American English unless context indicates otherwise.
- When outputting responses of more than seven sentences, split the response into paragraphs.
- Prefer the active voice.
- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.
- Use gender-neutral pronouns for unspecified persons.
- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.
- Use the third person when asked to write a summary.
- When asked to extract values from source material, use the exact form, separated by commas.
- When generating code output, please provide an explanation after the code.
- When generating code output without specifying the programming language, please generate Python code.
- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.
{%- if developer_preamble %}
# Developer Preamble
The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.
{{ developer_preamble }}
{%- endif -%}
<|END_OF_TURN_TOKEN|>
{%- for message in messages %}
{%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
{%- elif message.role|lower == 'user' %}
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[
{% for tc in message.tool_calls %}
{"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}
{% set tool_idx.value = tool_idx.value + 1 %}
{% endfor %}
]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}
{% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
{{ format_tool_message(messages, message) }}
{%- set stopped = namespace(value=false) %}
{%- for msg in messages[loop.index0 + 1:] %}
{%- if not stopped.value and msg.role|lower == 'tool' %},
{{ format_tool_message(messages, msg) }}
{%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}
{%- else %}
{%- set stopped.value = true %}
{%- endif %}
{%- endfor %}
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>
{%- endif %}
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -0,0 +1,157 @@
{{ bos_token }}{%- macro document_turn(documents) -%}
{# format documents into chat turn #}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[
{"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}
]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
{
"tool_call_id": "0",
"results": {
{% for doc in documents %}
"{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},
{% endif %}
{% endfor %}
},
"is_error": null
}
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}
{%- macro tool_call_id_to_int(messages, tool_call_id) %}
{%- set counter = namespace(value=0) %}
{%- set tool_call_id_seen = namespace(value=false) %}
{%- for msg in messages %}
{%- if msg.tool_calls %}
{%- for tool_call in msg.tool_calls %}
{%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}
{{ counter.value }}
{%- set tool_call_id_seen.value = true %}
{%- endif %}
{%- set counter.value = counter.value + 1 %}
{%- endfor %}
{%- endif %}
{%- endfor %}
{%- endmacro %}
{%- macro format_tool_message(messages, tool_msg) -%}
{# format tool message #}
{
"tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",
"results": {
"0": {{ tool_msg.content|tojson }}
},
"is_error": null
}
{%- endmacro -%}
{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}
{%- set tool_idx = namespace(value=0) %}
{%- set tool_ids_seen = namespace(value=[]) %}
{%- set sent_documents = namespace(value=false) %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble
You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.
Your information cutoff date is June 2024.
You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.
{% if tools or documents %}
You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.
## Tool Use
Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.
0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.
You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.
NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.
Then carry out your plan by repeatedly executing the following steps.
1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.
When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.
2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.
Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".
3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.
You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.
NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.
You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.
4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.
{% if enable_citations %}
## Grounding
Importantly, note that "Reflection" and "Response" above can be grounded.
Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "<co>" and "</co>" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "<co>span</co: 0:[1,2],1:[0]>" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".
{% endif %}
## Available Tools
Here is the list of tools that you have available to you.
You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.
Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).
```json
[
{% if documents %}
{"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}
{% endif %}
{% for tool in tools %}
{"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}
{% endfor %}
]
```
{% endif %}
# Default Preamble
The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.
- Your name is Command.
- You are a large language model built by Cohere.
- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.
- If the input is ambiguous, ask clarifying follow-up questions.
- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).
- Use LaTeX to generate mathematical notation for complex equations.
- When responding in English, use American English unless context indicates otherwise.
- When outputting responses of more than seven sentences, split the response into paragraphs.
- Prefer the active voice.
- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.
- Use gender-neutral pronouns for unspecified persons.
- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.
- Use the third person when asked to write a summary.
- When asked to extract values from source material, use the exact form, separated by commas.
- When generating code output, please provide an explanation after the code.
- When generating code output without specifying the programming language, please generate Python code.
- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.
{%- if developer_preamble %}
# Developer Preamble
The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.
{{ developer_preamble }}
{%- endif -%}
<|END_OF_TURN_TOKEN|>
{%- for message in messages %}
{%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
{%- elif message.role|lower == 'user' %}
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[
{% for tc in message.tool_calls %}
{"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}
{% set tool_idx.value = tool_idx.value + 1 %}
{% endfor %}
]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}
{% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
{{ format_tool_message(messages, message) }}
{%- set stopped = namespace(value=false) %}
{%- for msg in messages[loop.index0 + 1:] %}
{%- if not stopped.value and msg.role|lower == 'tool' %},
{{ format_tool_message(messages, msg) }}
{%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}
{%- else %}
{%- set stopped.value = true %}
{%- endif %}
{%- endfor %}
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>
{%- endif %}
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -0,0 +1,3 @@
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<User>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<Assistant>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<Assistant>' }}{% endif %}

View File

@@ -0,0 +1 @@
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}{%- else %}{{'<Assistant>' + message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}{%- endif %}{%- endfor %}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}

View File

@@ -0,0 +1,4 @@
{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]
' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '
' }}{% else %}{{ '[|endofturn|]
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}

View File

@@ -0,0 +1,17 @@
'{{bos_token}}
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "You are a function calling AI model. You are provided with function signature within <tools> </tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.\n<tools>\n" }}
{%- for tool in tools %}[{{- tool | tojson }}]{%- endfor %}
{{- "\n</tools>\nFor each function call, return a json object with function name and arguments within <tool_call> </tool_call> tags with the following schema:\n<tool_call>\n{'arguments': <args-dict>, 'name': <function-name>}\n</tool_call>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}{% for message in messages %}{%- if message.role != 'system' %}{{'<|im_start|>' + message['role'] + '
' + message['content'] + '<|im_end|>' + '
'}}{%- endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
' }}{% endif %}'

View File

@@ -0,0 +1,4 @@
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '
' + message['content'] | trim + '<end_of_turn>
' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model
'}}{% endif %}

View File

@@ -0,0 +1,47 @@
{{ bos_token }}
{%- if messages[0]['role'] == 'system' -%}
{%- if messages[0]['content'] is string -%}
{%- set first_user_prefix = messages[0]['content'] + '
' -%}
{%- else -%}
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '
' -%}
{%- endif -%}
{%- set loop_messages = messages[1:] -%}
{%- else -%}
{%- set first_user_prefix = "" -%}
{%- set loop_messages = messages -%}
{%- endif -%}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif -%}
{%- if (message['role'] == 'assistant') -%}
{%- set role = "model" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<start_of_turn>' + role + '
' + (first_user_prefix if loop.first else "") }}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
{{ '<end_of_turn>
' }}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{'<start_of_turn>model
'}}
{%- endif -%}

View File

@@ -0,0 +1,255 @@
{# Variables #}
{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}
{##}
{% set bom_str = bom_str or "<|bom|>" %}
{% set eom_str = eom_str or "<|eom|>" %}
{% set default_system_message = "" %}
{##}
{% set documents_prefix = "<documents>" %}
{% set documents_suffix = "</documents>" %}
{% set tool_definitions_prefix = "<tool_definitions>" %}
{% set tool_definitions_suffix = "</tool_definitions>" %}
{% set active_modes_prefix = "<active_output_modes>" %}
{% set active_modes_suffix = "</active_output_modes>" %}
{##}
{% set tool_calls_prefix = "<tool_calls>" %}
{% set tool_calls_suffix = "</tool_calls>" %}
{% set citations_prefix = "<citations>" %}
{% set citations_suffix = "</citations>" %}
{##}
{% if add_generation_prompt is not defined %}
{% set add_generation_prompt = True %}
{% endif %}
{% set role_to_predict = role_to_predict or "assistant" %}
{% if messages|length > 0 and messages[0].role == "system" %}
{% set system_message = messages[0].content %}
{% set loop_messages = messages[1:] %}
{% else %}
{% set system_message = default_system_message %}
{% set loop_messages = messages %}
{% endif %}
{##}
{##}
{# Macros #}
{% macro handle_tool_definitions(tools) %}
{{- tool_definitions_prefix -}}
{{- "\n# Tools" -}}
{{- "\n\n## Functions" -}}
{% for tool in tools %}
{% set _ = is_param_set(tool, field="type") %}
{% set is_tool_type_set = ns.is_last_checked_defined %}
{% if is_tool_type_set %}
{% if tool.type == "function" %}
{% set tool = tool.function %}
{% else %}
{{ raise_exception("Currently, the only supported tool type is `function`") }}
{% endif %}
{% endif %}
{{- "\n\n" + (tool|tojson(indent=2)) -}}
{% endfor %}
{{- "\n" + tool_definitions_suffix -}}
{% endmacro %}
{##}
{% macro handle_first_system_message(system_message, tools) %}
{{- bom_str + handle_role("system") -}}
{% set _ = is_param_set(system_message) %}
{% set is_system_message_set = ns.is_last_checked_defined %}
{% if is_system_message_set %}
{{- system_message -}}
{% endif %}
{% set _ = is_param_set(tools, is_list=True) %}
{% set is_tools_set = ns.is_last_checked_defined %}
{% if is_tools_set %}
{% if system_message %}
{{- "\n\n" -}}
{% endif %}
{{- handle_tool_definitions(tools) -}}
{% endif %}
{% set ns.message_count = ns.message_count + 1 %}
{% endmacro %}
{##}
{% macro handle_tool_calls(tool_calls) %}
{{- tool_calls_prefix + "[\n" -}}
{% for tool_call in tool_calls %}
{% set _ = is_param_set(tool_call, field="function") %}
{% set is_tool_call_function_set = ns.is_last_checked_defined %}
{% if is_tool_call_function_set %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{% set arguments = tool_call.arguments %}
{% if arguments is not string %}
{%- set arguments = arguments|tojson -%}
{%- endif %}
{{ "{\"name\": \"" + tool_call.name + "\", \"arguments\": " + arguments + "}" -}}
{% if not loop.last %}
{{- "," }}
{% endif %}
{% endfor %}
{{- "\n]" + tool_calls_suffix -}}
{% endmacro %}
{##}
{% macro handle_documents(documents) %}
{{- documents_prefix -}}
{{- "\n# Documents" -}}
{{- "\n\nYou can use the following documents for reference:" -}}
{% for doc in documents %}
{{- "\n\n## Document ID: " + loop.index0|string -}}
{% set _ = is_param_set(doc, field="title") %}
{% set is_doc_title_set = ns.is_last_checked_defined %}
{% if is_doc_title_set %}
{{- "\nTitle: " + doc.title -}}
{% endif %}
{% for key, value in doc.items() %}
{% if key not in ["title", "text"] %}
{{- "\n" + key|title + ": " + value|string -}}
{% endif %}
{% endfor %}
{{- "\nText: " + doc.text -}}
{% endfor %}
{{- "\n" + documents_suffix -}}
{% endmacro %}
{##}
{% macro handle_knobs(knobs) %}
{{- active_modes_prefix -}}
{{- "\n# Active Modes" -}}
{{ "\n\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}
{{ " active modes simultaneously." -}}
{% if knobs.citation_mode == "fast" %}
{{- "\n\n## Citation Mode" -}}
{{- "\n\nProvide a list of references only for the documents you base your response on. Format your response" -}}
{{ " with the original answer followed by a citation section. Use this template:" -}}
{{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}
{{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}
{% endif %}
{% if knobs.response_format == "json_object" %}
{{- "\n\n## JSON Mode" -}}
{{ "\n\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}
{{ " If an appropriate JSON format exists, use it without modification." -}}
{% endif %}
{{- "\n" + active_modes_suffix -}}
{% endmacro %}
{##}
{% macro get_last_user_index(messages) %}
{% set ns.last_user_index = 0 %}
{% for message in messages %}
{% if message.role == 'user' %}
{% set ns.last_user_index = loop.index0 %}
{% endif %}
{% endfor %}
{{- ns.last_user_index -}}
{% endmacro %}
{##}
{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}
{{- bom_str + handle_role("system") -}}
{% set macros_to_call = [] %}
{% set params_for_macros = [] %}
{% if use_documents %}
{% set macros_to_call = macros_to_call + [handle_documents] %}
{% set params_for_macros = params_for_macros + [[documents]] %}
{% endif %}
{% if use_knobs %}
{% set macros_to_call = macros_to_call + [handle_knobs] %}
{% set params_for_macros = params_for_macros + [[knobs]] %}
{% endif %}
{% for i in range(macros_to_call|length) %}
{% if i > 0 %}
{{- "\n\n" -}}
{% endif %}
{{- macros_to_call[i](*params_for_macros[i]) -}}
{% endfor %}
{% set ns.message_count = ns.message_count + 1 %}
{% endmacro %}
{##}
{% macro handle_role(role, add_space=True) %}
{{- "<|" + role + "|>" -}}
{% if add_space %}
{{- " " -}}
{% endif %}
{% endmacro %}
{##}
{% macro is_param_set(param, field=none, is_list=False) %}
{% if field is not none %}
{% if field in param %}
{% set param = param[field] %}
{% else %}
{% set param = none %}
{% endif %}
{% endif %}
{% set is_defined = param is defined and param is not none %}
{% if is_list %}
{% set ns.is_last_checked_defined = is_defined and param|length > 0 %}
{% else %}
{% set ns.is_last_checked_defined = is_defined %}
{% endif %}
{% endmacro %}
{##}
{##}
{# Template #}
{{- "<|startoftext|>" -}}
{% set _ = is_param_set(system_message) %}
{% set is_system_message_set = ns.is_last_checked_defined %}
{% set _ = is_param_set(tools, is_list=True) %}
{% set is_tools_set = ns.is_last_checked_defined %}
{% set has_system_message = (is_system_message_set or is_tools_set) %}
{% if has_system_message %}
{{- handle_first_system_message(system_message, tools) -}}
{% endif %}
{% set last_user_index = get_last_user_index(loop_messages)|int %}
{% for message in loop_messages %}
{% if loop.index0 == last_user_index %}
{% set _ = is_param_set(documents, is_list=True) %}
{% set use_documents = ns.is_last_checked_defined %}
{% set _ = is_param_set(knobs) %}
{% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}
{% set add_last_system_message = use_documents or use_knobs %}
{% if add_last_system_message %}
{% if ns.message_count > 0 %}
{{- eom_str -}}
{% endif %}
{{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}
{% endif %}
{% endif %}
{% set role = message.role %}
{% set _ = is_param_set(message, field="name") %}
{% set is_message_name_set = ns.is_last_checked_defined %}
{% if is_message_name_set %}
{% set message_prefix = handle_role(role) + "(" + message.name + ")" %}
{% else %}
{% set message_prefix = handle_role(role) %}
{% endif %}
{% set content = (message.content or "") %}
{% if content is not string %}
{% set content = content|tojson %}
{% endif %}
{% if ns.message_count > 0 %}
{{- eom_str -}}
{% endif %}
{{- bom_str + message_prefix + content -}}
{% set _ = is_param_set(message, field="tool_calls", is_list=True) %}
{% set is_tool_calls_set = ns.is_last_checked_defined %}
{% if role == "assistant" and is_tool_calls_set %}
{{- handle_tool_calls(message.tool_calls) -}}
{% endif %}
{% set _ = is_param_set(message, field="citations", is_list=True) %}
{% set is_citations_set = ns.is_last_checked_defined %}
{% if role == "assistant" and is_citations_set %}
{{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}
{% endif %}
{% set ns.message_count = ns.message_count + 1 %}
{% endfor %}
{% if add_generation_prompt %}
{% if ns.message_count > 0 %}
{{- eom_str -}}
{% endif %}
{{- bom_str + handle_role(role_to_predict, add_space=False) -}}
{% set _ = is_param_set(generation_preamble) %}
{% set is_generation_preamble_set = ns.is_last_checked_defined %}
{% if is_generation_preamble_set and generation_preamble.strip() != "" %}
{{- " " + generation_preamble -}}
{% endif %}
{% set ns.message_count = ns.message_count + 1 %}
{% else %}
{% if ns.message_count > 0 %}
{{- eom_str -}}
{% endif %}
{% endif %}

View File

@@ -0,0 +1,5 @@
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>
'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>
' }}{% endif %}

View File

@@ -0,0 +1,122 @@
{{- bos_token }}
{%- if custom_tools is defined %}
{%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
{%- set tools_in_user_message = true %}
{%- endif %}
{%- if not date_string is defined %}
{%- if strftime_now is defined %}
{%- set date_string = strftime_now("%d %b %Y") %}
{%- else %}
{%- set date_string = "26 Jul 2024" %}
{%- endif %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "" %}
{%- endif %}
{#- Find out if there are any images #}
{% set image_ns = namespace(has_images=false) %}
{%- for message in messages %}
{%- for content in message['content'] %}
{%- if content['type'] == 'image' %}
{%- set image_ns.has_images = true %}
{%- endif %}
{%- endfor %}
{%- endfor %}
{#- Error out if there are images and system message #}
{%- if image_ns.has_images and not system_message == "" %}
{{- raise_exception("Prompting with images is incompatible with system messages.") }}
{%- endif %}
{#- System message if there are no images #}
{%- if not image_ns.has_images %}
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if tools is not none %}
{{- "Environment: ipython\n" }}
{%- endif %}
{{- "Cutting Knowledge Date: December 2023\n" }}
{{- "Today Date: " + date_string + "\n\n" }}
{%- if tools is not none and not tools_in_user_message %}
{{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{%- endif %}
{{- system_message }}
{{- "<|eot_id|>" }}
{%- endif %}
{#- Custom tools are passed in a user message with some extra guidance #}
{%- if tools_in_user_message and not tools is none %}
{#- Extract the first user message so we can plug it in here #}
{%- if messages | length != 0 %}
{%- set first_user_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
{{- "Given the following functions, please respond with a JSON for a function call " }}
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{{- first_user_message + "<|eot_id|>"}}
{%- endif %}
{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}
{%- if message['content'] is string %}
{{- message['content'] }}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'image' %}
{{- '<|image|>' }}
{%- elif content['type'] == 'text' %}
{{- content['text'] }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- '<|eot_id|>' }}
{%- elif 'tool_calls' in message %}
{%- if not message.tool_calls|length == 1 %}
{{- raise_exception("This model only supports single tool-calls at once!") }}
{%- endif %}
{%- set tool_call = message.tool_calls[0].function %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{{- '{"name": "' + tool_call.name + '", ' }}
{{- '"parameters": ' }}
{{- tool_call.arguments | tojson }}
{{- "}" }}
{{- "<|eot_id|>" }}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
{%- if message.content is mapping or message.content is iterable %}
{{- message.content | tojson }}
{%- else %}
{{- message.content }}
{%- endif %}
{{- "<|eot_id|>" }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}

View File

@@ -0,0 +1,123 @@
{{- bos_token }}
{%- if custom_tools is defined %}
{%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
{%- set tools_in_user_message = true %}
{%- endif %}
{%- if not date_string is defined %}
{%- if strftime_now is defined %}
{%- set date_string = strftime_now("%d %b %Y") %}
{%- else %}
{%- set date_string = "26 Jul 2024" %}
{%- endif %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
{%- if messages[0]['content'] is string %}
{%- set system_message = messages[0]['content']|trim %}
{%- else %}
{#- FIXME: The processor requires an array, always. #}
{%- set system_message = messages[0]['content'][0]['text']|trim %}
{%- endif %}
{%- set messages = messages[1:] %}
{%- set user_supplied_system_message = true %}
{%- else %}
{%- set system_message = "" %}
{%- set user_supplied_system_message = false %}
{%- endif %}
{#- System message if the user supplied one #}
{%- if user_supplied_system_message %}
{{- "<|header_start|>system<|header_end|>\n\n" }}
{%- if tools is not none %}
{{- "Environment: ipython\n" }}
{%- endif %}
{%- if tools is not none and not tools_in_user_message %}
{{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{%- endif %}
{{- system_message }}
{{- "<|eot|>" }}
{%- endif %}
{#- Custom tools are passed in a user message with some extra guidance #}
{%- if tools_in_user_message and not tools is none %}
{#- Extract the first user message so we can plug it in here #}
{%- if messages | length != 0 %}
{%- set first_user_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
{{- '<|header_start|>user<|header_end|>\n\n' -}}
{{- "Given the following functions, please respond with a JSON for a function call " }}
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{{- first_user_message + "<|eot|>"}}
{%- endif %}
{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
{%- if message['content'] is string %}
{{- message['content'] }}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'image' %}
{{- '<|image|>' }}
{%- elif content['type'] == 'text' %}
{{- content['text'] }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- "<|eot|>" }}
{%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}
{{- '<|header_start|>assistant<|header_end|>\n\n' -}}
{{- '<|python_start|>' }}
{%- if message['content'] is string %}
{{- message['content'] }}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'image' %}
{{- '<|image|>' }}
{%- elif content['type'] == 'text' %}
{{- content['text'] }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- '<|python_end|>' }}
{%- for tool_call in message.tool_calls %}
{{- '{"name": "' + tool_call.function.name + '", ' }}
{{- '"parameters": ' }}
{{- tool_call.function.arguments | tojson }}
{{- "}" }}
{%- endfor %}
{{- "<|eot|>" }}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{- "<|header_start|>ipython<|header_end|>\n\n" }}
{%- if message.content is mapping or message.content is iterable %}
{{- message.content | tojson }}
{%- else %}
{{- message.content }}
{%- endif %}
{{- "<|eot|>" }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|header_start|>assistant<|header_end|>\n\n' }}
{%- endif %}

View File

@@ -0,0 +1,2 @@
{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>
' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}

View File

@@ -0,0 +1 @@
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}

View File

@@ -0,0 +1 @@
{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}

View File

@@ -0,0 +1 @@
{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}

View File

@@ -0,0 +1 @@
{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}

View File

@@ -0,0 +1,51 @@
{%- set today = strftime_now("%Y-%m-%d") %}
{%- set default_system_message = "You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYour knowledge base was last updated on 2023-10-01. The current date is " + today + ".\n\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\")" %}
{{- bos_token }}
{%- if messages[0]['role'] == 'system' %}
{%- if messages[0]['content'] is string %}
{%- set system_message = messages[0]['content'] %}
{%- else %}
{%- set system_message = messages[0]['content'][0]['text'] %}
{%- endif %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set system_message = default_system_message %}
{%- set loop_messages = messages %}
{%- endif %}
{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
{%- for message in loop_messages %}
{%- if message['role'] == 'user' %}
{%- if message['content'] is string %}
{{- '[INST]' + message['content'] + '[/INST]' }}
{%- else %}
{{- '[INST]' }}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
{{- block['text'] }}
{%- elif block['type'] in ['image', 'image_url'] %}
{{- '[IMG]' }}
{%- else %}
{{- raise_exception('Only text and image blocks are supported in message content!') }}
{%- endif %}
{%- endfor %}
{{- '[/INST]' }}
{%- endif %}
{%- elif message['role'] == 'system' %}
{%- if message['content'] is string %}
{{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}
{%- else %}
{{- '[SYSTEM_PROMPT]' + message['content'][0]['text'] + '[/SYSTEM_PROMPT]' }}
{%- endif %}
{%- elif message['role'] == 'assistant' %}
{%- if message['content'] is string %}
{{- message['content'] + eos_token }}
{%- else %}
{{- message['content'][0]['text'] + eos_token }}
{%- endif %}
{%- else %}
{{- raise_exception('Only user, system and assistant roles are supported!') }}
{%- endif %}
{%- endfor %}

View File

@@ -0,0 +1,7 @@
{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '
' + message['content'] + '<|end|>' + '
'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '
' + message['content'] + '<|end|>' + '
' + '<|assistant|>' + '
'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '
'}}{% endif %}{% endfor %}

View File

@@ -0,0 +1,8 @@
{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>
' + message['content'] + '<|end|>
'}}{% elif message['role'] == 'user' %}{{'<|user|>
' + message['content'] + '<|end|>
'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>
' + message['content'] + '<|end|>
'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>
' }}{% endif %}

View File

@@ -0,0 +1 @@
{% set system_message = 'You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> {Thought section} </think> {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:' -%}{%- if messages and messages[0]['role'] == 'system' -%}{%- set system_message = messages[0]['content'] -%}{%- set messages = messages[1:] -%}{%- endif -%}<|im_start|>system<|im_sep|>{{ system_message }}<|im_end|>{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>'}}{% generation %}{{message['content'] + '<|im_end|>'}}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}

View File

@@ -0,0 +1,53 @@
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{{- bos_token }}
{%- for message in loop_messages %}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif %}
{%- if message["role"] == "user" %}
{%- if loop.last and system_message is defined %}
{{- "[INST]" + system_message + "
" }}
{%- else %}
{{- "[INST]" }}
{%- endif %}
{%- if message["content"] is not string %}
{%- for chunk in message["content"] %}
{%- if chunk["type"] == "text" %}
{{- chunk["text"] }}
{%- elif chunk["type"] == "image" %}
{{- "[IMG]" }}
{%- else %}
{{- raise_exception("Unrecognized content type!") }}
{%- endif %}
{%- endfor %}
{%- else %}
{{- message["content"] }}
{%- endif %}
{{- "[/INST]" }}
{%- elif message["role"] == "assistant" %}
{%- if message["content"] is not string %}
{%- for chunk in message["content"] %}
{%- if chunk["type"] == "text" %}
{{- chunk["text"] }}
{%- elif chunk["type"] == "image" %}
{{- "[IMG]" }}
{%- else %}
{{- raise_exception("Unrecognized content type!") }}
{%- endif %}
{%- endfor %}
{{- eos_token }}
{%- else %}
{{- message["content"] + eos_token }}
{%- endif %}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}

View File

@@ -0,0 +1,7 @@
{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system
You are a helpful assistant.<|im_end|>
{% endif %}<|im_start|>{{ message['role'] }}
{% if message['content'] is string %}{{ message['content'] }}<|im_end|>
{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>
{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
{% endif %}

View File

@@ -0,0 +1,87 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set content = message.content %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is defined and message.reasoning_content is not none %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in message.content %}
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
{%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if loop.last or (not loop.last and reasoning_content) %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- else %}
{{- '<think>\n\n' }}
{%- endif %}
{%- endif %}

View File

@@ -0,0 +1,54 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0]['role'] == 'system' %}
{{- messages[0]['content'] }}
{%- else %}
{{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
{%- endif %}
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0]['role'] == 'system' %}
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
{%- else %}
{{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{{- '<|im_start|>' + message.role }}
{%- if message.content %}
{{- '\n' + message.content }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '\n<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{{- tool_call.arguments | tojson }}
{{- '}\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}

View File

@@ -526,8 +526,9 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
if len(datasets) == 1:
ds = datasets[0]
# Do not shuffle if curriculum sampling is enabled
if cfg.curriculum_sampling:
# Do not shuffle if curriculum sampling is enabled or
# shuffle_merged_datasets is disabled
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
return ds
return ds.shuffle(seed=cfg.seed)

View File

@@ -203,7 +203,7 @@ class AxolotlInputConfig(
},
)
dataset_processes: int | None = Field(
default=min(32, os.cpu_count()), # type: ignore[type-var]
default=min(int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()), # type: ignore[type-var]
json_schema_extra={
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
},
@@ -549,6 +549,20 @@ class AxolotlInputConfig(
},
)
tiled_mlp: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use ALST tiled mlp for memory efficient long context"
},
)
tiled_mlp_num_shards: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of shards to use for ALST tiled mlp. If unset, it will be set based on seqlen/hidden_size"
},
)
llama4_linearized_experts: bool | None = None
deepspeed: str | dict[str, Any] | None = Field(
@@ -782,7 +796,7 @@ class AxolotlInputConfig(
chat_template_jinja: str | None = Field(
default=None,
json_schema_extra={
"description": "Custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null."
"description": "Custom jinja template or path to jinja file for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null."
},
)
chat_template_kwargs: dict[str, Any] | None = Field(
@@ -1114,3 +1128,17 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
raise ValueError("QAT is not supported on torch version < 2.6.0")
return data
@model_validator(mode="before")
@classmethod
def default_dataloader_opts(cls, data):
if (
data.get("dataloader_num_workers") is None
and data.get("dataloader_pin_memory") is None
and data.get("dataloader_prefetch_factor") is None
):
data["dataloader_num_workers"] = data.get("capabilities").get("n_gpu", 1)
data["dataloader_pin_memory"] = True
data["dataloader_prefetch_factor"] = 256
return data

View File

@@ -89,7 +89,7 @@ class SFTDataset(BaseModel):
chat_template_jinja: str | None = Field(
default=None,
json_schema_extra={
"description": "Custom jinja chat template. Used only if `chat_template: jinja` or empty."
"description": "Custom jinja chat template or path to jinja file. Used only if `chat_template: jinja` or empty."
},
)
data_files: str | list[str] | None = Field(

View File

@@ -476,6 +476,13 @@ class TrainingValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_deepspeed(cls, data):
if data.get("tiled_mlp", False) and not data.get("deepspeed"):
raise ValueError("tiled_mlp requires deepspeed ZeRO to be enabled")
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""

View File

@@ -535,6 +535,9 @@ def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(
cfg.gradient_accumulation_steps
)
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
@@ -609,6 +612,9 @@ def prepare_opinionated_env(cfg):
if cfg.qlora_sharded_model_loading:
# model loading is forked after the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.sample_packing:
# multipack parallel packing sampler defaults to using fork
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_trainer(

View File

@@ -10,12 +10,13 @@ import shutil
import sys
import tempfile
import time
from pathlib import Path, PosixPath
from pathlib import Path
from typing import Generator
import datasets
import pytest
import requests
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.errors import LocalEntryNotFoundError
from tokenizers import AddedToken
@@ -424,8 +425,8 @@ def temp_dir() -> Generator[str, None, None]:
@pytest.fixture(scope="function", autouse=True)
def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None:
os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache"
def torch_manual_seed():
torch.manual_seed(42)
@pytest.fixture(scope="function", autouse=True)

View File

@@ -19,8 +19,15 @@ def test_geglu_forward_shape():
assert out.device == gate.device
def test_geglu_forward_values():
@pytest.mark.flaky(retries=1, delay=5)
@pytest.mark.parametrize(
"torch_seed",
[0, 42],
)
def test_geglu_forward_values(torch_seed):
"""Test GEGLU forward pass matches PyTorch reference implementation."""
torch.manual_seed(torch_seed)
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
@@ -33,6 +40,7 @@ def test_geglu_forward_values():
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
@pytest.mark.flaky(retries=1, delay=5)
@pytest.mark.parametrize(
"torch_seed",
[0, 42],

View File

@@ -104,7 +104,7 @@ class TestSequenceParallelism:
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
# (False, 2, True, "batch_zigzag", 2.5),
(False, 2, False, None, 2.5), # defaults to batch_ring ring_attn_func
(False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",

View File

@@ -86,5 +86,5 @@ class TestPackedFlex:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)

View File

@@ -90,7 +90,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.8, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -364,6 +364,7 @@ class TestMultiGPULlama:
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
"seed": 42,
}
)
@@ -759,6 +760,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
"seed": 42,
**adapter,
}
)
@@ -856,7 +858,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
)
@pytest.mark.skip(

View File

@@ -0,0 +1,81 @@
"""
E2E tests for flattening batches
"""
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
class TestFAFlattening:
"""
Test case for Llama models using LoRA w batch flattening
"""
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
)
def test_lora_packing_flattening(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"batch_flattening": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"field_messages": "conversations",
"message_field_content": "value",
"message_field_role": "from",
"type": "chat_template",
"split": "train[:2%]",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
)

View File

@@ -9,7 +9,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import check_model_output_exists, require_torch_2_6_0, with_temp_dir
class TestMistral(unittest.TestCase):
@@ -17,6 +17,7 @@ class TestMistral(unittest.TestCase):
Test case for Llama models using LoRA
"""
@require_torch_2_6_0
@with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -63,5 +63,5 @@ class TestPackedFlex(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)

View File

@@ -1690,3 +1690,18 @@ class TestValidationMLflow(BaseValidation):
assert new_cfg.use_mlflow is True
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)
class TestDataloaderValidation(BaseValidation):
"""
tests for dataloader_* sane defaults
"""
def test_dataloader_auto_defaults(self, minimal_cfg):
cfg = minimal_cfg
new_cfg = validate_config(cfg, {"n_gpu": 8}, {"torch_version": "2.6.0"})
assert new_cfg.dataloader_num_workers == 8
assert new_cfg.dataloader_pin_memory is True
assert new_cfg.dataloader_prefetch_factor == 256