Compare commits
1 Commits
update-vll
...
v0.10.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b61cd445a |
6
.github/workflows/base.yml
vendored
6
.github/workflows/base.yml
vendored
@@ -5,13 +5,11 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- 'docker/Dockerfile-base'
|
||||
- 'docker/Dockerfile-uv-base'
|
||||
- 'Dockerfile-base'
|
||||
- '.github/workflows/base.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docker/Dockerfile-base'
|
||||
- 'docker/Dockerfile-uv-base'
|
||||
- 'Dockerfile-base'
|
||||
- '.github/workflows/base.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
13
.github/workflows/main.yml
vendored
13
.github/workflows/main.yml
vendored
@@ -20,11 +20,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -87,8 +88,8 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
@@ -145,8 +146,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
|
||||
6
.github/workflows/multi-gpu-e2e.yml
vendored
6
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,11 +26,11 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
|
||||
115
.github/workflows/tests-nightly.yml
vendored
115
.github/workflows/tests-nightly.yml
vendored
@@ -18,9 +18,96 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
preload-cache:
|
||||
name: Preload HF cache
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
env:
|
||||
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v tests/conftest.py
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage.xml
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
needs: [preload-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
@@ -33,11 +120,14 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -78,6 +168,10 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
@@ -99,8 +193,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
|
||||
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
@@ -195,12 +195,12 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -247,8 +247,8 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
@@ -311,7 +311,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -19,7 +19,7 @@ repos:
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.3.0
|
||||
rev: 7.2.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
@@ -27,7 +27,7 @@ repos:
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.16.1
|
||||
rev: v1.16.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
@@ -36,7 +36,7 @@ repos:
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.6
|
||||
rev: 1.8.3
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
|
||||
@@ -2,5 +2,4 @@ include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||
recursive-include axolotl *.py
|
||||
|
||||
11
README.md
11
README.md
@@ -43,7 +43,7 @@ Features:
|
||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
|
||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||
|
||||
@@ -59,8 +59,6 @@ Features:
|
||||
|
||||
### Installation
|
||||
|
||||
#### Using pip
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
@@ -70,13 +68,6 @@ axolotl fetch examples
|
||||
axolotl fetch deepspeed_configs # OPTIONAL
|
||||
```
|
||||
|
||||
#### Using Docker
|
||||
|
||||
Installing with Docker can be less error prone than installing in your own environment.
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
```
|
||||
|
||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
### Your First Fine-tune
|
||||
|
||||
@@ -9,7 +9,6 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_PROCESSES="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||
|
||||
@@ -32,8 +32,6 @@ df_args = {
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
||||
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
@@ -38,6 +38,6 @@ RUN git lfs install --skip-repo && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
||||
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
||||
pip3 install flash-attn==2.7.4.post1; \
|
||||
fi
|
||||
|
||||
@@ -34,3 +34,7 @@ RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
||||
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
|
||||
fi
|
||||
|
||||
@@ -7,7 +7,6 @@ toc-depth: 3
|
||||
```{python}
|
||||
#| echo: false
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
def process_readme(integration_name):
|
||||
@@ -54,24 +53,6 @@ sections = [
|
||||
("LLMCompressor", "llm_compressor")
|
||||
]
|
||||
|
||||
for folder_name in os.listdir("../src/axolotl/integrations/"):
|
||||
if folder_name in [path for name, path in sections]:
|
||||
# skip if already in sections
|
||||
continue
|
||||
if os.path.exists(f"../src/axolotl/integrations/{folder_name}/README.md"):
|
||||
# grab the first heading in README.md as the section name
|
||||
with open(f"../src/axolotl/integrations/{folder_name}/README.md", "r") as f:
|
||||
txt = f.read()
|
||||
matches = re.search(r'^# (.*)\n?', txt, flags=re.MULTILINE)
|
||||
if matches:
|
||||
name = matches.group(1)
|
||||
else:
|
||||
continue
|
||||
sections.append((name, folder_name))
|
||||
|
||||
# sort sections by name
|
||||
sections = sorted(sections, key=lambda x: x[0])
|
||||
|
||||
for section_name, folder_name in sections:
|
||||
print(print_section(section_name, folder_name))
|
||||
```
|
||||
|
||||
@@ -9,7 +9,7 @@ order: 3
|
||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]}
|
||||
{"conversations": [{"role": "...", "content": "..."}]}
|
||||
```
|
||||
|
||||
See [configs](../config-reference.qmd) for full configs and supported templates.
|
||||
|
||||
@@ -9,7 +9,7 @@ format:
|
||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
||||
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
## Base
|
||||
@@ -34,7 +34,6 @@ Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.5.1`
|
||||
|
||||
@@ -74,15 +73,13 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-py3.11-cu128-2.7.1`
|
||||
- `main-py3.11-cu126-2.7.1`
|
||||
- `main-py3.11-cu126-2.6.0`
|
||||
- `main-py3.11-cu126-2.7.0`
|
||||
- `main-py3.11-cu124-2.6.0`
|
||||
- `main-py3.11-cu124-2.5.1`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu124-2.5.1`
|
||||
- `0.10.1`
|
||||
- `0.9.2`
|
||||
|
||||
## Cloud
|
||||
|
||||
|
||||
12
docs/faq.qmd
12
docs/faq.qmd
@@ -51,18 +51,6 @@ description: Frequently asked questions
|
||||
> pad_token: "..."
|
||||
> ```
|
||||
|
||||
**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI**
|
||||
|
||||
> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand.
|
||||
|
||||
**Q: vLLM is not working with Axolotl**
|
||||
|
||||
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
|
||||
|
||||
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
||||
|
||||
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||
|
||||
1. Set `adapter: qlora` in your axolotl config file.
|
||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Example Config
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,71 +0,0 @@
|
||||
base_model: tiiuae/Falcon-H1-1.5B-Deep-Base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,71 +0,0 @@
|
||||
base_model: tiiuae/Falcon-H1-1.5B-Base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,71 +0,0 @@
|
||||
base_model: tiiuae/Falcon-H1-34B-Base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,71 +0,0 @@
|
||||
base_model: tiiuae/Falcon-H1-3B-Base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,71 +0,0 @@
|
||||
base_model: tiiuae/Falcon-H1-0.5B-Instruct
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,71 +0,0 @@
|
||||
base_model: tiiuae/Falcon-H1-7B-Base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -13,8 +13,6 @@ load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: gemma3
|
||||
eot_tokens:
|
||||
- <end_of_turn>
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
|
||||
@@ -6,8 +6,6 @@ load_in_4bit: true
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
eot_tokens:
|
||||
- <end_of_turn>
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
|
||||
@@ -12,8 +12,6 @@ sample_packing: false
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
eot_tokens:
|
||||
- <end_of_turn>
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen2_vl
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.46.0
|
||||
bitsandbytes==0.45.4
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
@@ -11,11 +11,11 @@ liger-kernel==0.5.10
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub[hf_xet]==0.33.0
|
||||
huggingface_hub==0.32.2
|
||||
peft==0.15.2
|
||||
transformers==4.53.1
|
||||
transformers==4.52.4
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.8.1
|
||||
accelerate==1.7.0
|
||||
datasets==3.6.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.18.2
|
||||
@@ -68,4 +68,4 @@ schedulefree==1.4.1
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
mistral-common==1.6.3
|
||||
mistral-common==1.6.0
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
|
||||
)
|
||||
|
||||
12
setup.py
12
setup.py
@@ -65,13 +65,15 @@ def parse_requirements(extras_require_map):
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||
extras_require_map["vllm"] = ["vllm==0.9.2"]
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append(
|
||||
"xformers==0.0.29.post2"
|
||||
) # vllm needs post2 w torch 2.6
|
||||
extras_require_map["vllm"] = ["vllm==0.9.2"]
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
@@ -109,10 +111,10 @@ def get_package_version():
|
||||
|
||||
|
||||
extras_require = {
|
||||
"flash-attn": ["flash-attn==2.8.0.post2"],
|
||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.8.0.post2",
|
||||
"ring-flash-attn>=0.1.5",
|
||||
"flash-attn==2.7.4.post1",
|
||||
"ring-flash-attn>=0.1.4",
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.11.0.dev"
|
||||
__version__ = "0.10.1"
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
from accelerate.commands.config import config_args
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -47,8 +46,3 @@ def check_user_token() -> bool:
|
||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
except HTTPError:
|
||||
LOG.warning(
|
||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -23,6 +24,7 @@ def do_cli_preprocess(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
@@ -37,6 +39,7 @@ def do_cli_train(
|
||||
cwd=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
@@ -51,6 +54,7 @@ def do_cli_lm_eval(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
|
||||
@@ -28,8 +28,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
API_KEY_FIELDS = {"comet_api_key"}
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||
"""
|
||||
@@ -235,15 +233,4 @@ def load_cfg(
|
||||
setup_comet_env_vars(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
cfg_to_log = {
|
||||
k: "[REDACTED]" if k in API_KEY_FIELDS else v
|
||||
for k, v in cfg.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
"config:\n%s",
|
||||
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
@@ -9,6 +9,7 @@ from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
@@ -34,6 +35,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
patch_optimized_env()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
|
||||
@@ -13,6 +13,7 @@ from dotenv import load_dotenv
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.cli.args import InferenceCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.utils.chat_templates import (
|
||||
@@ -254,6 +255,7 @@ def do_cli(
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||
|
||||
@@ -20,7 +20,6 @@ from axolotl.cli.args import (
|
||||
TrainerCliArgs,
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
@@ -41,7 +40,6 @@ LOG = get_logger(__name__)
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Union
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -22,6 +23,8 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -193,6 +194,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||
|
||||
@@ -12,6 +12,7 @@ from dotenv import load_dotenv
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.args import PreprocessCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
@@ -32,15 +33,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Preprocessing-specific CLI arguments.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
|
||||
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
|
||||
if cfg.get("key"):
|
||||
raise ValueError(
|
||||
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
|
||||
)
|
||||
|
||||
if not cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Union
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -26,6 +27,7 @@ def do_quantize(
|
||||
config (Union[Path, str]): The path to the config file
|
||||
cli_args (dict): Additional command-line arguments
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
@@ -34,6 +35,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
|
||||
@@ -75,17 +75,13 @@ def load_datasets(
|
||||
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
text_only = cli_args.debug_text_only if cli_args else False
|
||||
try:
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
)
|
||||
except AttributeError:
|
||||
# can't sample iterable datasets
|
||||
pass
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
)
|
||||
|
||||
LOG.info("printing prompters...")
|
||||
for prompter in prompters:
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
"""
|
||||
monkeypatch for flex + packing
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from transformers import Cache, PretrainedConfig
|
||||
from transformers.masking_utils import (
|
||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
||||
_preprocess_mask_arguments,
|
||||
and_masks,
|
||||
causal_mask_function,
|
||||
or_masks,
|
||||
)
|
||||
from transformers.utils import is_torch_greater_or_equal
|
||||
|
||||
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||
|
||||
|
||||
def create_causal_mask(
|
||||
config: PretrainedConfig,
|
||||
input_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
or_mask_function: Optional[Callable] = None,
|
||||
and_mask_function: Optional[Callable] = None,
|
||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||
"""
|
||||
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
|
||||
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
|
||||
to what is needed in the `modeling_xxx.py` files).
|
||||
|
||||
Args:
|
||||
config (`PretrainedConfig`):
|
||||
The model config.
|
||||
input_embeds (`torch.Tensor`):
|
||||
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
||||
batch size, query length and dtype.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
||||
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
past_key_values (`Cache`, optional):
|
||||
The past key values, if we use a cache.
|
||||
or_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the union of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
and_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||
if (
|
||||
past_key_values
|
||||
and hasattr(past_key_values, "is_sliding")
|
||||
and False in past_key_values.is_sliding
|
||||
):
|
||||
layer_idx = past_key_values.is_sliding.index(False)
|
||||
else:
|
||||
layer_idx = 0
|
||||
|
||||
original_attention_mask = (
|
||||
None
|
||||
if attention_mask is None
|
||||
else attention_mask.clone().to(cache_position.device)
|
||||
)
|
||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
||||
)
|
||||
if early_exit:
|
||||
return attention_mask
|
||||
|
||||
batch_size, total_seq_len = cache_position.shape
|
||||
key_length = total_seq_len
|
||||
document_ids = torch.nn.functional.pad(
|
||||
original_attention_mask, value=0, pad=(0, key_length)
|
||||
)
|
||||
|
||||
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
|
||||
if attention_mask is not None:
|
||||
|
||||
def causal_doc_mask_mod(
|
||||
batch_idx, head_idx, q_idx, kv_idx
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||
and a block diagonal document mask.
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
causal_mask_ = q_idx >= kv_idx # not valid when decoding
|
||||
document_mask = (
|
||||
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||
)
|
||||
final_mask = causal_mask_ & document_mask
|
||||
return final_mask
|
||||
|
||||
mask_factory_function = causal_doc_mask_mod
|
||||
else:
|
||||
mask_factory_function = causal_mask_function
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
|
||||
config._attn_implementation # pylint: disable=protected-access
|
||||
]
|
||||
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
allow_is_causal_skip = (
|
||||
not past_key_values.is_compileable if past_key_values is not None else True
|
||||
)
|
||||
|
||||
# Allow slight deviations from causal mask
|
||||
if or_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
|
||||
# We now create the mask
|
||||
causal_mask = mask_interface(
|
||||
batch_size=batch_size,
|
||||
cache_position=cache_position,
|
||||
kv_length=kv_length,
|
||||
kv_offset=kv_offset,
|
||||
mask_function=mask_factory_function,
|
||||
attention_mask=attention_mask,
|
||||
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
def patch_create_causal_mask(model_type):
|
||||
import transformers.masking_utils
|
||||
|
||||
transformers.masking_utils.create_causal_mask = create_causal_mask
|
||||
|
||||
if model_type:
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(module_path)
|
||||
module.create_causal_mask = create_causal_mask
|
||||
del sys.modules[module_path]
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
@@ -219,9 +219,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.bf16 == "full":
|
||||
training_args_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
bf16 = self.cfg.bf16 or self.cfg.bfloat16
|
||||
bf16 = bf16 if bf16 is not None else False
|
||||
training_args_kwargs["bf16"] = bf16
|
||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
||||
|
||||
def _configure_scheduler(self, training_args_kwargs: dict):
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
||||
|
||||
@@ -245,27 +245,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
|
||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
|
||||
self.cfg.flash_attention
|
||||
or self.cfg.xformers_attention
|
||||
or self.cfg.flex_attention
|
||||
)
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
self.cfg.multipack_real_batches
|
||||
if self.cfg.multipack_real_batches is not None
|
||||
else not (
|
||||
self.cfg.flash_attention
|
||||
or self.cfg.flex_attention
|
||||
or self.cfg.xformers_attention
|
||||
)
|
||||
else not self.cfg.flash_attention
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||
self.cfg.eval_sample_packing
|
||||
)
|
||||
if self.cfg.sample_packing_sequentially is not None:
|
||||
training_arguments_kwargs["sample_packing_sequentially"] = (
|
||||
self.cfg.sample_packing_sequentially
|
||||
)
|
||||
if self.cfg.sample_packing_bin_size is not None:
|
||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||
self.cfg.sample_packing_bin_size
|
||||
@@ -426,8 +413,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
||||
return None
|
||||
return None
|
||||
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||
|
||||
@@ -20,14 +20,13 @@ from torch.utils.data import (
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
CheckpointSaveMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
@@ -43,12 +42,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(
|
||||
PackingMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
Trainer,
|
||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
@@ -122,15 +116,14 @@ class AxolotlTrainer(
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
num_processes=self.args.dataset_num_proc,
|
||||
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
|
||||
)
|
||||
|
||||
len(sampler)
|
||||
return sampler
|
||||
|
||||
def _get_train_sampler(
|
||||
self, train_dataset: Dataset | None = None
|
||||
) -> Sampler | None:
|
||||
self, train_dataset: Optional[Dataset] = None
|
||||
) -> Optional[Sampler]:
|
||||
"""
|
||||
Helper method to get the sampler for training. Handles cases for sample packing
|
||||
and curriculum sampling (sequential).
|
||||
@@ -139,22 +132,16 @@ class AxolotlTrainer(
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24
|
||||
if train_dataset is None:
|
||||
train_dataset = self.train_dataset
|
||||
if train_dataset is None or not has_length(train_dataset):
|
||||
return None
|
||||
|
||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||
|
||||
# Determine the base sampler first
|
||||
if self.args.curriculum_sampling:
|
||||
base_sampler = SequentialSampler(train_dataset)
|
||||
base_sampler = SequentialSampler(self.train_dataset)
|
||||
elif use_sample_packing:
|
||||
base_sampler = RandomSampler(train_dataset)
|
||||
base_sampler = RandomSampler(self.train_dataset)
|
||||
else:
|
||||
# Default to parent class implementation for standard random sampling
|
||||
return super()._get_train_sampler(train_dataset)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if use_sample_packing:
|
||||
@@ -173,10 +160,6 @@ class AxolotlTrainer(
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24
|
||||
if eval_dataset is None or not has_length(eval_dataset):
|
||||
return None
|
||||
|
||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||
use_multipack = (
|
||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||
@@ -212,14 +195,6 @@ class AxolotlTrainer(
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
if isinstance(dataset, datasets.Dataset):
|
||||
if is_training:
|
||||
|
||||
@@ -28,7 +28,7 @@ class DPOStrategy:
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
if cfg.dpo_padding_free is not None:
|
||||
|
||||
@@ -5,6 +5,5 @@
|
||||
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Trainer mixin to support packing"""
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
class PackingMixin(Trainer):
|
||||
"""
|
||||
Trainer mixin to support packing
|
||||
"""
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
if (
|
||||
self._signature_columns
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
set_sig_columns = set(self._signature_columns)
|
||||
set_sig_columns.remove("attention_mask")
|
||||
self._signature_columns = list(set_sig_columns)
|
||||
@@ -38,14 +38,6 @@ class AxolotlTrainingMixins:
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
sample_packing_mp_start_method: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The multiprocessing start method to use."},
|
||||
)
|
||||
sample_packing_drop_attention_mask: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Drop attention mask from inputs when using packing."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
|
||||
@@ -19,11 +19,19 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet.
|
||||
|
||||
```bash
|
||||
git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764
|
||||
|
||||
pip3 install --no-build-isolation -e .
|
||||
```
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
@@ -31,29 +39,27 @@ plugins:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- cohere
|
||||
- cohere2
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- mllama
|
||||
- phi3
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3
|
||||
- gemma3_text
|
||||
- glm
|
||||
- glm4
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- mistral
|
||||
- mistral3
|
||||
- mllama
|
||||
- phi
|
||||
- phi3
|
||||
- phi4_multimodal
|
||||
- qwen2
|
||||
- qwen2_vl
|
||||
- qwen2_moe
|
||||
- qwen2_vl
|
||||
- qwen2_5_vl
|
||||
- qwen3
|
||||
- qwen3_moe
|
||||
- cohere
|
||||
- cohere2
|
||||
- glm
|
||||
- glm4
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
@@ -31,8 +31,8 @@ from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa:
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"`'
|
||||
"Please install cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`'
|
||||
)
|
||||
|
||||
|
||||
@@ -64,28 +64,16 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
"cut_cross_entropy.transformers"
|
||||
)
|
||||
if cce_spec_transformers is None:
|
||||
raise ImportError(
|
||||
"Transformers support is not installed. " + _CCE_INSTALL_MESSAGE
|
||||
)
|
||||
|
||||
# Check if Axolotl's cce fork is installed
|
||||
try:
|
||||
from cut_cross_entropy.transformers.patch import AXOLOTL_CCE_FORK
|
||||
|
||||
if not AXOLOTL_CCE_FORK:
|
||||
raise ImportError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Axolotl's fork of cut_cross_entropy is not installed. "
|
||||
+ _CCE_INSTALL_MESSAGE
|
||||
) from e
|
||||
raise ImportError(_CCE_INSTALL_MESSAGE)
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply cut cross entropy before model loading if enabled."""
|
||||
if cfg.cut_cross_entropy:
|
||||
self._check_requirements()
|
||||
|
||||
from cut_cross_entropy.transformers.patch import cce_patch
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
||||
cce_patch,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||
|
||||
191
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
Normal file
191
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Cohere and Cohere2 CCE patch."""
|
||||
|
||||
# This patch is based off transformers 4.50.0.
|
||||
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
||||
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
||||
# operation is done internally and should be mathematically equivalent.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||
|
||||
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
|
||||
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>> # Generate
|
||||
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
# scale hidden_states by logit_scale in-place of logits
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :] * self.logit_scale,
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
logits = logits * self.logit_scale # main diff from Llama
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_cohere(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.cohere import modeling_cohere
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_cohere.CohereForCausalLM
|
||||
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_cohere2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.cohere2 import modeling_cohere2
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
||||
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
165
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py
Normal file
165
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Gemma CCE patch"""
|
||||
|
||||
# This patch is based off transformers 4.50.0.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma.GemmaForCausalLM
|
||||
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
447
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py
Normal file
447
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
||||
|
||||
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
||||
# and updated for transformers 4.50.0.
|
||||
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
||||
# with both gemma3 (text and multimodal) models.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache, HybridCache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
||||
|
||||
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logits = logits / self.config.final_logit_softcapping
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.config.final_logit_softcapping
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"answer en Where is the cow standing?\nbeach"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids # type: ignore
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = (
|
||||
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
||||
)
|
||||
cache_position = torch.arange( # type: ignore
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(
|
||||
self.config.image_token_index,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||
|
||||
# mask out pad-token-ids in labels for BC
|
||||
if labels is not None and self.pad_token_id in labels:
|
||||
logger.warning_once(
|
||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||
)
|
||||
labels = torch.where( # type: ignore
|
||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
||||
)
|
||||
|
||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
inputs_embeds,
|
||||
is_training,
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
shift_logits = logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:]
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = shift_logits[
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = shift_labels[
|
||||
shift_attention_mask.to(shift_labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = shift_logits.contiguous()
|
||||
shift_labels = shift_labels.contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
|
||||
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
||||
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||
loss = loss_fct(flat_logits, flat_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Gemma3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma2 import modeling_gemma2
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
||||
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_gemma3_text(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
||||
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_gemma3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
||||
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -0,0 +1,57 @@
|
||||
"""GLM 4 patch. GLM family inherits from Llama."""
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
|
||||
|
||||
def patch_glm(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import cut_cross_entropy.transformers.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
from transformers.models.glm import modeling_glm
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_glm.GlmForCausalLM
|
||||
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_glm.GlmForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_glm4(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import cut_cross_entropy.transformers.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
from transformers.models.glm4 import modeling_glm4
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_glm4.Glm4ForCausalLM
|
||||
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
164
src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py
Normal file
164
src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Llama CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states is None")
|
||||
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_llama(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
"""Patch Llama for CCE."""
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_llama.LlamaForCausalLM
|
||||
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_llama.LlamaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
401
src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py
Normal file
401
src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""Llama4 CCE patch. Adapted from transformers 4.51.0."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama4.modeling_llama4 import (
|
||||
Llama4CausalLMOutputWithPast,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
|
||||
|
||||
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor | None = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||
|
||||
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
|
||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
|
||||
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
num_tokens_to_fill = final_mask_1d.sum()
|
||||
|
||||
if num_tokens_to_fill != projected_vision_flat.size(0):
|
||||
raise ValueError(
|
||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||
)
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
expanded_mask, projected_vision_flat
|
||||
) # type: ignore
|
||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
# TODO: check if need to handle attention_mask
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = logits[..., :-1, :][
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = labels[..., 1:][
|
||||
shift_attention_mask.to(labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1).to(shift_logits.device),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Llama4CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits, # type: ignore # TODO: check if need to create dummy logits
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_llama4_text(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.llama4 import modeling_llama4
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_llama4.Llama4ForCausalLM
|
||||
), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
setattr(
|
||||
modeling_llama4.Llama4ForCausalLM,
|
||||
"forward",
|
||||
cce_forward,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def patch_llama4(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.llama4 import modeling_llama4
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_llama4.Llama4ForConditionalGeneration
|
||||
), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the language model
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
setattr(
|
||||
modeling_llama4.Llama4ForConditionalGeneration,
|
||||
"forward",
|
||||
cce_forward_multimodal,
|
||||
)
|
||||
|
||||
# patch the causal language model
|
||||
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
|
||||
return None
|
||||
@@ -0,0 +1,384 @@
|
||||
"""Mistral and Mistral3 CCE patch."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.mistral3.modeling_mistral3 import (
|
||||
Mistral3CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] | None = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||
|
||||
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor | None = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
|
||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = logits[..., :-1, :][
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = labels[..., 1:][
|
||||
shift_attention_mask.to(labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1).to(shift_logits.device),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Mistral3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_mistral(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mistral.MistralForCausalLM
|
||||
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_mistral3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
from transformers.models.mistral3 import modeling_mistral3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
||||
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||
return None
|
||||
366
src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
Normal file
366
src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""Mllama CCE patch."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.mllama.modeling_mllama import (
|
||||
_prepare_cross_attention_mask,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cross_attention_states: Optional[torch.LongTensor] = None,
|
||||
cross_attention_mask: Optional[torch.LongTensor] = None,
|
||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
||||
|
||||
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
||||
|
||||
>>> prompt = "If I had to write a haiku, it would be:"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
||||
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
>>> print(result)
|
||||
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
||||
I love the idea of snowflakes gently falling, each one
|
||||
```
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
||||
|
||||
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
||||
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
||||
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
||||
|
||||
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> output = model.generate(**inputs, max_new_tokens=15)
|
||||
|
||||
>>> prompt_len = inputs.input_ids.shape[-1]
|
||||
>>> generated_ids = output[:, prompt_len:]
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
>>> print(generated_text)
|
||||
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
||||
```
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and cross_attention_states is not None:
|
||||
raise ValueError(
|
||||
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
if aspect_ratio_ids is None:
|
||||
raise ValueError(
|
||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||
)
|
||||
# get vision tokens from vision model
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
aspect_ratio_ids=aspect_ratio_ids,
|
||||
aspect_ratio_mask=aspect_ratio_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
cross_attention_states = vision_outputs[0]
|
||||
cross_attention_states = self.multi_modal_projector(
|
||||
cross_attention_states
|
||||
).reshape(
|
||||
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
||||
)
|
||||
|
||||
if cross_attention_mask is not None:
|
||||
cross_attention_mask, full_text_row_masked_out_mask = (
|
||||
_prepare_cross_attention_mask(
|
||||
cross_attention_mask,
|
||||
num_vision_tokens=self.vision_model.num_patches,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
)
|
||||
else:
|
||||
full_text_row_masked_out_mask = None
|
||||
|
||||
if cross_attention_mask is not None and cache_position is not None:
|
||||
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
||||
:, :, cache_position
|
||||
]
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
else:
|
||||
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
||||
logits = hidden_states
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (loss,) + outputs if loss is not None else outputs
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=outputs.logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_mllama(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mllama import modeling_mllama
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
||||
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the language model
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
||||
|
||||
# patch the causal language model
|
||||
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
126
src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py
Normal file
126
src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Cut Cross Entropy patcher"""
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
||||
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
||||
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
||||
patch_cohere,
|
||||
patch_cohere2,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
||||
patch_gemma2,
|
||||
patch_gemma3,
|
||||
patch_gemma3_text,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
|
||||
patch_glm,
|
||||
patch_glm4,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
||||
patch_llama,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
|
||||
patch_llama4,
|
||||
patch_llama4_text,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
||||
patch_mistral,
|
||||
patch_mistral3,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
|
||||
patch_qwen2,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
|
||||
patch_qwen2_5_vl,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
|
||||
patch_qwen2_moe,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
|
||||
patch_qwen2_vl,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
|
||||
patch_qwen3_moe,
|
||||
)
|
||||
|
||||
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
||||
"llama": patch_llama,
|
||||
"llama4": patch_llama4,
|
||||
"llama4_text": patch_llama4_text,
|
||||
"mllama": patch_mllama,
|
||||
"phi3": patch_phi3,
|
||||
"gemma": patch_gemma,
|
||||
"gemma2": patch_gemma2,
|
||||
"gemma3": patch_gemma3,
|
||||
"gemma3_text": patch_gemma3_text,
|
||||
"mistral": patch_mistral,
|
||||
"mistral3": patch_mistral3,
|
||||
"qwen2": patch_qwen2,
|
||||
"qwen2_moe": patch_qwen2_moe,
|
||||
"qwen2_vl": patch_qwen2_vl,
|
||||
"qwen2_5_vl": patch_qwen2_5_vl,
|
||||
"qwen3": patch_qwen3,
|
||||
"qwen3_moe": patch_qwen3_moe,
|
||||
"cohere": patch_cohere,
|
||||
"cohere2": patch_cohere2,
|
||||
"glm": patch_glm,
|
||||
"glm4": patch_glm4,
|
||||
}
|
||||
|
||||
|
||||
def cce_patch(
|
||||
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
||||
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
||||
reduction: str = "mean",
|
||||
filter_eps: float | str | None = "auto",
|
||||
accum_e_fp32: bool = False,
|
||||
accum_c_fp32: bool = False,
|
||||
filter_e_grad: bool = True,
|
||||
filter_c_grad: bool = True,
|
||||
train_only: bool = False,
|
||||
) -> TransformersModelT | None:
|
||||
if isinstance(impl, LinearCrossEntropyImpl):
|
||||
impl = impl.name.lower()
|
||||
|
||||
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
||||
raise ValueError(f"Unknown {impl=}")
|
||||
|
||||
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
||||
if hasattr(model_type_or_model, "config"):
|
||||
model_type = getattr(
|
||||
getattr(model_type_or_model, "config", None), "model_type", None
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"model_type_or_model is a PreTrainedModel but does not have a config attribute"
|
||||
)
|
||||
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
||||
model_type = model_type_or_model.model_type
|
||||
else:
|
||||
model_type = model_type_or_model
|
||||
|
||||
patch_options = PatchOptions(
|
||||
impl=impl,
|
||||
reduction=reduction,
|
||||
filter_eps=filter_eps,
|
||||
accum_e_fp32=accum_e_fp32,
|
||||
accum_c_fp32=accum_c_fp32,
|
||||
filter_e_grad=filter_e_grad,
|
||||
filter_c_grad=filter_c_grad,
|
||||
train_only=train_only,
|
||||
)
|
||||
|
||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
||||
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
||||
model_type_or_model, patch_options
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unknown model type {model_type}")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
||||
cce_forward,
|
||||
)
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2.Qwen2ForCausalLM
|
||||
), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -0,0 +1,246 @@
|
||||
"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLCausalLMOutputWithPast,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
>>> messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.video_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
||||
):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts,
|
||||
attention_mask,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (
|
||||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||
if cache_position is not None
|
||||
else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
||||
position_ids = position_ids.add(delta) # type: ignore
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = None
|
||||
loss = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Qwen2_5_VLCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=self.rope_deltas,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2_5_vl(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
|
||||
), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = (
|
||||
cce_forward_multimodal
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,178 @@
|
||||
"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM
|
||||
|
||||
>>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: MoeModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states is None")
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
||||
loss.device # type: ignore
|
||||
) # make sure to reside in the same device
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss, # type: ignore
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2_moe(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM
|
||||
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(forward, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward
|
||||
return None
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
Qwen2VLCausalLMOutputWithPast,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
|
||||
>>> messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
||||
):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (
|
||||
cache_position[0] + self.rope_deltas
|
||||
if cache_position is not None
|
||||
else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
||||
delta = delta.to(position_ids.device) # type: ignore
|
||||
position_ids = position_ids.add(delta) # type: ignore
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = None
|
||||
loss = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Qwen2VLCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=self.rope_deltas,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2_vl(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration
|
||||
), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal
|
||||
return None
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen3.Qwen3ForCausalLM
|
||||
), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||
KwargsForCausalLM,
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
||||
|
||||
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: MoeModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states is None")
|
||||
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
||||
loss.device # type: ignore
|
||||
) # make sure to reside in the same device
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss, # type: ignore
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen3_moe(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM
|
||||
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(forward, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward
|
||||
return None
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Monkeypatch for apply_lce to add softcap."""
|
||||
|
||||
import torch
|
||||
from cut_cross_entropy import linear_cross_entropy
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions
|
||||
|
||||
|
||||
def apply_lce(
|
||||
e: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
opts: PatchOptions,
|
||||
bias: torch.Tensor | None = None,
|
||||
softcap: float | None = None,
|
||||
**loss_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Monkey patch for apply_lce to support softcap kwarg."""
|
||||
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
||||
cce_kwargs = opts.to_kwargs()
|
||||
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
||||
cce_kwargs["reduction"] = "sum"
|
||||
else:
|
||||
num_items_in_batch = None
|
||||
|
||||
loss = linear_cross_entropy(
|
||||
e,
|
||||
c,
|
||||
labels.to(e.device),
|
||||
bias=bias,
|
||||
shift=True,
|
||||
softcap=softcap,
|
||||
**cce_kwargs,
|
||||
)
|
||||
|
||||
if num_items_in_batch is not None:
|
||||
loss = loss / num_items_in_batch
|
||||
|
||||
return loss
|
||||
@@ -1,12 +0,0 @@
|
||||
# DenseMixer
|
||||
|
||||
See [DenseMixer](https://github.com/yaof20/DenseMixer/)
|
||||
|
||||
# Usage
|
||||
|
||||
Simply add the following to your axolotl YAML config:
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.densemixer.DenseMixerPlugin
|
||||
```
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Integration entry point for the DenseMixer plugin."""
|
||||
|
||||
from .plugin import DenseMixerPlugin
|
||||
|
||||
__all__ = ["DenseMixerPlugin"]
|
||||
@@ -1,11 +0,0 @@
|
||||
"""Pydantic models for DenseMixer plugin"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DenseMixerArgs(BaseModel):
|
||||
"""
|
||||
Args for DenseMixer
|
||||
"""
|
||||
|
||||
dense_mixer: bool = True
|
||||
@@ -1,42 +0,0 @@
|
||||
"""DenseMixer plugin for Axolotl"""
|
||||
|
||||
import importlib
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class DenseMixerPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for DenseMixer
|
||||
"""
|
||||
|
||||
def get_input_args(self) -> str | None:
|
||||
return "axolotl.integrations.densemixer.args.DenseMixerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply densemixer patches before model loading if enabled."""
|
||||
if cfg.dense_mixer:
|
||||
if not importlib.util.find_spec("densemixer"):
|
||||
raise RuntimeError(
|
||||
"DenseMixer is not installed. Install it with `pip install densemizer`"
|
||||
)
|
||||
|
||||
from densemixer.patching import (
|
||||
apply_olmoe_patch,
|
||||
apply_qwen2_moe_patch,
|
||||
apply_qwen3_moe_patch,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
f"Applying DenseMixer patches for model type: {cfg.model_config_type}"
|
||||
)
|
||||
|
||||
if cfg.model_config_type == "olmoe":
|
||||
apply_olmoe_patch()
|
||||
if cfg.model_config_type == "qwen2_moe":
|
||||
apply_qwen2_moe_patch()
|
||||
if cfg.model_config_type == "qwen3_moe":
|
||||
apply_qwen3_moe_patch()
|
||||
@@ -504,9 +504,6 @@ class ModelLoader:
|
||||
# for some reason, this causes the loss to be off by an order of magnitude
|
||||
# but deepspeed needs this still in bfloat16
|
||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||
if self.cfg.model_config_type == "falcon_h1":
|
||||
# output projection cannot be quantized for Falcon-H1 models
|
||||
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
||||
|
||||
if self.cfg.bnb_config_kwargs:
|
||||
bnb_config.update(self.cfg.bnb_config_kwargs)
|
||||
@@ -521,9 +518,6 @@ class ModelLoader:
|
||||
# Exclude mamba blocks from int8 quantization for jamba
|
||||
if self.cfg.model_config_type == "jamba":
|
||||
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
||||
if self.cfg.model_config_type == "falcon_h1":
|
||||
# output projection cannot be quantized for Falcon-H1 models
|
||||
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
@@ -776,9 +770,6 @@ class ModelLoader:
|
||||
dist_dtype: torch.dtype,
|
||||
before_kbit_train_or_finetune: bool,
|
||||
):
|
||||
dest = {"dtype": dist_dtype}
|
||||
if self.cfg.lora_on_cpu:
|
||||
dest["device"] = "cpu"
|
||||
for name, module in self.model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(dist_dtype)
|
||||
@@ -789,4 +780,4 @@ class ModelLoader:
|
||||
# don't upcast lm_head for btlm
|
||||
continue
|
||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||
module.to(**dest)
|
||||
module.to(dist_dtype)
|
||||
|
||||
@@ -49,11 +49,10 @@ class PatchManager:
|
||||
|
||||
def apply_pre_model_load_patches(self):
|
||||
"""Apply pre-model load patches based on config."""
|
||||
# self._apply_flex_attention_patches()
|
||||
self._apply_flash_attention_patches()
|
||||
self._apply_chunked_cross_entropy_patch()
|
||||
self._apply_fsdp_patches()
|
||||
self._apply_adapter_patches()
|
||||
self._apply_flex_attention_patches()
|
||||
self._apply_model_specific_patches()
|
||||
self._apply_fp8_patches()
|
||||
self._apply_flash_attention_peft_patches()
|
||||
@@ -64,9 +63,6 @@ class PatchManager:
|
||||
self._patch_llama_derived_model()
|
||||
self._apply_mistral_cross_entropy_patch()
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_gemma3_conditional_generation_forward_patch()
|
||||
self._apply_sequence_parallel_patches()
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
@@ -82,15 +78,6 @@ class PatchManager:
|
||||
patch_xformers_attn_over_fa2()
|
||||
self.cfg.flash_attention = True
|
||||
|
||||
def _apply_chunked_cross_entropy_patch(self):
|
||||
if self.cfg.chunked_cross_entropy:
|
||||
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
|
||||
|
||||
if self.cfg.chunked_cross_entropy_num_chunks:
|
||||
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
|
||||
else:
|
||||
patch_chunked_ce_loss_fn()
|
||||
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||
@@ -98,14 +85,6 @@ class PatchManager:
|
||||
|
||||
patch_accelerate_fsdp2()
|
||||
|
||||
# if self.cfg.fsdp_config:
|
||||
# # see transformers#39152
|
||||
# from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||
# patch_training_loop_for_fsdp,
|
||||
# )
|
||||
#
|
||||
# patch_training_loop_for_fsdp()
|
||||
|
||||
def _apply_adapter_patches(self):
|
||||
"""Apply patches for adapter configurations."""
|
||||
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
|
||||
@@ -116,20 +95,14 @@ class PatchManager:
|
||||
def _apply_flex_attention_patches(self):
|
||||
"""Apply patches for flexible attention."""
|
||||
if self.cfg.flex_attention:
|
||||
# from axolotl.monkeypatch.attention.flex_attn import (
|
||||
# patch_flex_make_mask,
|
||||
# patch_flex_wrapper,
|
||||
# )
|
||||
#
|
||||
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
# patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
# patch_flex_make_mask()
|
||||
if self.cfg.sample_packing:
|
||||
from axolotl.core.attention.flex_block_mask import (
|
||||
patch_create_causal_mask,
|
||||
)
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_make_mask,
|
||||
patch_flex_wrapper,
|
||||
)
|
||||
|
||||
patch_create_causal_mask(self.cfg.model_config_type)
|
||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
patch_flex_make_mask()
|
||||
|
||||
def _apply_model_specific_patches(self):
|
||||
"""Apply patches specific to model architectures."""
|
||||
@@ -238,32 +211,6 @@ class PatchManager:
|
||||
has_remote_code=has_remote_code,
|
||||
)
|
||||
|
||||
def _apply_gemma3_conditional_generation_forward_patch(self):
|
||||
"""Apply gemma3 conditional generation forward patch."""
|
||||
if self.model_config.model_type in ["gemma3", "gemma3_text"]:
|
||||
from axolotl.monkeypatch.models.gemma3.modeling import (
|
||||
patch_gemma3_conditional_generation_forward,
|
||||
)
|
||||
|
||||
patch_gemma3_conditional_generation_forward()
|
||||
|
||||
def _apply_sequence_parallel_patches(self):
|
||||
"""Apply sequence parallelism patches."""
|
||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.ring_attn.patch import (
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
)
|
||||
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
||||
|
||||
patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
|
||||
|
||||
def _patch_attention(self):
|
||||
"""Apply attention-specific patches based on model type."""
|
||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
"""
|
||||
chunked ce loss
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss
|
||||
class CEWithChunkedOutputLoss(torch.nn.Module):
|
||||
"""
|
||||
Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.
|
||||
|
||||
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
|
||||
"""
|
||||
|
||||
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
super().__init__()
|
||||
self.num_output_chunks = num_output_chunks
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def compute_cross_entropy(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
normalize: bool = True, # pylint: disable=unused-argument
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Upcast logits to fp32 and compute cross entropy loss.
|
||||
"""
|
||||
return F.cross_entropy(
|
||||
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
logits (List[torch.Tensor]): List of chunked logits of length
|
||||
``self.num_output_chunks``, where each chunk has shape
|
||||
``(batch_size, num_tokens / num_output_chunks, vocab_size)``.
|
||||
labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``.
|
||||
reduction (str): The reduction to apply to the output.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Cross entropy loss of shape (1,).
|
||||
"""
|
||||
|
||||
total_elements = (labels != self.ignore_index).sum()
|
||||
|
||||
# chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)]
|
||||
labels = [
|
||||
target_chunk.reshape(-1)
|
||||
for target_chunk in labels.chunk(self.num_output_chunks, dim=1)
|
||||
]
|
||||
# reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)]
|
||||
logits = [
|
||||
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
|
||||
]
|
||||
|
||||
# compute one chunk at a time
|
||||
total_loss = 0.0
|
||||
for logits_chunk, labels_chunk in zip(logits, labels):
|
||||
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)
|
||||
|
||||
if reduction == "sum":
|
||||
return total_loss
|
||||
return total_loss / total_elements
|
||||
|
||||
|
||||
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
|
||||
loss_fn_ce.compute_cross_entropy = torch.compile(
|
||||
loss_fn_ce.compute_cross_entropy, backend="inductor"
|
||||
)
|
||||
return loss_fn_ce
|
||||
|
||||
|
||||
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
|
||||
|
||||
def chunked_fix_cross_entropy(
|
||||
source,
|
||||
target,
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
logit_chunks = [ # pylint: disable=unnecessary-comprehension
|
||||
chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1)
|
||||
]
|
||||
loss = loss_fn_ce(logit_chunks, target, reduction=reduction)
|
||||
if reduction == "sum":
|
||||
loss = loss / num_items_in_batch
|
||||
return loss
|
||||
|
||||
def for_causal_lm_chunked_loss(
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int = None, # pylint: disable=unused-argument
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
ignore_index: int = -100,
|
||||
shift_labels: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# skip the upcast to float since we handle that in the chunking loss
|
||||
if shift_labels is None:
|
||||
# Shift so that tokens < n predict n
|
||||
labels = F.pad(labels, (0, 1), value=ignore_index)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Skip Flattening the tokens
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(logits.device)
|
||||
loss = chunked_fix_cross_entropy(
|
||||
logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
|
||||
)
|
||||
return loss
|
||||
|
||||
return for_causal_lm_chunked_loss
|
||||
|
||||
|
||||
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
import transformers.loss.loss_utils
|
||||
|
||||
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
|
||||
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
|
||||
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
|
||||
for_causal_lm_chunked_loss
|
||||
)
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Monkeypatch for gemma3 conditional generation forward to fix high loss"""
|
||||
|
||||
|
||||
def patch_gemma3_conditional_generation_forward():
|
||||
# Remove when https://github.com/huggingface/transformers/pull/37208 merged
|
||||
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
Gemma3ForConditionalGeneration,
|
||||
)
|
||||
|
||||
setattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs", False)
|
||||
|
||||
def unpatch():
|
||||
delattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs")
|
||||
|
||||
return unpatch
|
||||
@@ -42,10 +42,6 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||
if has_remote_code:
|
||||
patch_remote(model_name)
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||
# sanity check in case upstream api changes on this
|
||||
assert hasattr(
|
||||
transformers.modeling_flash_attention_utils, "_get_unpad_data"
|
||||
), "transformers api changed for _get_unpad_data for flash attention"
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ RING_ATTN_FUNC_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def create_flash_attn_forward_varlen_llama3(
|
||||
def create_flash_attn_forward(
|
||||
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
|
||||
) -> Callable:
|
||||
"""
|
||||
@@ -71,7 +71,6 @@ def create_flash_attn_forward_varlen_llama3(
|
||||
max_length_q: int | None = None,
|
||||
max_length_k: int | None = None,
|
||||
target_dtype: torch.dtype | None = None,
|
||||
attn_implementation: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -98,7 +97,6 @@ def create_flash_attn_forward_varlen_llama3(
|
||||
max_length_q: Not used in this implementation.
|
||||
max_length_k: Not used in this implementation.
|
||||
target_dtype: Not used in this implementation.
|
||||
attn_implementation: Not used in this implementation.
|
||||
**kwargs: Additional keyword arguments. Not used in this implementation.
|
||||
|
||||
Returns:
|
||||
@@ -163,7 +161,7 @@ def substitute_hf_flash_attn(
|
||||
old_flash_attention_forward = (
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward
|
||||
)
|
||||
new_flash_attention_forward = create_flash_attn_forward_varlen_llama3(
|
||||
new_flash_attention_forward = create_flash_attn_forward(
|
||||
process_group=process_group, ring_attn_func=ring_attn_func
|
||||
)
|
||||
|
||||
|
||||
@@ -9,13 +9,10 @@ sequence parallelism training.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window_size
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -65,96 +62,6 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def create_ring_flash_attention_forward(
|
||||
process_group: dist.ProcessGroup, heads_k_stride: int
|
||||
):
|
||||
from ring_flash_attn import llama3_flash_attn_varlen_func
|
||||
from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS
|
||||
|
||||
def _flash_attention_forward_v3(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor, # pylint: disable=unused-argument
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
dropout: float = 0.0,
|
||||
position_ids: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
use_top_left_mask: bool = False,
|
||||
softcap: Optional[float] = None,
|
||||
deterministic: bool = None,
|
||||
cu_seq_lens_q: Optional[
|
||||
torch.LongTensor
|
||||
] = None, # pylint: disable=unused-argument
|
||||
cu_seq_lens_k: Optional[
|
||||
torch.LongTensor
|
||||
] = None, # pylint: disable=unused-argument
|
||||
max_length_q: Optional[int] = None, # pylint: disable=unused-argument
|
||||
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
|
||||
target_dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument
|
||||
attn_implementation: Optional[str] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
|
||||
causal = is_causal and query_length != 1
|
||||
|
||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size
|
||||
and sliding_window is not None
|
||||
and key_states.shape[1] > sliding_window
|
||||
)
|
||||
flash_kwargs = (
|
||||
{"window_size": (sliding_window, sliding_window)}
|
||||
if use_sliding_windows
|
||||
else {}
|
||||
)
|
||||
|
||||
if deterministic is None:
|
||||
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
flash_kwargs["deterministic"] = deterministic
|
||||
assert (
|
||||
softcap is None
|
||||
), "llama3_flash_attn_varlen_func does not support softcap yet."
|
||||
# flash_kwargs["softcap"] = softcap
|
||||
flash_kwargs["group"] = process_group
|
||||
|
||||
# not sure why attention_mask can be not None...
|
||||
assert causal, "only causal attention is supported yet."
|
||||
batch_size = query_states.size(0)
|
||||
assert batch_size == 1, "varlen data should be processed in advance."
|
||||
|
||||
attn_output = llama3_flash_attn_varlen_func(
|
||||
query_states.squeeze(dim=0),
|
||||
key_states.squeeze(dim=0),
|
||||
value_states.squeeze(dim=0),
|
||||
cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"],
|
||||
cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"],
|
||||
max_seqlen_q=DATA_PARAMS["max_seqlen_q"],
|
||||
max_seqlen_k=DATA_PARAMS["max_seqlen_k"],
|
||||
heads_k_stride=heads_k_stride,
|
||||
local_k_slice=DATA_PARAMS["local_k_slice"],
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.unsqueeze(dim=0)
|
||||
|
||||
return attn_output
|
||||
|
||||
return [
|
||||
_flash_attention_forward_v3,
|
||||
]
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
sequence_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
@@ -211,20 +118,9 @@ def register_ring_attn(
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
# fmt: off
|
||||
import ring_flash_attn.adapters.hf_adapter
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
from ring_flash_attn.adapters.hf_adapter import ( # isort: skip # pylint: disable=unused-import
|
||||
create_ring_flash_attention_forward as create_ring_flash_attention_forward_orig,
|
||||
)
|
||||
|
||||
create_ring_flash_attention_forward_orig = ( # noqa: F811,F841
|
||||
create_ring_flash_attention_forward
|
||||
)
|
||||
ring_flash_attn.adapters.hf_adapter.create_ring_flash_attention_forward = create_ring_flash_attention_forward
|
||||
# fmt: on
|
||||
|
||||
ring_flash_attn.adapters.hf_adapter.substitute_hf_flash_attn(
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||
)
|
||||
elif ring_attn_func is RingAttnFunc.BATCH_RING:
|
||||
@@ -256,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
def patch_prepare_data_loader():
|
||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
||||
|
||||
Raises:
|
||||
Raies:
|
||||
RuntimeError: If source code to patch does not exist.
|
||||
"""
|
||||
original_fn = accelerate.data_loader.prepare_data_loader
|
||||
@@ -272,34 +168,23 @@ def patch_prepare_data_loader():
|
||||
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
|
||||
)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(accelerate.data_loader):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Create a new function from the patched source
|
||||
namespace = {}
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
patched_source, accelerate.data_loader.__dict__, namespace
|
||||
)
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
patched_source, globals(), namespace
|
||||
)
|
||||
|
||||
patched_function = namespace["prepare_data_loader"]
|
||||
original_fn.__code__ = patched_function.__code__
|
||||
|
||||
accelerate.data_loader.prepare_data_loader = patched_function
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||
|
||||
|
||||
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
|
||||
def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||
that includes sequence parallelism with the specified degree.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: The degree of sequence parallelism to use.
|
||||
fsdp: Whether to use FSDP.
|
||||
sequence_parallel_degree (int): The degree of sequence parallelism to use.
|
||||
"""
|
||||
|
||||
def _prepare_device_mesh(self):
|
||||
@@ -322,14 +207,12 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False)
|
||||
)
|
||||
device_ids = list(range(world_size))
|
||||
|
||||
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
|
||||
# parallelism" implementation naming.
|
||||
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
|
||||
# only use "fsdp" and "cp" for the device mesh.
|
||||
# Note that we use "cp" instead of "sp" to match the PyTorch native "context
|
||||
# parallelism" implementation naming
|
||||
return dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(device_ids).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
|
||||
# Replace the original method with our new method
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
"""Monkeypatch for Tiled MLP implementation"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
|
||||
|
||||
try:
|
||||
# Dynamically import the module and MLP class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
model_cls_prefix = "".join(
|
||||
[part.capitalize() for part in model_type.split("_")]
|
||||
)
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
||||
|
||||
if use_original_mlp:
|
||||
mlp_forward = mlp_cls.forward
|
||||
else:
|
||||
|
||||
def generic_mlp_forward(self_, hs):
|
||||
return self_.down_proj(
|
||||
self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs)
|
||||
)
|
||||
|
||||
mlp_forward = torch.compile(generic_mlp_forward)
|
||||
|
||||
def tiled_mlp_forward(self, x):
|
||||
input_shape = x.shape
|
||||
seqlen = input_shape[-2]
|
||||
hidden = input_shape[-1]
|
||||
if cfg_num_shards is None:
|
||||
num_shards = math.ceil(seqlen / hidden)
|
||||
num_shards_tensor = torch.tensor(num_shards, device=x.device)
|
||||
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
|
||||
num_shards = num_shards_tensor.item()
|
||||
else:
|
||||
num_shards = cfg_num_shards
|
||||
|
||||
compute_params = [
|
||||
self.down_proj.weight,
|
||||
self.gate_proj.weight,
|
||||
self.up_proj.weight,
|
||||
]
|
||||
|
||||
down_res = TiledMLP.apply(
|
||||
mlp_forward,
|
||||
self,
|
||||
x,
|
||||
num_shards,
|
||||
compute_params,
|
||||
)
|
||||
return down_res
|
||||
|
||||
mlp_cls.forward = tiled_mlp_forward
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise RuntimeError(
|
||||
f"Could not import MLP class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
@@ -12,13 +12,15 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
if delay_optimizer_creation:
|
||||
self.optimizer = self.accelerator.prepare(self.optimizer)
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
|
||||
|
||||
"""
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
if delay_optimizer_creation:
|
||||
model = self.accelerator.prepare(self.model)
|
||||
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ class ProcessingStrategy:
|
||||
# TODO: check if it's normal to be single image only for common datasets
|
||||
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
||||
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
|
||||
if len(processed_example[image_key]) > 1:
|
||||
if len(processed_example[image_key]) > 0:
|
||||
LOG.warning(
|
||||
f"Found {len(processed_example[image_key])} images in a sample. Using the first one."
|
||||
"If you are using a dataset with multiple images per sample, please convert it to use multi-content Messages."
|
||||
|
||||
@@ -103,7 +103,6 @@ class ChatTemplatePrompter(Prompter):
|
||||
chat_template_kwargs = {
|
||||
"chat_template": self.chat_template,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
**self.chat_template_kwargs,
|
||||
}
|
||||
|
||||
if tools:
|
||||
|
||||
@@ -23,6 +23,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
@@ -218,14 +219,10 @@ def execute_training(
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
gather_outputs=cfg.rl is RLType.GRPO,
|
||||
)
|
||||
)
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
|
||||
# if cfg.bf16:
|
||||
# torch.set_default_dtype(torch.bfloat16)
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
|
||||
@@ -548,6 +545,8 @@ def train(
|
||||
Returns:
|
||||
Tuple of (model, tokenizer) after training
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||
(
|
||||
trainer,
|
||||
|
||||
148
src/axolotl/utils/chat_templates.py
Normal file
148
src/axolotl/utils/chat_templates.py
Normal file
File diff suppressed because one or more lines are too long
@@ -1,20 +0,0 @@
|
||||
"""
|
||||
This module provides functionality for selecting chat templates based on user choices.
|
||||
These templates are used for formatting messages in a conversation.
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
_CHAT_TEMPLATES,
|
||||
extract_chat_template_args,
|
||||
get_chat_template,
|
||||
get_chat_template_from_config,
|
||||
register_chat_template,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_chat_template",
|
||||
"extract_chat_template_args",
|
||||
"get_chat_template_from_config",
|
||||
"register_chat_template",
|
||||
"_CHAT_TEMPLATES",
|
||||
]
|
||||
@@ -1,125 +0,0 @@
|
||||
"""
|
||||
utility functions for chat templates
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
LOG = get_logger("axolotl.utils.chat_templates")
|
||||
|
||||
_JINJA_TEMPLATE_CHOICE = "jinja"
|
||||
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
|
||||
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
|
||||
|
||||
TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "templates")
|
||||
_CHAT_TEMPLATES: dict[str, str] = {}
|
||||
for filename in [f for f in os.listdir(TEMPLATE_DIR) if f.endswith(".jinja")]:
|
||||
with open(os.path.join(TEMPLATE_DIR, filename), "r", encoding="utf-8") as f:
|
||||
_CHAT_TEMPLATES[filename[:-6]] = f.read()
|
||||
|
||||
|
||||
def get_chat_template(
|
||||
user_choice: str,
|
||||
jinja_template: str | None = None,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
|
||||
|
||||
Args:
|
||||
user_choice (str): The user's choice of template.
|
||||
jinja_template (str, optional): The jinja template string or Path to a valid jinja template file. Defaults to None.
|
||||
tokenizer (PreTrainedTokenizerBase, optional): The tokenizer. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The chosen template string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the user_choice is not found in the templates.
|
||||
"""
|
||||
if user_choice == _JINJA_TEMPLATE_CHOICE:
|
||||
if not jinja_template:
|
||||
raise ValueError(
|
||||
f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPLATE_CHOICE}"
|
||||
)
|
||||
if os.path.exists(jinja_template) and os.path.isfile(jinja_template):
|
||||
with open(jinja_template, "r", encoding="utf-8") as file:
|
||||
jinja_template = file.read()
|
||||
return jinja_template
|
||||
|
||||
if user_choice == _DEFAULT_TEMPLATE_CHOICE:
|
||||
if not tokenizer:
|
||||
raise ValueError(
|
||||
f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}"
|
||||
)
|
||||
if not tokenizer.chat_template:
|
||||
raise ValueError(
|
||||
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
|
||||
f"Please add a chat_template in tokenizer config"
|
||||
)
|
||||
return tokenizer.chat_template # type: ignore
|
||||
|
||||
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
|
||||
if not tokenizer:
|
||||
raise ValueError(
|
||||
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
|
||||
)
|
||||
if tokenizer.chat_template:
|
||||
return tokenizer.chat_template # type: ignore
|
||||
|
||||
user_choice = user_choice[
|
||||
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
|
||||
]
|
||||
LOG.warning(
|
||||
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
|
||||
)
|
||||
|
||||
if user_choice in _CHAT_TEMPLATES:
|
||||
return _CHAT_TEMPLATES[user_choice]
|
||||
|
||||
raise ValueError(f"Template '{user_choice}' not found.")
|
||||
|
||||
|
||||
def extract_chat_template_args(cfg, ds_cfg: Dict[str, Any] | None = None):
|
||||
if ds_cfg and ds_cfg.get("chat_template"):
|
||||
chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
|
||||
chat_template_jinja = ds_cfg.get("chat_template_jinja")
|
||||
else:
|
||||
chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
|
||||
chat_template_jinja = cfg.get("chat_template_jinja")
|
||||
return chat_template_choice, chat_template_jinja
|
||||
|
||||
|
||||
def get_chat_template_from_config(
|
||||
cfg,
|
||||
ds_cfg: Dict[str, Any] | None = None,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
) -> str:
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
return get_chat_template(
|
||||
user_choice=chat_template_choice,
|
||||
jinja_template=chat_template_jinja,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
|
||||
def register_chat_template(template_name: str, chat_template: str):
|
||||
"""
|
||||
Registers chat templates.
|
||||
|
||||
Args:
|
||||
template_name (str): The name of the template.
|
||||
chat_template (str): The template string.
|
||||
"""
|
||||
|
||||
if template_name in _CHAT_TEMPLATES:
|
||||
raise ValueError(f"Template '{template_name}' already exists.")
|
||||
|
||||
_CHAT_TEMPLATES[template_name] = chat_template
|
||||
@@ -1,8 +0,0 @@
|
||||
{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' and loop.first %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '### Instruction:
|
||||
' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '### Response:
|
||||
' + message['content'] + eos_token }}{% endif %}{% if not loop.last %}{{ '
|
||||
|
||||
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '
|
||||
|
||||
### Response:
|
||||
' }}{% endif %}
|
||||
@@ -1 +0,0 @@
|
||||
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}
|
||||
@@ -1,4 +0,0 @@
|
||||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '
|
||||
' + message['content'] + '<|im_end|>' + '
|
||||
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
|
||||
' }}{% endif %}
|
||||
@@ -1 +0,0 @@
|
||||
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}
|
||||
@@ -1,210 +0,0 @@
|
||||
{{ bos_token }}{% if documents %}
|
||||
{% set tools = [] %}
|
||||
{%- macro document_turn(documents) -%}
|
||||
{# format documents into chat turn #}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[
|
||||
{"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}
|
||||
]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
|
||||
{
|
||||
"tool_call_id": "0",
|
||||
"results": {
|
||||
{% for doc in documents %}
|
||||
"{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
},
|
||||
"is_error": null
|
||||
}
|
||||
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}
|
||||
{%- macro tool_call_id_to_int(messages, tool_call_id) %}
|
||||
{%- set counter = namespace(value=0) %}
|
||||
{%- set tool_call_id_seen = namespace(value=false) %}
|
||||
{%- for msg in messages %}
|
||||
{%- if msg.tool_calls %}
|
||||
{%- for tool_call in msg.tool_calls %}
|
||||
{%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}
|
||||
{{ counter.value }}
|
||||
{%- set tool_call_id_seen.value = true %}
|
||||
{%- endif %}
|
||||
{%- set counter.value = counter.value + 1 %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endmacro %}
|
||||
{%- macro format_tool_message(messages, tool_msg) -%}
|
||||
{# format tool message #}
|
||||
{
|
||||
"tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",
|
||||
"results": {
|
||||
"0": {{ tool_msg.content|tojson }}
|
||||
},
|
||||
"is_error": null
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}
|
||||
{%- set tool_idx = namespace(value=0) %}
|
||||
{%- set tool_ids_seen = namespace(value=[]) %}
|
||||
{%- set sent_documents = namespace(value=false) %}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble
|
||||
You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.
|
||||
|
||||
Your information cutoff date is June 2024.
|
||||
|
||||
You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.
|
||||
{% if tools or documents %}
|
||||
|
||||
You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.
|
||||
|
||||
## Tool Use
|
||||
Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.
|
||||
|
||||
0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.
|
||||
You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.
|
||||
NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.
|
||||
|
||||
Then carry out your plan by repeatedly executing the following steps.
|
||||
1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.
|
||||
When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.
|
||||
2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.
|
||||
Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".
|
||||
3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.
|
||||
You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.
|
||||
NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.
|
||||
|
||||
You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.
|
||||
|
||||
4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.
|
||||
{% if enable_citations %}
|
||||
|
||||
## Grounding
|
||||
Importantly, note that "Reflection" and "Response" above can be grounded.
|
||||
Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "<co>" and "</co>" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "<co>span</co: 0:[1,2],1:[0]>" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".
|
||||
{% endif %}
|
||||
|
||||
## Available Tools
|
||||
Here is the list of tools that you have available to you.
|
||||
You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.
|
||||
Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).
|
||||
|
||||
```json
|
||||
[
|
||||
{% if documents %}
|
||||
{"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}
|
||||
|
||||
{% endif %}
|
||||
{% for tool in tools %}
|
||||
{"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
]
|
||||
```
|
||||
|
||||
{% endif %}
|
||||
# Default Preamble
|
||||
The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.
|
||||
- Your name is Command.
|
||||
- You are a large language model built by Cohere.
|
||||
- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.
|
||||
- If the input is ambiguous, ask clarifying follow-up questions.
|
||||
- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).
|
||||
- Use LaTeX to generate mathematical notation for complex equations.
|
||||
- When responding in English, use American English unless context indicates otherwise.
|
||||
- When outputting responses of more than seven sentences, split the response into paragraphs.
|
||||
- Prefer the active voice.
|
||||
- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.
|
||||
- Use gender-neutral pronouns for unspecified persons.
|
||||
- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.
|
||||
- Use the third person when asked to write a summary.
|
||||
- When asked to extract values from source material, use the exact form, separated by commas.
|
||||
- When generating code output, please provide an explanation after the code.
|
||||
- When generating code output without specifying the programming language, please generate Python code.
|
||||
- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.
|
||||
{%- if developer_preamble %}
|
||||
|
||||
|
||||
# Developer Preamble
|
||||
The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.
|
||||
{{ developer_preamble }}
|
||||
{%- endif -%}
|
||||
<|END_OF_TURN_TOKEN|>
|
||||
{%- for message in messages %}
|
||||
{%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
|
||||
{%- elif message.role|lower == 'user' %}
|
||||
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}
|
||||
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[
|
||||
{% for tc in message.tool_calls %}
|
||||
{"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}
|
||||
|
||||
{% set tool_idx.value = tool_idx.value + 1 %}
|
||||
{% endfor %}
|
||||
]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}
|
||||
{% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
|
||||
{{ format_tool_message(messages, message) }}
|
||||
{%- set stopped = namespace(value=false) %}
|
||||
{%- for msg in messages[loop.index0 + 1:] %}
|
||||
{%- if not stopped.value and msg.role|lower == 'tool' %},
|
||||
{{ format_tool_message(messages, msg) }}
|
||||
{%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}
|
||||
{%- else %}
|
||||
{%- set stopped.value = true %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>
|
||||
{%- endif %}
|
||||
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
{%- else -%}
|
||||
{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble
|
||||
{% if safety_mode|upper == 'STRICT' -%}
|
||||
You are in strict safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will reject requests to generate content related to violence, hate, misinformation or sex to any amount. You will avoid using profanity. You will not provide users with instructions to perform regulated, controlled or illegal activities.
|
||||
{%- else -%}
|
||||
You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.
|
||||
{%- endif %}
|
||||
|
||||
|
||||
Your information cutoff date is June 2024.
|
||||
|
||||
You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.
|
||||
|
||||
# Default Preamble
|
||||
The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.
|
||||
- Your name is Command.
|
||||
- You are a large language model built by Cohere.
|
||||
- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.
|
||||
- If the input is ambiguous, ask clarifying follow-up questions.
|
||||
- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).
|
||||
- Use LaTeX to generate mathematical notation for complex equations.
|
||||
- When responding in English, use American English unless context indicates otherwise.
|
||||
- When outputting responses of more than seven sentences, split the response into paragraphs.
|
||||
- Prefer the active voice.
|
||||
- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.
|
||||
- Use gender-neutral pronouns for unspecified persons.
|
||||
- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.
|
||||
- Use the third person when asked to write a summary.
|
||||
- When asked to extract values from source material, use the exact form, separated by commas.
|
||||
- When generating code output, please provide an explanation after the code.
|
||||
- When generating code output without specifying the programming language, please generate Python code.
|
||||
- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.
|
||||
{%- if developer_preamble %}
|
||||
|
||||
|
||||
# Developer Preamble
|
||||
The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.
|
||||
{{ developer_preamble }}
|
||||
{%- endif -%}
|
||||
<|END_OF_TURN_TOKEN|>
|
||||
{%- for message in messages %}
|
||||
{%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
|
||||
{%- elif message.role|lower == 'user' %}
|
||||
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
|
||||
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>
|
||||
{%- endif %}
|
||||
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{%- if add_generation_prompt -%}<|START_RESPONSE|>{%- endif %}
|
||||
{% endif %}
|
||||
@@ -1,158 +0,0 @@
|
||||
{{ bos_token }}{% set tools = [] %}
|
||||
{%- macro document_turn(documents) -%}
|
||||
{# format documents into chat turn #}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[
|
||||
{"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}
|
||||
]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
|
||||
{
|
||||
"tool_call_id": "0",
|
||||
"results": {
|
||||
{% for doc in documents %}
|
||||
"{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
},
|
||||
"is_error": null
|
||||
}
|
||||
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}
|
||||
{%- macro tool_call_id_to_int(messages, tool_call_id) %}
|
||||
{%- set counter = namespace(value=0) %}
|
||||
{%- set tool_call_id_seen = namespace(value=false) %}
|
||||
{%- for msg in messages %}
|
||||
{%- if msg.tool_calls %}
|
||||
{%- for tool_call in msg.tool_calls %}
|
||||
{%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}
|
||||
{{ counter.value }}
|
||||
{%- set tool_call_id_seen.value = true %}
|
||||
{%- endif %}
|
||||
{%- set counter.value = counter.value + 1 %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endmacro %}
|
||||
{%- macro format_tool_message(messages, tool_msg) -%}
|
||||
{# format tool message #}
|
||||
{
|
||||
"tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",
|
||||
"results": {
|
||||
"0": {{ tool_msg.content|tojson }}
|
||||
},
|
||||
"is_error": null
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}
|
||||
{%- set tool_idx = namespace(value=0) %}
|
||||
{%- set tool_ids_seen = namespace(value=[]) %}
|
||||
{%- set sent_documents = namespace(value=false) %}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble
|
||||
You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.
|
||||
|
||||
Your information cutoff date is June 2024.
|
||||
|
||||
You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.
|
||||
{% if tools or documents %}
|
||||
|
||||
You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.
|
||||
|
||||
## Tool Use
|
||||
Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.
|
||||
|
||||
0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.
|
||||
You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.
|
||||
NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.
|
||||
|
||||
Then carry out your plan by repeatedly executing the following steps.
|
||||
1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.
|
||||
When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.
|
||||
2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.
|
||||
Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".
|
||||
3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.
|
||||
You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.
|
||||
NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.
|
||||
|
||||
You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.
|
||||
|
||||
4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.
|
||||
{% if enable_citations %}
|
||||
|
||||
## Grounding
|
||||
Importantly, note that "Reflection" and "Response" above can be grounded.
|
||||
Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "<co>" and "</co>" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "<co>span</co: 0:[1,2],1:[0]>" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".
|
||||
{% endif %}
|
||||
|
||||
## Available Tools
|
||||
Here is the list of tools that you have available to you.
|
||||
You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.
|
||||
Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).
|
||||
|
||||
```json
|
||||
[
|
||||
{% if documents %}
|
||||
{"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}
|
||||
|
||||
{% endif %}
|
||||
{% for tool in tools %}
|
||||
{"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
]
|
||||
```
|
||||
|
||||
{% endif %}
|
||||
# Default Preamble
|
||||
The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.
|
||||
- Your name is Command.
|
||||
- You are a large language model built by Cohere.
|
||||
- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.
|
||||
- If the input is ambiguous, ask clarifying follow-up questions.
|
||||
- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).
|
||||
- Use LaTeX to generate mathematical notation for complex equations.
|
||||
- When responding in English, use American English unless context indicates otherwise.
|
||||
- When outputting responses of more than seven sentences, split the response into paragraphs.
|
||||
- Prefer the active voice.
|
||||
- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.
|
||||
- Use gender-neutral pronouns for unspecified persons.
|
||||
- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.
|
||||
- Use the third person when asked to write a summary.
|
||||
- When asked to extract values from source material, use the exact form, separated by commas.
|
||||
- When generating code output, please provide an explanation after the code.
|
||||
- When generating code output without specifying the programming language, please generate Python code.
|
||||
- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.
|
||||
{%- if developer_preamble %}
|
||||
|
||||
|
||||
# Developer Preamble
|
||||
The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.
|
||||
{{ developer_preamble }}
|
||||
{%- endif -%}
|
||||
<|END_OF_TURN_TOKEN|>
|
||||
{%- for message in messages %}
|
||||
{%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
|
||||
{%- elif message.role|lower == 'user' %}
|
||||
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}
|
||||
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[
|
||||
{% for tc in message.tool_calls %}
|
||||
{"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}
|
||||
|
||||
{% set tool_idx.value = tool_idx.value + 1 %}
|
||||
{% endfor %}
|
||||
]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}
|
||||
{% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
|
||||
{{ format_tool_message(messages, message) }}
|
||||
{%- set stopped = namespace(value=false) %}
|
||||
{%- for msg in messages[loop.index0 + 1:] %}
|
||||
{%- if not stopped.value and msg.role|lower == 'tool' %},
|
||||
{{ format_tool_message(messages, msg) }}
|
||||
{%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}
|
||||
{%- else %}
|
||||
{%- set stopped.value = true %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>
|
||||
{%- endif %}
|
||||
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
@@ -1,157 +0,0 @@
|
||||
{{ bos_token }}{%- macro document_turn(documents) -%}
|
||||
{# format documents into chat turn #}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[
|
||||
{"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}
|
||||
]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
|
||||
{
|
||||
"tool_call_id": "0",
|
||||
"results": {
|
||||
{% for doc in documents %}
|
||||
"{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
},
|
||||
"is_error": null
|
||||
}
|
||||
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}
|
||||
{%- macro tool_call_id_to_int(messages, tool_call_id) %}
|
||||
{%- set counter = namespace(value=0) %}
|
||||
{%- set tool_call_id_seen = namespace(value=false) %}
|
||||
{%- for msg in messages %}
|
||||
{%- if msg.tool_calls %}
|
||||
{%- for tool_call in msg.tool_calls %}
|
||||
{%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}
|
||||
{{ counter.value }}
|
||||
{%- set tool_call_id_seen.value = true %}
|
||||
{%- endif %}
|
||||
{%- set counter.value = counter.value + 1 %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endmacro %}
|
||||
{%- macro format_tool_message(messages, tool_msg) -%}
|
||||
{# format tool message #}
|
||||
{
|
||||
"tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",
|
||||
"results": {
|
||||
"0": {{ tool_msg.content|tojson }}
|
||||
},
|
||||
"is_error": null
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}
|
||||
{%- set tool_idx = namespace(value=0) %}
|
||||
{%- set tool_ids_seen = namespace(value=[]) %}
|
||||
{%- set sent_documents = namespace(value=false) %}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble
|
||||
You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.
|
||||
|
||||
Your information cutoff date is June 2024.
|
||||
|
||||
You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.
|
||||
{% if tools or documents %}
|
||||
|
||||
You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.
|
||||
|
||||
## Tool Use
|
||||
Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.
|
||||
|
||||
0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.
|
||||
You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.
|
||||
NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.
|
||||
|
||||
Then carry out your plan by repeatedly executing the following steps.
|
||||
1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.
|
||||
When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.
|
||||
2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.
|
||||
Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".
|
||||
3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.
|
||||
You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.
|
||||
NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.
|
||||
|
||||
You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.
|
||||
|
||||
4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.
|
||||
{% if enable_citations %}
|
||||
|
||||
## Grounding
|
||||
Importantly, note that "Reflection" and "Response" above can be grounded.
|
||||
Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "<co>" and "</co>" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "<co>span</co: 0:[1,2],1:[0]>" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".
|
||||
{% endif %}
|
||||
|
||||
## Available Tools
|
||||
Here is the list of tools that you have available to you.
|
||||
You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.
|
||||
Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).
|
||||
|
||||
```json
|
||||
[
|
||||
{% if documents %}
|
||||
{"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}
|
||||
|
||||
{% endif %}
|
||||
{% for tool in tools %}
|
||||
{"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
]
|
||||
```
|
||||
|
||||
{% endif %}
|
||||
# Default Preamble
|
||||
The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.
|
||||
- Your name is Command.
|
||||
- You are a large language model built by Cohere.
|
||||
- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.
|
||||
- If the input is ambiguous, ask clarifying follow-up questions.
|
||||
- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).
|
||||
- Use LaTeX to generate mathematical notation for complex equations.
|
||||
- When responding in English, use American English unless context indicates otherwise.
|
||||
- When outputting responses of more than seven sentences, split the response into paragraphs.
|
||||
- Prefer the active voice.
|
||||
- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.
|
||||
- Use gender-neutral pronouns for unspecified persons.
|
||||
- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.
|
||||
- Use the third person when asked to write a summary.
|
||||
- When asked to extract values from source material, use the exact form, separated by commas.
|
||||
- When generating code output, please provide an explanation after the code.
|
||||
- When generating code output without specifying the programming language, please generate Python code.
|
||||
- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.
|
||||
{%- if developer_preamble %}
|
||||
|
||||
|
||||
# Developer Preamble
|
||||
The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.
|
||||
{{ developer_preamble }}
|
||||
{%- endif -%}
|
||||
<|END_OF_TURN_TOKEN|>
|
||||
{%- for message in messages %}
|
||||
{%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>
|
||||
{%- elif message.role|lower == 'user' %}
|
||||
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}
|
||||
{%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[
|
||||
{% for tc in message.tool_calls %}
|
||||
{"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}
|
||||
|
||||
{% set tool_idx.value = tool_idx.value + 1 %}
|
||||
{% endfor %}
|
||||
]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}
|
||||
{% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[
|
||||
{{ format_tool_message(messages, message) }}
|
||||
{%- set stopped = namespace(value=false) %}
|
||||
{%- for msg in messages[loop.index0 + 1:] %}
|
||||
{%- if not stopped.value and msg.role|lower == 'tool' %},
|
||||
{{ format_tool_message(messages, msg) }}
|
||||
{%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}
|
||||
{%- else %}
|
||||
{%- set stopped.value = true %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>
|
||||
{%- endif %}
|
||||
{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user