Merge branch 'main' into cj_tokenizer_default_prompt_template
This commit is contained in:
37
.github/workflows/base.yml
vendored
37
.github/workflows/base.yml
vendored
@@ -12,36 +12,24 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "118"
|
- cuda: "121"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
|
cudnn_version: 8
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "121"
|
- cuda: "121"
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
cudnn_version: 8
|
||||||
pytorch: 2.1.2
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "124"
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
@@ -67,6 +55,7 @@ jobs:
|
|||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
build-args: |
|
build-args: |
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||||
|
CUDNN_VERSION=${{ matrix.cudnn_version }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTHON_VERSION=${{ matrix.python_version }}
|
PYTHON_VERSION=${{ matrix.python_version }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
|
|||||||
54
.github/workflows/main.yml
vendored
54
.github/workflows/main.yml
vendored
@@ -13,28 +13,22 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 121
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras: mamba-ssm
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras: mamba-ssm
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -65,6 +59,7 @@ jobs:
|
|||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: |
|
tags: |
|
||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
@@ -75,27 +70,22 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 121
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -134,7 +124,7 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
47
.github/workflows/nightlies.yml
vendored
47
.github/workflows/nightlies.yml
vendored
@@ -12,28 +12,22 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 121
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -75,27 +69,22 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 121
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
28
.github/workflows/tests.yml
vendored
28
.github/workflows/tests.yml
vendored
@@ -72,27 +72,24 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 121
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
|
axolotl_extras: mamba-ssm
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
num_gpus: 1
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
num_gpus: 1
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
|
axolotl_extras: mamba-ssm
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -109,6 +106,7 @@ jobs:
|
|||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
|
|||||||
@@ -334,7 +334,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
|
|||||||
|
|
||||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||||
|
|
||||||
See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
See [the documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
||||||
|
|
||||||
### Config
|
### Config
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN pip install causal_conv1d
|
RUN pip install causal_conv1d
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN pip install causal_conv1d
|
RUN pip install causal_conv1d
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ ARG CUDNN_VERSION="8"
|
|||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
|
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
|
|||||||
62
examples/llama-3/qlora-fsdp-405b.yaml
Normal file
62
examples/llama-3/qlora-fsdp-405b.yaml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3.1-405B
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out/qlora-llama3_1-405b
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
|
||||||
|
sequence_len: 1024
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.00001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|finetune_right_pad_id|>
|
||||||
@@ -1,18 +1,18 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.11.1
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf
|
transformers==4.43.3
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.1
|
||||||
accelerate==0.32.0
|
accelerate==0.32.0
|
||||||
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
|
deepspeed==0.14.4
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
datasets==2.19.1
|
datasets==2.19.1
|
||||||
flash-attn==2.6.1
|
flash-attn==2.6.2
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
@@ -32,6 +32,7 @@ fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e59
|
|||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
|
autoawq>=0.2.5
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
|
|||||||
6
setup.py
6
setup.py
@@ -80,13 +80,13 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.6.1",
|
"flash-attn==2.6.2",
|
||||||
],
|
],
|
||||||
"fused-dense-lib": [
|
"fused-dense-lib": [
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
|
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
|
"deepspeed==0.14.4",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
CLI to run training on a model
|
CLI to run training on a model
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -76,8 +77,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
if parsed_cli_args.download:
|
if parsed_cli_args.download:
|
||||||
model_name = parsed_cfg.base_model
|
model_name = parsed_cfg.base_model
|
||||||
with init_empty_weights():
|
with warnings.catch_warnings():
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
# there are a bunch of useless UserWarnings about
|
||||||
|
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
with init_empty_weights(include_buffers=True):
|
||||||
|
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
Fore.GREEN
|
Fore.GREEN
|
||||||
|
|||||||
14
src/axolotl/common/architectures.py
Normal file
14
src/axolotl/common/architectures.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Common architecture specific constants
|
||||||
|
"""
|
||||||
|
|
||||||
|
MOE_ARCH_BLOCK = {
|
||||||
|
"dbrx": "DbrxFFN",
|
||||||
|
"jamba": "JambaSparseMoeBlock",
|
||||||
|
"jetmoe": [
|
||||||
|
"JetMoeMoA",
|
||||||
|
"JetMoeMoE",
|
||||||
|
],
|
||||||
|
"mixtral": "MixtralSparseMoeBlock",
|
||||||
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import importlib
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@@ -28,9 +29,18 @@ from transformers import (
|
|||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
|
from trl import (
|
||||||
|
CPOConfig,
|
||||||
|
CPOTrainer,
|
||||||
|
DPOConfig,
|
||||||
|
DPOTrainer,
|
||||||
|
KTOConfig,
|
||||||
|
KTOTrainer,
|
||||||
|
ORPOConfig,
|
||||||
|
ORPOTrainer,
|
||||||
|
)
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -265,7 +275,89 @@ class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
@dataclass
|
||||||
|
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||||
|
"""
|
||||||
|
CPO config for CPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
simpo_gamma: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "simpo gamma parameter"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for scheduler setup in CausalTrainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: AxolotlTrainingArguments
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||||
|
passed as an argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_training_steps (int): The number of training steps to do.
|
||||||
|
optimizer (torch.optim.Optimizer): The training optimizer
|
||||||
|
"""
|
||||||
|
use_cosine_quadratic = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.lr_quadratic_warmup is True
|
||||||
|
)
|
||||||
|
|
||||||
|
use_cosine_min_lr = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.cosine_min_lr_ratio is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
|
# fmt: on
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
|
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
else:
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||||
"""
|
"""
|
||||||
Extend the base Trainer for axolotl helpers
|
Extend the base Trainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -383,68 +475,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
|
|
||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
|
||||||
passed as an argument.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_training_steps (int): The number of training steps to do.
|
|
||||||
optimizer (torch.optim.Optimizer): The training optimizer
|
|
||||||
"""
|
|
||||||
use_cosine_quadratic = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.lr_quadratic_warmup is True
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cosine_min_lr = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.cosine_min_lr_ratio is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
|
||||||
# fmt: on
|
|
||||||
if use_cosine_quadratic:
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
|
||||||
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
|
||||||
else:
|
|
||||||
if use_cosine_quadratic:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
if self.args.multipack_real_batches:
|
if self.args.multipack_real_batches:
|
||||||
@@ -809,6 +839,14 @@ class AxolotlTrainer(Trainer):
|
|||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
|
|
||||||
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
|
# make sure the checkpoint dir exists, since trainer is flakey
|
||||||
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -908,7 +946,7 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(DPOTrainer):
|
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -969,7 +1007,7 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -977,7 +1015,7 @@ class AxolotlORPOTrainer(ORPOTrainer):
|
|||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(KTOTrainer):
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -985,6 +1023,14 @@ class AxolotlKTOTrainer(KTOTrainer):
|
|||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Base class for trainer builder
|
Base class for trainer builder
|
||||||
@@ -1707,6 +1753,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
|
if self.cfg.rl_beta:
|
||||||
|
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||||
if self.cfg.orpo_alpha:
|
if self.cfg.orpo_alpha:
|
||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
@@ -1715,9 +1763,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
if self.cfg.rpo_alpha is not None:
|
if self.cfg.rpo_alpha is not None:
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
|
if self.cfg.rl == "simpo":
|
||||||
|
training_args_cls = AxolotlCPOConfig
|
||||||
|
training_args_kwargs["loss_type"] = "simpo"
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
|
||||||
|
if self.cfg.cpo_alpha is not None:
|
||||||
|
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||||
|
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = AxolotlORPOConfig
|
training_args_cls = AxolotlORPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
@@ -1725,7 +1781,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl == "kto":
|
if self.cfg.rl == "kto":
|
||||||
training_args_cls = AxolotlKTOConfig
|
training_args_cls = AxolotlKTOConfig
|
||||||
|
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
|
||||||
training_args_kwargs["desirable_weight"] = (
|
training_args_kwargs["desirable_weight"] = (
|
||||||
self.cfg.kto_desirable_weight or 1.0
|
self.cfg.kto_desirable_weight or 1.0
|
||||||
)
|
)
|
||||||
@@ -1771,7 +1826,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
# these aren't used for the ORPO trainer
|
||||||
@@ -1785,6 +1839,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
elif self.cfg.rl in ["kto"]:
|
elif self.cfg.rl in ["kto"]:
|
||||||
trainer_cls = AxolotlKTOTrainer
|
trainer_cls = AxolotlKTOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
elif self.cfg.rl in ["simpo"]:
|
||||||
|
trainer_cls = AxolotlCPOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
|
|||||||
@@ -6,14 +6,16 @@ import logging
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
|
||||||
|
# Configure the logger
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplatePrompter(Prompter):
|
class ChatTemplatePrompter(Prompter):
|
||||||
"""prompter for HF chat templates"""
|
"""Prompter for HF chat templates"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -22,6 +24,8 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
max_length=2048,
|
max_length=2048,
|
||||||
message_field_role: str = "from",
|
message_field_role: str = "from",
|
||||||
message_field_content: str = "value",
|
message_field_content: str = "value",
|
||||||
|
message_field_training: str = "train",
|
||||||
|
message_field_training_detail: str = "train_detail",
|
||||||
roles: Optional[Dict[str, List[str]]] = None,
|
roles: Optional[Dict[str, List[str]]] = None,
|
||||||
drop_system_message: bool = False,
|
drop_system_message: bool = False,
|
||||||
):
|
):
|
||||||
@@ -37,6 +41,8 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
}
|
}
|
||||||
self.message_field_role = message_field_role
|
self.message_field_role = message_field_role
|
||||||
self.message_field_content = message_field_content
|
self.message_field_content = message_field_content
|
||||||
|
self.message_field_training = message_field_training
|
||||||
|
self.message_field_training_detail = message_field_training_detail
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
@@ -47,6 +53,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
{
|
{
|
||||||
"role": self.roles[t[self.message_field_role]],
|
"role": self.roles[t[self.message_field_role]],
|
||||||
"content": t[self.message_field_content],
|
"content": t[self.message_field_content],
|
||||||
|
"training": t.get(self.message_field_training, None),
|
||||||
}
|
}
|
||||||
for t in conversation
|
for t in conversation
|
||||||
]
|
]
|
||||||
@@ -62,6 +69,108 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_offsets_for_train_detail(
|
||||||
|
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
|
||||||
|
) -> List[int]:
|
||||||
|
tokenized_output = self.tokenizer(
|
||||||
|
text, return_offsets_mapping=True, add_special_tokens=False
|
||||||
|
)
|
||||||
|
tokens = tokenized_output.tokens()
|
||||||
|
token_offsets = tokenized_output["offset_mapping"]
|
||||||
|
|
||||||
|
LOG.debug(f"Tokenizing text: {text}")
|
||||||
|
LOG.debug(f"Tokens: {tokens}")
|
||||||
|
# Adjust the end offsets. For some reason by default they are set to the same value as the start offsets.
|
||||||
|
for i in range(len(token_offsets) - 1):
|
||||||
|
token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1)
|
||||||
|
# Ensure the last token's end offset is set correctly
|
||||||
|
token_offsets[-1] = (token_offsets[-1][0], len(text) - 1)
|
||||||
|
LOG.debug(f"Token offsets: {token_offsets}")
|
||||||
|
|
||||||
|
# Initialize all offsets as IGNORE_TOKEN_ID (not trained)
|
||||||
|
result = [IGNORE_TOKEN_ID] * len(token_offsets)
|
||||||
|
|
||||||
|
# Adjust train_details to align with token boundaries
|
||||||
|
adjusted_train_details = self.adjust_train_details(train_details, token_offsets)
|
||||||
|
|
||||||
|
for idx, (start, end) in enumerate(token_offsets):
|
||||||
|
for detail in adjusted_train_details:
|
||||||
|
# Check if the token is completely within the detail's range
|
||||||
|
if start >= detail["begin_offset"] and end <= detail["end_offset"]:
|
||||||
|
if detail["train"] or not mask_untrainable:
|
||||||
|
result[idx] = start
|
||||||
|
LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training")
|
||||||
|
else:
|
||||||
|
LOG.debug(
|
||||||
|
f"Token {idx} ({tokens[idx]}) marked as non-trainable"
|
||||||
|
)
|
||||||
|
elif start < detail["end_offset"] and end > detail["begin_offset"]:
|
||||||
|
# Token partially overlaps with detail, always mark as non-trainable
|
||||||
|
LOG.debug(
|
||||||
|
f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Final result: {result}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def adjust_train_details(
|
||||||
|
self, train_details: List[Dict], token_offsets: List[tuple]
|
||||||
|
) -> List[Dict]:
|
||||||
|
adjusted_details = []
|
||||||
|
for detail in train_details:
|
||||||
|
begin_offset = detail["begin_offset"]
|
||||||
|
end_offset = detail["end_offset"]
|
||||||
|
|
||||||
|
# Find the first token that starts after or at the begin_offset
|
||||||
|
begin_token = next(
|
||||||
|
(
|
||||||
|
i
|
||||||
|
for i, (t_start, t_end) in enumerate(token_offsets)
|
||||||
|
if t_start >= begin_offset
|
||||||
|
),
|
||||||
|
len(token_offsets),
|
||||||
|
)
|
||||||
|
if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset:
|
||||||
|
begin_token -= 1
|
||||||
|
|
||||||
|
# Find the last token that ends before or at the end_offset
|
||||||
|
end_token = next(
|
||||||
|
(
|
||||||
|
i
|
||||||
|
for i in range(len(token_offsets) - 1, -1, -1)
|
||||||
|
if token_offsets[i][1] <= end_offset
|
||||||
|
),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
end_token < len(token_offsets) - 1
|
||||||
|
and token_offsets[end_token + 1][0] < end_offset
|
||||||
|
):
|
||||||
|
end_token += 1
|
||||||
|
|
||||||
|
if begin_token <= end_token:
|
||||||
|
adjusted_begin = token_offsets[begin_token][0]
|
||||||
|
adjusted_end = token_offsets[end_token][1]
|
||||||
|
|
||||||
|
if adjusted_begin != begin_offset or adjusted_end != end_offset:
|
||||||
|
LOG.warning(
|
||||||
|
f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})"
|
||||||
|
)
|
||||||
|
|
||||||
|
adjusted_details.append(
|
||||||
|
{
|
||||||
|
"begin_offset": adjusted_begin,
|
||||||
|
"end_offset": adjusted_end,
|
||||||
|
"train": detail["train"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail."
|
||||||
|
)
|
||||||
|
|
||||||
|
return adjusted_details
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -70,6 +179,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
_messages = "conversations"
|
_messages = "conversations"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompter,
|
||||||
|
tokenizer,
|
||||||
|
train_on_inputs,
|
||||||
|
sequence_len,
|
||||||
|
roles_to_train=None,
|
||||||
|
train_on_eos="last",
|
||||||
|
):
|
||||||
|
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||||
|
self.roles_to_train = roles_to_train if roles_to_train is not None else []
|
||||||
|
self.train_on_eos = train_on_eos
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def messages(self):
|
def messages(self):
|
||||||
return self._messages
|
return self._messages
|
||||||
@@ -79,65 +201,172 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
self._messages = messages
|
self._messages = messages
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = prompt[self.messages]
|
||||||
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
|
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
|
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||||
|
|
||||||
if not self.train_on_inputs:
|
last_eos_idx = -1
|
||||||
user_prompt_len = len(prompt_ids)
|
for index, turn in enumerate(turns):
|
||||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
role = turn.get(self.prompter.message_field_role)
|
||||||
else:
|
content = turn.get(self.prompter.message_field_content)
|
||||||
labels = input_ids
|
train_turn = turn.get(self.prompter.message_field_training)
|
||||||
|
train_detail = turn.get(self.prompter.message_field_training_detail)
|
||||||
|
|
||||||
tokenized_prompt = {
|
LOG.debug(
|
||||||
|
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
||||||
|
)
|
||||||
|
|
||||||
|
should_train = (
|
||||||
|
train_turn
|
||||||
|
if train_turn is not None
|
||||||
|
else bool(train_detail is not None)
|
||||||
|
if train_detail is not None
|
||||||
|
else self.train_on_inputs or role in self.roles_to_train
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Should train: {should_train}")
|
||||||
|
|
||||||
|
turn_start_idx, turn_end_idx = self.find_turn(
|
||||||
|
conversation_ids=input_ids, turn=index, turn_content=turn
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||||
|
|
||||||
|
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||||
|
if train_detail:
|
||||||
|
token_offsets = self.prompter.get_offsets_for_train_detail(
|
||||||
|
content, train_detail
|
||||||
|
)
|
||||||
|
LOG.debug(f"Token offsets: {token_offsets}")
|
||||||
|
for i, offset in enumerate(token_offsets):
|
||||||
|
if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len(
|
||||||
|
input_ids
|
||||||
|
):
|
||||||
|
labels[turn_start_idx + i] = input_ids[turn_start_idx + i]
|
||||||
|
LOG.debug(
|
||||||
|
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
labels[turn_start_idx:turn_end_idx] = input_ids[
|
||||||
|
turn_start_idx:turn_end_idx
|
||||||
|
]
|
||||||
|
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
|
||||||
|
|
||||||
|
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||||
|
|
||||||
|
# Handle EOS token
|
||||||
|
eos_idx = self.find_eos_token(input_ids, turn_end_idx)
|
||||||
|
if eos_idx == turn_end_idx:
|
||||||
|
last_eos_idx = eos_idx
|
||||||
|
if self.train_on_eos == "all" or (
|
||||||
|
self.train_on_eos == "turn" and should_train
|
||||||
|
):
|
||||||
|
labels[eos_idx] = input_ids[eos_idx]
|
||||||
|
LOG.debug(f"EOS token set for training at index {eos_idx}")
|
||||||
|
else:
|
||||||
|
LOG.debug(
|
||||||
|
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle 'last' option for train_on_eos
|
||||||
|
if self.train_on_eos == "last" and last_eos_idx != -1:
|
||||||
|
labels[last_eos_idx] = input_ids[last_eos_idx]
|
||||||
|
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
|
||||||
|
|
||||||
|
LOG.debug(f"Final labels: {labels}")
|
||||||
|
|
||||||
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
"attention_mask": [1] * len(input_ids),
|
"attention_mask": [1] * len(input_ids),
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokenized_prompt
|
def find_eos_token(self, input_ids, start_idx):
|
||||||
|
eos_token_id = self.tokenizer.eos_token_id
|
||||||
|
for i in range(start_idx, len(input_ids)):
|
||||||
|
if input_ids[i] == eos_token_id:
|
||||||
|
return i
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def find_turn(self, conversation_ids, turn, turn_content):
|
||||||
|
"""
|
||||||
|
Locate the starting and ending indices of the specified turn in a conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_ids (list[int]): Token IDs representing the conversation.
|
||||||
|
turn (int): The turn number to locate (based on EOS tokens).
|
||||||
|
turn_content (str): String containing the content of the turn.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
|
||||||
|
Returns (-1, -1) if the turn content is not found.
|
||||||
|
"""
|
||||||
|
content = turn_content.get(self.prompter.message_field_content, "")
|
||||||
|
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
|
||||||
|
|
||||||
|
eos_token_id = self.tokenizer.eos_token_id
|
||||||
|
eos_count = 0
|
||||||
|
start_search_idx = 0
|
||||||
|
|
||||||
|
# Locate the starting index after the specified number of EOS tokens
|
||||||
|
for i, token_id in enumerate(conversation_ids):
|
||||||
|
if token_id == eos_token_id:
|
||||||
|
eos_count += 1
|
||||||
|
if eos_count == turn:
|
||||||
|
start_search_idx = (
|
||||||
|
i + 1
|
||||||
|
) # Start searching after the specified turn's EOS token
|
||||||
|
break
|
||||||
|
|
||||||
|
# Find the start index of the content within the conversation
|
||||||
|
start_idx = -1
|
||||||
|
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
|
||||||
|
if conversation_ids[i : i + len(content_ids)] == content_ids:
|
||||||
|
start_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if start_idx != -1:
|
||||||
|
end_idx = start_idx + len(content_ids)
|
||||||
|
else:
|
||||||
|
end_idx = -1
|
||||||
|
|
||||||
|
return start_idx, end_idx
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
return prompt[self.messages]
|
return prompt[self.messages]
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
chat_template = (
|
ds_cfg = ds_cfg or {}
|
||||||
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
chat_template = ds_cfg.get("chat_template", "chatml")
|
||||||
)
|
|
||||||
message_field_role = (
|
|
||||||
ds_cfg["message_field_role"]
|
|
||||||
if ds_cfg and "message_field_role" in ds_cfg
|
|
||||||
else "from"
|
|
||||||
)
|
|
||||||
message_field_content = (
|
|
||||||
ds_cfg["message_field_content"]
|
|
||||||
if ds_cfg and "message_field_content" in ds_cfg
|
|
||||||
else "value"
|
|
||||||
)
|
|
||||||
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
|
||||||
drop_system_message = (
|
|
||||||
ds_cfg["drop_system_message"]
|
|
||||||
if ds_cfg and "drop_system_message" in ds_cfg
|
|
||||||
else False
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_template_str = chat_templates(chat_template, tokenizer=tokenizer)
|
chat_template_str = chat_templates(chat_template, tokenizer=tokenizer)
|
||||||
LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---")
|
LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---")
|
||||||
|
|
||||||
|
prompter_params = {
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
||||||
|
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
||||||
|
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
||||||
|
"message_field_training": ds_cfg.get("message_field_training", "training"),
|
||||||
|
"message_field_training_detail": ds_cfg.get(
|
||||||
|
"message_field_training_detail", "train_detail"
|
||||||
|
),
|
||||||
|
"roles": ds_cfg.get("roles"),
|
||||||
|
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
strategy_params = {
|
||||||
|
"train_on_inputs": cfg.train_on_inputs,
|
||||||
|
"sequence_len": cfg.sequence_len,
|
||||||
|
"roles_to_train": ds_cfg.get("roles_to_train"),
|
||||||
|
"train_on_eos": ds_cfg.get("train_on_eos", "last"),
|
||||||
|
}
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||||
tokenizer,
|
|
||||||
chat_template_str,
|
|
||||||
message_field_role=message_field_role,
|
|
||||||
message_field_content=message_field_content,
|
|
||||||
roles=roles,
|
|
||||||
drop_system_message=drop_system_message,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
)
|
||||||
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
|
||||||
|
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||||
strategy.messages = ds_cfg["field_messages"]
|
strategy.messages = ds_cfg["field_messages"]
|
||||||
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def default(
|
|||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
chosen_strip_index = result["chosen"].find(chosen["content"])
|
chosen_strip_index = result["chosen"].find(chosen["content"])
|
||||||
result["chosen"] = result["chosen"][chosen_strip_index:]
|
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||||
|
|
||||||
result["rejected"] = tokenizer.apply_chat_template(
|
result["rejected"] = tokenizer.apply_chat_template(
|
||||||
[rejected],
|
[rejected],
|
||||||
@@ -71,7 +71,7 @@ def default(
|
|||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
rejected_strip_index = result["rejected"].find(rejected["content"])
|
rejected_strip_index = result["rejected"].find(rejected["content"])
|
||||||
result["rejected"] = result["rejected"][rejected_strip_index:]
|
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -212,26 +212,23 @@ def train(
|
|||||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||||
trainer.accelerator.wait_for_everyone()
|
trainer.accelerator.wait_for_everyone()
|
||||||
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
trainer.save_model(cfg.output_dir)
|
||||||
|
|
||||||
# the trainer saved a model.safetensors file in the output directory,
|
# the trainer saved a model.safetensors file in the output directory,
|
||||||
# but it is a proxy model and should be deleted
|
# but it is most likely a proxy model and if so, should be deleted
|
||||||
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
|
maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
|
maybe_sharded = os.path.exists(
|
||||||
|
os.path.join(cfg.output_dir, "model.safetensors.index.json")
|
||||||
|
)
|
||||||
|
|
||||||
|
if maybe_proxy and maybe_sharded:
|
||||||
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
|
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
|
||||||
LOG.info("This is a proxy model and should be deleted")
|
LOG.info("This is a proxy model and should be deleted")
|
||||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
try:
|
||||||
|
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
|
||||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
|
||||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
|
||||||
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
|
||||||
# The model name saved is `pytorch_model.bin`
|
|
||||||
unwrapped_model.save_pretrained(
|
|
||||||
cfg.output_dir,
|
|
||||||
is_main_process=trainer.accelerator.is_main_process,
|
|
||||||
save_function=trainer.accelerator.save,
|
|
||||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
|
||||||
)
|
|
||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|||||||
@@ -123,6 +123,10 @@ class SFTDataset(BaseModel):
|
|||||||
field_messages: Optional[str] = None
|
field_messages: Optional[str] = None
|
||||||
message_field_role: Optional[str] = None
|
message_field_role: Optional[str] = None
|
||||||
message_field_content: Optional[str] = None
|
message_field_content: Optional[str] = None
|
||||||
|
message_field_training: Optional[str] = None
|
||||||
|
message_field_training_detail: Optional[str] = None
|
||||||
|
roles_to_train: Optional[List[str]] = None
|
||||||
|
train_on_eos: Optional[str] = None
|
||||||
|
|
||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
drop_system_message: Optional[bool] = None
|
drop_system_message: Optional[bool] = None
|
||||||
@@ -179,6 +183,7 @@ class RLType(str, Enum):
|
|||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
|
simpo = "simpo" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
@@ -653,6 +658,8 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
orpo_alpha: Optional[float] = None
|
orpo_alpha: Optional[float] = None
|
||||||
rpo_alpha: Optional[float] = None
|
rpo_alpha: Optional[float] = None
|
||||||
|
simpo_gamma: Optional[float] = None
|
||||||
|
cpo_alpha: Optional[float] = None
|
||||||
|
|
||||||
kto_desirable_weight: Optional[float] = None
|
kto_desirable_weight: Optional[float] = None
|
||||||
kto_undesirable_weight: Optional[float] = None
|
kto_undesirable_weight: Optional[float] = None
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from axolotl.prompters import (
|
|||||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||||
from axolotl.utils.data.utils import md5
|
from axolotl.utils.data.utils import md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
@@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl")
|
|||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_local_main_process()):
|
||||||
if cfg.test_datasets:
|
if cfg.test_datasets:
|
||||||
train_dataset, _, prompters = load_prepare_datasets(
|
train_dataset, _, prompters = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
||||||
@@ -170,6 +170,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
if dataset:
|
if dataset:
|
||||||
|
# This is for the case where we already loaded a pretokenized dataset from the hub
|
||||||
...
|
...
|
||||||
elif (
|
elif (
|
||||||
cfg.dataset_prepared_path
|
cfg.dataset_prepared_path
|
||||||
@@ -198,6 +199,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
def for_d_in_datasets(dataset_configs):
|
def for_d_in_datasets(dataset_configs):
|
||||||
for dataset in dataset_configs:
|
for dataset in dataset_configs:
|
||||||
if dataset.name and isinstance(dataset.name, list):
|
if dataset.name and isinstance(dataset.name, list):
|
||||||
|
# load_dataset doesn't properly handle multiple named configurations
|
||||||
|
# at the same time for a given dataset
|
||||||
for name in dataset.name:
|
for name in dataset.name:
|
||||||
yield DictDefault({**dataset, "name": name})
|
yield DictDefault({**dataset, "name": name})
|
||||||
else:
|
else:
|
||||||
@@ -208,6 +211,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
try:
|
try:
|
||||||
|
# this is just a basic check to see if the path is a
|
||||||
|
# valid HF dataset that's loadable
|
||||||
load_dataset(
|
load_dataset(
|
||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ def is_main_process():
|
|||||||
return dist.get_rank() == 0
|
return dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def is_local_main_process():
|
||||||
|
return PartialState().is_main_process
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size():
|
||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from transformers import ( # noqa: F401
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
AwqConfig,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
@@ -36,6 +37,7 @@ from transformers import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
@@ -510,7 +512,25 @@ def load_model(
|
|||||||
model_kwargs["quantization_config"] = GPTQConfig(
|
model_kwargs["quantization_config"] = GPTQConfig(
|
||||||
**model_config.quantization_config
|
**model_config.quantization_config
|
||||||
)
|
)
|
||||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
if (
|
||||||
|
cfg.adapter in ["qlora", "lora"]
|
||||||
|
and hasattr(model_config, "quantization_config")
|
||||||
|
and model_config.quantization_config["quant_method"]
|
||||||
|
in ["gptq", "awq", "bitsandbytes"]
|
||||||
|
):
|
||||||
|
if model_config.quantization_config["quant_method"] == "gptq":
|
||||||
|
model_kwargs["quantization_config"] = GPTQConfig(
|
||||||
|
**model_config.quantization_config
|
||||||
|
)
|
||||||
|
elif model_config.quantization_config["quant_method"] == "awq":
|
||||||
|
model_kwargs["quantization_config"] = AwqConfig(
|
||||||
|
**model_config.quantization_config
|
||||||
|
)
|
||||||
|
elif model_config.quantization_config["quant_method"] == "bitsandbytes":
|
||||||
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
**model_config.quantization_config
|
||||||
|
)
|
||||||
|
elif cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"llm_int8_threshold": 6.0,
|
"llm_int8_threshold": 6.0,
|
||||||
@@ -619,7 +639,7 @@ def load_model(
|
|||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
):
|
):
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
@@ -701,7 +721,7 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
# disabling either of these two still leads to VRAM spike before setting back down
|
# disabling either of these two still leads to VRAM spike before setting back down
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
@@ -785,12 +805,14 @@ def load_model(
|
|||||||
set_z3_leaf_modules,
|
set_z3_leaf_modules,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.model_config_type == "mixtral":
|
if cfg.model_config_type in MOE_ARCH_BLOCK:
|
||||||
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
|
set_z3_leaf_modules(
|
||||||
set_z3_leaf_modules(model, [moe_block])
|
model,
|
||||||
elif cfg.model_config_type == "dbrx":
|
[
|
||||||
moe_block = get_module_class_from_name(model, "DbrxFFN")
|
get_module_class_from_name(model, module_name)
|
||||||
set_z3_leaf_modules(model, [moe_block])
|
for module_name in MOE_ARCH_BLOCK[cfg.model_config_type]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||||
@@ -804,6 +826,9 @@ def load_model(
|
|||||||
# make sure everything is in the same dtype
|
# make sure everything is in the same dtype
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if cfg.adapter in ["lora", "qlora"]:
|
if cfg.adapter in ["lora", "qlora"]:
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable(
|
model.gradient_checkpointing_enable(
|
||||||
@@ -838,6 +863,9 @@ def load_model(
|
|||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
skip_move_to_device = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.ddp
|
cfg.ddp
|
||||||
and not load_in_8bit
|
and not load_in_8bit
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -389,6 +390,19 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|
||||||
|
|
||||||
|
def setup_deepspeed_env(cfg, stage=None):
|
||||||
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||||
|
if cfg.bf16:
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||||
|
elif cfg.fp16:
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
||||||
|
if stage:
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||||
|
if stage == 3:
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def setup_fsdp_envs(cfg):
|
def setup_fsdp_envs(cfg):
|
||||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||||
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
||||||
@@ -415,8 +429,14 @@ def prepare_optim_env(cfg):
|
|||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
stage = None
|
||||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
# check if the cfg.deepspeed is a file
|
||||||
|
if os.path.isfile(cfg.deepspeed):
|
||||||
|
# parse with json
|
||||||
|
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
|
||||||
|
deepspeed_config = json.load(fin)
|
||||||
|
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
||||||
|
setup_deepspeed_env(cfg, stage=stage)
|
||||||
|
|
||||||
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
|
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
|
||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||||
@@ -425,7 +445,7 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
20
tests/e2e/test_imports.py
Normal file
20
tests/e2e/test_imports.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
test module to import various submodules that have historically broken due to dependency issues
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestImports(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test class to import various submodules that have historically broken due to dependency issues
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_import_causal_trainer(self):
|
||||||
|
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
|
||||||
|
HFCausalTrainerBuilder,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_import_rl_trainer(self):
|
||||||
|
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
|
||||||
|
HFRLTrainerBuilder,
|
||||||
|
)
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
tests for chat_template prompt strategy
|
tests for chat_template prompt strategy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -13,33 +14,24 @@ from axolotl.prompt_strategies.chat_template import (
|
|||||||
ChatTemplateStrategy,
|
ChatTemplateStrategy,
|
||||||
load,
|
load,
|
||||||
)
|
)
|
||||||
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="assistant_dataset")
|
@pytest.fixture(name="assistant_dataset")
|
||||||
def fixture_assistant_dataset():
|
def fixture_assistant_dataset():
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
return Dataset.from_list(
|
return Dataset.from_list(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{"role": "user", "content": "hello"},
|
||||||
"role": "user",
|
{"role": "assistant", "content": "hello"},
|
||||||
"content": "hello",
|
{"role": "user", "content": "goodbye"},
|
||||||
},
|
{"role": "assistant", "content": "goodbye"},
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "hello",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "goodbye",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "goodbye",
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -53,22 +45,28 @@ def fixture_sharegpt_dataset():
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{"from": "human", "value": "hello"},
|
||||||
"from": "human",
|
{"from": "gpt", "value": "hello"},
|
||||||
"value": "hello",
|
{"from": "human", "value": "goodbye"},
|
||||||
},
|
{"from": "gpt", "value": "goodbye"},
|
||||||
{
|
]
|
||||||
"from": "gpt",
|
}
|
||||||
"value": "hello",
|
]
|
||||||
},
|
)
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "goodbye",
|
@pytest.fixture(name="basic_dataset")
|
||||||
},
|
def fixture_basic_dataset():
|
||||||
{
|
# pylint: disable=duplicate-code
|
||||||
"from": "gpt",
|
return Dataset.from_list(
|
||||||
"value": "goodbye",
|
[
|
||||||
},
|
{
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "You are an AI assistant."},
|
||||||
|
{"from": "human", "value": "Hello"},
|
||||||
|
{"from": "assistant", "value": "Hi there!"},
|
||||||
|
{"from": "human", "value": "How are you?"},
|
||||||
|
{"from": "assistant", "value": "I'm doing well, thank you!"},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -77,8 +75,7 @@ def fixture_sharegpt_dataset():
|
|||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer")
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
def fixture_llama3_tokenizer():
|
def fixture_llama3_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||||
tokenizer.eos_token = "<|eot_id|>"
|
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
@@ -130,13 +127,607 @@ class TestChatTemplates:
|
|||||||
assert chat_template_str == "test_template"
|
assert chat_template_str == "test_template"
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatTemplateConfigurations:
|
||||||
|
"""
|
||||||
|
Test class for various configurations of ChatTemplateStrategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_sublist(full_list, sub_list):
|
||||||
|
token_count = len(sub_list)
|
||||||
|
for index in range(len(full_list) - token_count + 1):
|
||||||
|
if full_list[index : index + token_count] == sub_list:
|
||||||
|
return index
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_inputs=True")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=True,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that assistant responses are labeled
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
# Check the behavior of human inputs
|
||||||
|
human_inputs = ["Hello", "How are you?"]
|
||||||
|
for input_text in human_inputs:
|
||||||
|
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, input_ids)
|
||||||
|
labeled = all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(input_ids)]
|
||||||
|
)
|
||||||
|
LOG.debug(
|
||||||
|
f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug("Full labels: %s", labels)
|
||||||
|
LOG.debug("Full input_ids: %s", input_ids)
|
||||||
|
|
||||||
|
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_inputs=False")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that only assistant responses are labeled
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
# Verify that human inputs are not labeled
|
||||||
|
human_inputs = ["Hello", "How are you?"]
|
||||||
|
for input_text in human_inputs:
|
||||||
|
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, input_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(input_ids)]
|
||||||
|
), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
|
||||||
|
|
||||||
|
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing roles_to_train with assistant only")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that only assistant responses are labeled
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing roles_to_train with all roles")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=True,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["human", "assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that all responses are labeled (except for special tokens)
|
||||||
|
all_responses = [
|
||||||
|
"Hello",
|
||||||
|
"Hi there!",
|
||||||
|
"How are you?",
|
||||||
|
"I'm doing well, thank you!",
|
||||||
|
]
|
||||||
|
for response in all_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with empty roles_to_train")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=[],
|
||||||
|
train_on_eos="none", # Add this line
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
|
||||||
|
# Verify that no labels are set when roles_to_train is empty
|
||||||
|
LOG.debug("Full labels: %s", labels)
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID for label in labels
|
||||||
|
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
|
||||||
|
|
||||||
|
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='all'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="all",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
eos_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||||
|
for eos_idx in eos_indices:
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token at index {eos_idx} to be labeled"
|
||||||
|
|
||||||
|
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='turn'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="turn",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||||
|
|
||||||
|
eos_idx = start_idx + len(response_ids)
|
||||||
|
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
||||||
|
eos_idx += 1
|
||||||
|
|
||||||
|
assert eos_idx < len(
|
||||||
|
input_ids
|
||||||
|
), f"Could not find EOS token after '{response}'"
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token after assistant response '{response}' to be labeled"
|
||||||
|
|
||||||
|
# Check that EOS tokens after human inputs are not labeled
|
||||||
|
human_inputs = ["Hello", "How are you?"]
|
||||||
|
for input_text in human_inputs:
|
||||||
|
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, input_ids)
|
||||||
|
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
|
||||||
|
|
||||||
|
eos_idx = start_idx + len(input_ids)
|
||||||
|
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
||||||
|
eos_idx += 1
|
||||||
|
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token after human input '{input_text}' to not be labeled"
|
||||||
|
|
||||||
|
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='last'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="last",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
eos_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||||
|
last_eos_idx = eos_indices[-1]
|
||||||
|
|
||||||
|
# Check that only the last EOS token is labeled
|
||||||
|
for idx in eos_indices[:-1]:
|
||||||
|
assert (
|
||||||
|
labels[idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token at index {idx} to not be labeled"
|
||||||
|
assert (
|
||||||
|
labels[last_eos_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
|
||||||
|
|
||||||
|
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='none'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="none",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
eos_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||||
|
for eos_idx in eos_indices:
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token at index {eos_idx} to not be labeled"
|
||||||
|
|
||||||
|
def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with drop_system_message=True")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Check if system message is not present in input_ids
|
||||||
|
system_message = "You are an AI assistant."
|
||||||
|
system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
|
||||||
|
assert (
|
||||||
|
self.find_sublist(input_ids, system_ids) == -1
|
||||||
|
), "Expected system message to be dropped"
|
||||||
|
|
||||||
|
def test_custom_roles(self, llama3_tokenizer):
|
||||||
|
LOG.info("Testing with custom roles mapping")
|
||||||
|
custom_roles = {
|
||||||
|
"user": ["human", "user"],
|
||||||
|
"assistant": ["ai", "assistant"],
|
||||||
|
"system": ["context"],
|
||||||
|
}
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["ai"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new dataset with modified role names
|
||||||
|
modified_conversations = [
|
||||||
|
{"from": "context", "value": "You are an AI assistant."},
|
||||||
|
{"from": "human", "value": "Hello"},
|
||||||
|
{"from": "ai", "value": "Hi there!"},
|
||||||
|
{"from": "human", "value": "How are you?"},
|
||||||
|
{"from": "ai", "value": "I'm doing well, thank you!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
modified_dataset = Dataset.from_dict(
|
||||||
|
{"conversations": [modified_conversations]}
|
||||||
|
)
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(modified_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Check if AI responses are labeled correctly
|
||||||
|
ai_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in ai_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
assert start_idx != -1, f"Could not find response '{response}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for AI response '{response}' to be set"
|
||||||
|
|
||||||
|
# Check if human messages are not labeled
|
||||||
|
human_messages = ["Hello", "How are you?"]
|
||||||
|
for message in human_messages:
|
||||||
|
message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, message_ids)
|
||||||
|
assert start_idx != -1, f"Could not find message '{message}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(message_ids)]
|
||||||
|
), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
|
||||||
|
|
||||||
|
def test_message_field_training(self, llama3_tokenizer):
|
||||||
|
LOG.info("Testing with message_field_training")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer,
|
||||||
|
chat_templates("llama3"),
|
||||||
|
message_field_training="train",
|
||||||
|
message_field_training_detail="train_detail",
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new dataset with the train and train_detail fields
|
||||||
|
modified_conversation = [
|
||||||
|
{"from": "system", "value": "You are an AI assistant.", "train": False},
|
||||||
|
{"from": "human", "value": "Hello", "train": False},
|
||||||
|
{"from": "assistant", "value": "Hello", "train": True},
|
||||||
|
{"from": "human", "value": "How are you?", "train": True},
|
||||||
|
{
|
||||||
|
"from": "assistant",
|
||||||
|
"value": "I'm doing very well, thank you!",
|
||||||
|
"train_detail": [
|
||||||
|
{"begin_offset": 0, "end_offset": 8, "train": False},
|
||||||
|
{"begin_offset": 9, "end_offset": 18, "train": True},
|
||||||
|
{"begin_offset": 19, "end_offset": 30, "train": False},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "I'm doing very well, thank you!",
|
||||||
|
"train": False,
|
||||||
|
},
|
||||||
|
{"from": "assistant", "value": "Hi there!", "train": True},
|
||||||
|
]
|
||||||
|
|
||||||
|
modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(modified_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Function to find all occurrences of a sublist
|
||||||
|
def find_all_sublists(full_list, sub_list):
|
||||||
|
indices = []
|
||||||
|
for index in range(len(full_list) - len(sub_list) + 1):
|
||||||
|
if full_list[index : index + len(sub_list)] == sub_list:
|
||||||
|
indices.append(index)
|
||||||
|
return indices
|
||||||
|
|
||||||
|
# Keep track of which occurrences we've processed
|
||||||
|
processed_occurrences = {}
|
||||||
|
# Check if messages are labeled correctly based on train or train_detail
|
||||||
|
for i, turn in enumerate(modified_conversation):
|
||||||
|
turn_tokens = llama3_tokenizer.encode(
|
||||||
|
turn["value"], add_special_tokens=False
|
||||||
|
)
|
||||||
|
occurrences = find_all_sublists(input_ids, turn_tokens)
|
||||||
|
turn_key = turn["value"]
|
||||||
|
if turn_key not in processed_occurrences:
|
||||||
|
processed_occurrences[turn_key] = 0
|
||||||
|
current_occurrence = processed_occurrences[turn_key]
|
||||||
|
|
||||||
|
if current_occurrence >= len(occurrences):
|
||||||
|
assert (
|
||||||
|
False
|
||||||
|
), f"Not enough occurrences found for message: {turn['value']}"
|
||||||
|
|
||||||
|
start_idx = occurrences[current_occurrence]
|
||||||
|
processed_occurrences[turn_key] += 1
|
||||||
|
end_idx = start_idx + len(turn_tokens)
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "train_detail" in turn:
|
||||||
|
# Get token offsets
|
||||||
|
tokenized_output = llama3_tokenizer(
|
||||||
|
turn["value"], return_offsets_mapping=True, add_special_tokens=False
|
||||||
|
)
|
||||||
|
token_offsets = tokenized_output["offset_mapping"]
|
||||||
|
|
||||||
|
# Adjust token offsets as done in the implementation
|
||||||
|
for i in range(len(token_offsets) - 1):
|
||||||
|
token_offsets[i] = (
|
||||||
|
token_offsets[i][0],
|
||||||
|
token_offsets[i + 1][0] - 1,
|
||||||
|
)
|
||||||
|
token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
|
||||||
|
|
||||||
|
# Adjust train_details
|
||||||
|
adjusted_train_details = strategy.prompter.adjust_train_details(
|
||||||
|
turn["train_detail"], token_offsets
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Original train_details: {turn['train_detail']}")
|
||||||
|
LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
|
||||||
|
|
||||||
|
# Handle train_detail
|
||||||
|
token_offsets = strategy.prompter.get_offsets_for_train_detail(
|
||||||
|
text=turn["value"],
|
||||||
|
train_details=adjusted_train_details,
|
||||||
|
mask_untrainable=False,
|
||||||
|
)
|
||||||
|
token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
|
||||||
|
text=turn["value"],
|
||||||
|
train_details=adjusted_train_details,
|
||||||
|
mask_untrainable=True,
|
||||||
|
)
|
||||||
|
LOG.debug(f"Token offsets: {token_offsets_masked}")
|
||||||
|
|
||||||
|
expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
|
||||||
|
for i, offset in enumerate(token_offsets_masked):
|
||||||
|
if offset != IGNORE_TOKEN_ID:
|
||||||
|
expected_labels[i] = turn_tokens[i]
|
||||||
|
actual_labels = labels[
|
||||||
|
start_idx : start_idx + len(token_offsets_masked)
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
actual_labels == expected_labels
|
||||||
|
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
|
||||||
|
|
||||||
|
for detail in adjusted_train_details:
|
||||||
|
# Find the token indices that correspond to the character offsets
|
||||||
|
detail_start = start_idx + next(
|
||||||
|
i
|
||||||
|
for i, offset in enumerate(token_offsets)
|
||||||
|
if offset >= detail["begin_offset"]
|
||||||
|
)
|
||||||
|
detail_end = start_idx + next(
|
||||||
|
(
|
||||||
|
i
|
||||||
|
for i, offset in enumerate(token_offsets)
|
||||||
|
if offset > detail["end_offset"]
|
||||||
|
),
|
||||||
|
len(token_offsets),
|
||||||
|
)
|
||||||
|
|
||||||
|
detail_text = turn["value"][
|
||||||
|
detail["begin_offset"] : detail["end_offset"] + 1
|
||||||
|
]
|
||||||
|
detail_labels = labels[detail_start:detail_end]
|
||||||
|
detail_input_ids = input_ids[detail_start:detail_end]
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
|
||||||
|
)
|
||||||
|
LOG.debug(f"Detail input_ids: {detail_input_ids}")
|
||||||
|
LOG.debug(f"Detail labels: {detail_labels}")
|
||||||
|
LOG.debug(
|
||||||
|
f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
|
||||||
|
)
|
||||||
|
LOG.debug(
|
||||||
|
f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if detail["train"]:
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID for label in detail_labels
|
||||||
|
), (
|
||||||
|
f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
|
||||||
|
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
|
||||||
|
f"InputIDs: {detail_input_ids}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID for label in detail_labels
|
||||||
|
), (
|
||||||
|
f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
|
||||||
|
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
|
||||||
|
f"InputIDs: {detail_input_ids}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
should_train = turn.get("train", False)
|
||||||
|
turn_labels = labels[start_idx:end_idx]
|
||||||
|
|
||||||
|
LOG.debug(f"Should train: {should_train}")
|
||||||
|
LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
|
||||||
|
LOG.debug(f"Turn labels: {turn_labels}")
|
||||||
|
LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
|
||||||
|
LOG.debug(
|
||||||
|
f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_train:
|
||||||
|
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
||||||
|
f"Expected all labels for '{turn['value']}' to be set\n"
|
||||||
|
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
|
||||||
|
f"InputIDs: {input_ids[start_idx:end_idx]}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
|
||||||
|
f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
|
||||||
|
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
|
||||||
|
f"InputIDs: {input_ids[start_idx:end_idx]}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"Processed turn: {turn['from']}, content: '{turn['value']}', "
|
||||||
|
f"start_idx: {start_idx}, end_idx: {end_idx}, "
|
||||||
|
f"labels: {labels[start_idx:end_idx]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Final labels: {labels}")
|
||||||
|
LOG.debug(f"Final input_ids: {input_ids}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestAssistantChatTemplateLlama3:
|
class TestAssistantChatTemplateLlama3:
|
||||||
"""
|
"""
|
||||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
||||||
# pylint: disable=duplicate-code
|
LOG.info("Loading llama-3 tokenizer with assistant dataset")
|
||||||
strategy = load(
|
strategy = load(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
DictDefault(
|
DictDefault(
|
||||||
@@ -162,21 +753,26 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
assert input_ids == [
|
expected_input_ids = [
|
||||||
128000, # bos
|
128000, # bos
|
||||||
128006, 882, 128007, # user header
|
128006, 882, 128007, # user header
|
||||||
271, 15339, 128009, # user prompt eot
|
271, 15339, 128009, # user prompt eot
|
||||||
128006, 78191, 128007, # assistant header
|
128006, 78191, 128007, # assistant header
|
||||||
271, 15339, 128009, # assistant response eot
|
271, 15339, 128009, # assistant response eot
|
||||||
128006, 882, 128007,
|
128006, 882, 128007,
|
||||||
271, 19045, 29474, 128009,
|
271, 19045, 29474, 128009,
|
||||||
128006, 78191, 128007,
|
128006, 78191, 128007,
|
||||||
271, 19045, 29474, 128009,
|
271, 19045, 29474, 128009,
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
|
||||||
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
||||||
# pylint: disable=duplicate-code
|
LOG.info("Testing llama-3 with assistant dataset")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
@@ -189,15 +785,16 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
"system": ["system"],
|
"system": ["system"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
False,
|
train_on_inputs=False,
|
||||||
512,
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
)
|
)
|
||||||
strategy.messages = "messages"
|
strategy.messages = "messages"
|
||||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
assert input_ids == [
|
expected_input_ids = [
|
||||||
128000, # bos
|
128000, # bos
|
||||||
128006, 882, 128007, # user header
|
128006, 882, 128007, # user header
|
||||||
271, 15339, 128009, # user prompt eot
|
271, 15339, 128009, # user prompt eot
|
||||||
@@ -209,6 +806,64 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
271, 19045, 29474, 128009,
|
271, 19045, 29474, 128009,
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
|
||||||
|
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
|
||||||
|
LOG.info("Testing llama-3 with assistant dataset including training data")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer,
|
||||||
|
chat_templates("llama3"),
|
||||||
|
message_field_role="role",
|
||||||
|
message_field_content="content",
|
||||||
|
message_field_training="training",
|
||||||
|
roles={
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
train_on_eos="none",
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
strategy.messages = "messages"
|
||||||
|
prompt_tokens = strategy.prompter.build_prompt(
|
||||||
|
assistant_dataset[0]["messages"], False
|
||||||
|
)
|
||||||
|
prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False)
|
||||||
|
LOG.debug(f"Generated prompt: {prompt}")
|
||||||
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
# fmt: off
|
||||||
|
expected_labels = [
|
||||||
|
IGNORE_TOKEN_ID, # bos
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||||
|
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
LOG.debug(f"Expected labels: {expected_labels}")
|
||||||
|
LOG.debug(f"Actual labels: {labels}")
|
||||||
|
assert labels == expected_labels, (
|
||||||
|
f"Labels mismatch:\n"
|
||||||
|
f"Expected: {expected_labels}\n"
|
||||||
|
f"Actual: {labels}\n"
|
||||||
|
f"Input IDs: {input_ids}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSharegptChatTemplateLlama3:
|
class TestSharegptChatTemplateLlama3:
|
||||||
@@ -216,30 +871,160 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_llama3(self, llama3_tokenizer, sharegpt_dataset):
|
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
||||||
# pylint: disable=duplicate-code
|
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
False,
|
train_on_inputs=False,
|
||||||
512,
|
train_on_eos="none",
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["gpt"],
|
||||||
)
|
)
|
||||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
|
labels = res["labels"]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
assert input_ids == [
|
expected_input_ids = [
|
||||||
128000, # bos
|
128000, # bos
|
||||||
128006, 882, 128007, # user header
|
128006, 882, 128007, # user header
|
||||||
271, 15339, 128009, # user prompt eot
|
271, 15339, 128009, # user prompt eot
|
||||||
128006, 78191, 128007, # assistant header
|
128006, 78191, 128007, # assistant header
|
||||||
271, 15339, 128009, # assistant response eot
|
271, 15339, 128009, # assistant response eot
|
||||||
128006, 882, 128007,
|
128006, 882, 128007,
|
||||||
271, 19045, 29474, 128009,
|
271, 19045, 29474, 128009,
|
||||||
128006, 78191, 128007,
|
128006, 78191, 128007,
|
||||||
271, 19045, 29474, 128009,
|
271, 19045, 29474, 128009,
|
||||||
]
|
]
|
||||||
|
expected_labels = [
|
||||||
|
IGNORE_TOKEN_ID, # bos
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||||
|
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
||||||
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
LOG.debug(f"Expected labels: {expected_labels}")
|
||||||
|
LOG.debug(f"Actual labels: {labels}")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
assert (
|
||||||
|
labels == expected_labels
|
||||||
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||||
|
|
||||||
|
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
||||||
|
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
train_on_eos="none",
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["human"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
labels = res["labels"]
|
||||||
|
# fmt: off
|
||||||
|
expected_input_ids = [
|
||||||
|
128000, # bos
|
||||||
|
128006, 882, 128007, # user header
|
||||||
|
271, 15339, 128009, # user prompt eot
|
||||||
|
128006, 78191, 128007, # assistant header
|
||||||
|
271, 15339, 128009, # assistant response eot
|
||||||
|
128006, 882, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
128006, 78191, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
]
|
||||||
|
expected_labels = [
|
||||||
|
IGNORE_TOKEN_ID, # bos
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||||
|
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # user prompt eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
LOG.debug(f"Expected labels: {expected_labels}")
|
||||||
|
LOG.debug(f"Actual labels: {labels}")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
assert (
|
||||||
|
labels == expected_labels
|
||||||
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||||
|
|
||||||
|
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
train_on_eos="none",
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["system", "human"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
labels = res["labels"]
|
||||||
|
# fmt: off
|
||||||
|
expected_input_ids = [
|
||||||
|
128000, # bos
|
||||||
|
128006, 9125, 128007,
|
||||||
|
271, 2675, 527, 459, 15592, 18328, 13, 128009,
|
||||||
|
128006, 882, 128007, # user header
|
||||||
|
271, 9906, 128009, # user prompt eot
|
||||||
|
128006, 78191, 128007, # assistant header
|
||||||
|
271, 13347, 1070, 0, 128009, # assistant response eot
|
||||||
|
128006, 882, 128007,
|
||||||
|
271, 4438, 527, 499, 30, 128009,
|
||||||
|
128006, 78191, 128007,
|
||||||
|
271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009,
|
||||||
|
]
|
||||||
|
expected_labels = [
|
||||||
|
IGNORE_TOKEN_ID, # bos
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
|
||||||
|
IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID, # system prompt eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||||
|
IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID, # user prompt eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
LOG.debug(f"Expected labels: {expected_labels}")
|
||||||
|
LOG.debug(f"Actual labels: {labels}")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
assert (
|
||||||
|
labels == expected_labels
|
||||||
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -192,6 +192,7 @@ class TestSharegptLlama3:
|
|||||||
input_ids = dataset_wrapper[0]["input_ids"]
|
input_ids = dataset_wrapper[0]["input_ids"]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
assert input_ids == [
|
assert input_ids == [
|
||||||
128000, # bos
|
128000, # bos
|
||||||
128006, 9125, 128007, # system header
|
128006, 9125, 128007, # system header
|
||||||
@@ -228,6 +229,7 @@ class TestSharegptLlama3:
|
|||||||
input_ids = dataset_wrapper[0]["input_ids"]
|
input_ids = dataset_wrapper[0]["input_ids"]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
assert input_ids == [
|
assert input_ids == [
|
||||||
128000, # bos
|
128000, # bos
|
||||||
128006, 9125, 128007, # system header
|
128006, 9125, 128007, # system header
|
||||||
|
|||||||
Reference in New Issue
Block a user