Merge branch 'main' into cj_tokenizer_default_prompt_template

This commit is contained in:
Chirag Jain
2024-07-30 23:48:43 +05:30
committed by GitHub
25 changed files with 1511 additions and 309 deletions

View File

@@ -12,36 +12,24 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "118" - cuda: "121"
cuda_version: 11.8.0 cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121" - cuda: "121"
cuda_version: 12.1.0 cuda_version: 12.1.1
python_version: "3.10" cudnn_version: 8
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
@@ -67,6 +55,7 @@ jobs:
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-args: | build-args: |
CUDA_VERSION=${{ matrix.cuda_version }} CUDA_VERSION=${{ matrix.cuda_version }}
CUDNN_VERSION=${{ matrix.cudnn_version }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }} PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }} PYTORCH_VERSION=${{ matrix.pytorch }}

View File

@@ -13,28 +13,22 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 118 - cuda: 121
cuda_version: 11.8.0 cuda_version: 12.1.1
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.3.1
axolotl_extras: axolotl_extras: mamba-ssm
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: axolotl_extras: mamba-ssm
is_latest: true is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -65,6 +59,7 @@ jobs:
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: | tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
@@ -75,27 +70,22 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 118 - cuda: 121
cuda_version: 11.8.0 cuda_version: 12.1.1
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -134,7 +124,7 @@ jobs:
matrix: matrix:
include: include:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: axolotl_extras:

View File

@@ -12,28 +12,22 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 118 - cuda: 121
cuda_version: 11.8.0 cuda_version: 12.1.1
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.3.1
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -75,27 +69,22 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 118 - cuda: 121
cuda_version: 11.8.0 cuda_version: 12.1.1
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -72,27 +72,24 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 118 - cuda: 121
cuda_version: 11.8.0 cuda_version: 12.1.1
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.3.1
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
num_gpus: 1 num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.1.2
num_gpus: 1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
num_gpus: 1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
num_gpus: 1 num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -109,6 +106,7 @@ jobs:
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal

View File

@@ -334,7 +334,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field. Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats. See [the documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
### Config ### Config

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -3,7 +3,7 @@ ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"

View File

@@ -0,0 +1,62 @@
base_model: meta-llama/Meta-Llama-3.1-405B
tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
adapter: qlora
sequence_len: 1024
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>

View File

@@ -1,18 +1,18 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.11.1 peft==0.11.1
transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf transformers==4.43.3
tokenizers==0.19.1 tokenizers==0.19.1
bitsandbytes==0.43.1 bitsandbytes==0.43.1
accelerate==0.32.0 accelerate==0.32.0
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b deepspeed==0.14.4
pydantic==2.6.3 pydantic==2.6.3
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
datasets==2.19.1 datasets==2.19.1
flash-attn==2.6.1 flash-attn==2.6.2
sentencepiece sentencepiece
wandb wandb
einops einops
@@ -32,6 +32,7 @@ fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e59
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq>=0.2.5
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1

View File

@@ -80,13 +80,13 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.6.1", "flash-attn==2.6.2",
], ],
"fused-dense-lib": [ "fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib", "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b", "deepspeed==0.14.4",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -2,6 +2,7 @@
CLI to run training on a model CLI to run training on a model
""" """
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -76,8 +77,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
if parsed_cli_args.download: if parsed_cli_args.download:
model_name = parsed_cfg.base_model model_name = parsed_cfg.base_model
with init_empty_weights(): with warnings.catch_warnings():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) # there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
warnings.simplefilter("ignore")
with init_empty_weights(include_buffers=True):
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
LOG.info( LOG.info(
Fore.GREEN Fore.GREEN

View File

@@ -0,0 +1,14 @@
"""
Common architecture specific constants
"""
MOE_ARCH_BLOCK = {
"dbrx": "DbrxFFN",
"jamba": "JambaSparseMoeBlock",
"jetmoe": [
"JetMoeMoA",
"JetMoeMoE",
],
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
}

View File

@@ -8,6 +8,7 @@ import importlib
import importlib.util import importlib.util
import logging import logging
import math import math
import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from collections import defaultdict from collections import defaultdict
@@ -28,9 +29,18 @@ from transformers import (
TrainerCallback, TrainerCallback,
TrainingArguments, TrainingArguments,
) )
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer from trl import (
CPOConfig,
CPOTrainer,
DPOConfig,
DPOTrainer,
KTOConfig,
KTOTrainer,
ORPOConfig,
ORPOTrainer,
)
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
@@ -265,7 +275,89 @@ class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
""" """
class AxolotlTrainer(Trainer): @dataclass
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
"""
CPO config for CPO training
"""
simpo_gamma: Optional[float] = field(
default=None,
metadata={"help": "simpo gamma parameter"},
)
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: AxolotlTrainingArguments
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer):
""" """
Extend the base Trainer for axolotl helpers Extend the base Trainer for axolotl helpers
""" """
@@ -383,68 +475,6 @@ class AxolotlTrainer(Trainer):
return self.optimizer return self.optimizer
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches: if self.args.multipack_real_batches:
@@ -809,6 +839,14 @@ class AxolotlTrainer(Trainer):
for key, value in metrics.items(): for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value) self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
""" """
@@ -908,7 +946,7 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler return self.lr_scheduler
class AxolotlDPOTrainer(DPOTrainer): class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
""" """
Extend the base DPOTrainer for axolotl helpers Extend the base DPOTrainer for axolotl helpers
""" """
@@ -969,7 +1007,7 @@ class AxolotlDPOTrainer(DPOTrainer):
return res return res
class AxolotlORPOTrainer(ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
""" """
Extend the base ORPOTrainer for axolotl helpers Extend the base ORPOTrainer for axolotl helpers
""" """
@@ -977,7 +1015,7 @@ class AxolotlORPOTrainer(ORPOTrainer):
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(KTOTrainer): class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
""" """
Extend the base KTOTrainer for axolotl helpers Extend the base KTOTrainer for axolotl helpers
""" """
@@ -985,6 +1023,14 @@ class AxolotlKTOTrainer(KTOTrainer):
tag_names = ["axolotl", "kto"] tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """
Base class for trainer builder Base class for trainer builder
@@ -1707,6 +1753,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined # default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch" training_args_kwargs["save_strategy"] = "epoch"
if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if self.cfg.orpo_alpha: if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha
@@ -1715,9 +1763,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_cls = AxolotlDPOConfig training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None: if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
if self.cfg.rl == "orpo": if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig training_args_cls = AxolotlORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len: if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
@@ -1725,7 +1781,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "kto": if self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig training_args_cls = AxolotlKTOConfig
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
training_args_kwargs["desirable_weight"] = ( training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0 self.cfg.kto_desirable_weight or 1.0
) )
@@ -1771,7 +1826,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
] = self.cfg.precompute_ref_log_probs ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]: if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer # these aren't used for the ORPO trainer
@@ -1785,6 +1839,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in ["kto"]: elif self.cfg.rl in ["kto"]:
trainer_cls = AxolotlKTOTrainer trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
elif self.cfg.rl in ["simpo"]:
trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model]
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls( dpo_trainer = trainer_cls(

View File

@@ -6,14 +6,16 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import chat_templates
# Configure the logger
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
class ChatTemplatePrompter(Prompter): class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates""" """Prompter for HF chat templates"""
def __init__( def __init__(
self, self,
@@ -22,6 +24,8 @@ class ChatTemplatePrompter(Prompter):
max_length=2048, max_length=2048,
message_field_role: str = "from", message_field_role: str = "from",
message_field_content: str = "value", message_field_content: str = "value",
message_field_training: str = "train",
message_field_training_detail: str = "train_detail",
roles: Optional[Dict[str, List[str]]] = None, roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False, drop_system_message: bool = False,
): ):
@@ -37,6 +41,8 @@ class ChatTemplatePrompter(Prompter):
} }
self.message_field_role = message_field_role self.message_field_role = message_field_role
self.message_field_content = message_field_content self.message_field_content = message_field_content
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.chat_template = chat_template self.chat_template = chat_template
self.max_length = max_length self.max_length = max_length
@@ -47,6 +53,7 @@ class ChatTemplatePrompter(Prompter):
{ {
"role": self.roles[t[self.message_field_role]], "role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content], "content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
} }
for t in conversation for t in conversation
] ]
@@ -62,6 +69,108 @@ class ChatTemplatePrompter(Prompter):
chat_template=self.chat_template, chat_template=self.chat_template,
) )
def get_offsets_for_train_detail(
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
) -> List[int]:
tokenized_output = self.tokenizer(
text, return_offsets_mapping=True, add_special_tokens=False
)
tokens = tokenized_output.tokens()
token_offsets = tokenized_output["offset_mapping"]
LOG.debug(f"Tokenizing text: {text}")
LOG.debug(f"Tokens: {tokens}")
# Adjust the end offsets. For some reason by default they are set to the same value as the start offsets.
for i in range(len(token_offsets) - 1):
token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1)
# Ensure the last token's end offset is set correctly
token_offsets[-1] = (token_offsets[-1][0], len(text) - 1)
LOG.debug(f"Token offsets: {token_offsets}")
# Initialize all offsets as IGNORE_TOKEN_ID (not trained)
result = [IGNORE_TOKEN_ID] * len(token_offsets)
# Adjust train_details to align with token boundaries
adjusted_train_details = self.adjust_train_details(train_details, token_offsets)
for idx, (start, end) in enumerate(token_offsets):
for detail in adjusted_train_details:
# Check if the token is completely within the detail's range
if start >= detail["begin_offset"] and end <= detail["end_offset"]:
if detail["train"] or not mask_untrainable:
result[idx] = start
LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training")
else:
LOG.debug(
f"Token {idx} ({tokens[idx]}) marked as non-trainable"
)
elif start < detail["end_offset"] and end > detail["begin_offset"]:
# Token partially overlaps with detail, always mark as non-trainable
LOG.debug(
f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable"
)
LOG.debug(f"Final result: {result}")
return result
def adjust_train_details(
self, train_details: List[Dict], token_offsets: List[tuple]
) -> List[Dict]:
adjusted_details = []
for detail in train_details:
begin_offset = detail["begin_offset"]
end_offset = detail["end_offset"]
# Find the first token that starts after or at the begin_offset
begin_token = next(
(
i
for i, (t_start, t_end) in enumerate(token_offsets)
if t_start >= begin_offset
),
len(token_offsets),
)
if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset:
begin_token -= 1
# Find the last token that ends before or at the end_offset
end_token = next(
(
i
for i in range(len(token_offsets) - 1, -1, -1)
if token_offsets[i][1] <= end_offset
),
-1,
)
if (
end_token < len(token_offsets) - 1
and token_offsets[end_token + 1][0] < end_offset
):
end_token += 1
if begin_token <= end_token:
adjusted_begin = token_offsets[begin_token][0]
adjusted_end = token_offsets[end_token][1]
if adjusted_begin != begin_offset or adjusted_end != end_offset:
LOG.warning(
f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})"
)
adjusted_details.append(
{
"begin_offset": adjusted_begin,
"end_offset": adjusted_end,
"train": detail["train"],
}
)
else:
LOG.warning(
f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail."
)
return adjusted_details
class ChatTemplateStrategy(PromptTokenizingStrategy): class ChatTemplateStrategy(PromptTokenizingStrategy):
""" """
@@ -70,6 +179,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
_messages = "conversations" _messages = "conversations"
def __init__(
self,
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos="last",
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.train_on_eos = train_on_eos
@property @property
def messages(self): def messages(self):
return self._messages return self._messages
@@ -79,65 +201,172 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self._messages = messages self._messages = messages
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt) turns = prompt[self.messages]
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns) input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)
if not self.train_on_inputs: last_eos_idx = -1
user_prompt_len = len(prompt_ids) for index, turn in enumerate(turns):
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] role = turn.get(self.prompter.message_field_role)
else: content = turn.get(self.prompter.message_field_content)
labels = input_ids train_turn = turn.get(self.prompter.message_field_training)
train_detail = turn.get(self.prompter.message_field_training_detail)
tokenized_prompt = { LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
)
should_train = (
train_turn
if train_turn is not None
else bool(train_detail is not None)
if train_detail is not None
else self.train_on_inputs or role in self.roles_to_train
)
LOG.debug(f"Should train: {should_train}")
turn_start_idx, turn_end_idx = self.find_turn(
conversation_ids=input_ids, turn=index, turn_content=turn
)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
token_offsets = self.prompter.get_offsets_for_train_detail(
content, train_detail
)
LOG.debug(f"Token offsets: {token_offsets}")
for i, offset in enumerate(token_offsets):
if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len(
input_ids
):
labels[turn_start_idx + i] = input_ids[turn_start_idx + i]
LOG.debug(
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
)
else:
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle EOS token
eos_idx = self.find_eos_token(input_ids, turn_end_idx)
if eos_idx == turn_end_idx:
last_eos_idx = eos_idx
if self.train_on_eos == "all" or (
self.train_on_eos == "turn" and should_train
):
labels[eos_idx] = input_ids[eos_idx]
LOG.debug(f"EOS token set for training at index {eos_idx}")
else:
LOG.debug(
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
)
# Handle 'last' option for train_on_eos
if self.train_on_eos == "last" and last_eos_idx != -1:
labels[last_eos_idx] = input_ids[last_eos_idx]
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
LOG.debug(f"Final labels: {labels}")
return {
"input_ids": input_ids, "input_ids": input_ids,
"labels": labels, "labels": labels,
"attention_mask": [1] * len(input_ids), "attention_mask": [1] * len(input_ids),
} }
return tokenized_prompt def find_eos_token(self, input_ids, start_idx):
eos_token_id = self.tokenizer.eos_token_id
for i in range(start_idx, len(input_ids)):
if input_ids[i] == eos_token_id:
return i
return -1
def find_turn(self, conversation_ids, turn, turn_content):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.
Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
"""
content = turn_content.get(self.prompter.message_field_content, "")
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
# Locate the starting index after the specified number of EOS tokens
for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break
# Find the start index of the content within the conversation
start_idx = -1
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break
if start_idx != -1:
end_idx = start_idx + len(content_ids)
else:
end_idx = -1
return start_idx, end_idx
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
return prompt[self.messages] return prompt[self.messages]
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = ( ds_cfg = ds_cfg or {}
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml" chat_template = ds_cfg.get("chat_template", "chatml")
)
message_field_role = (
ds_cfg["message_field_role"]
if ds_cfg and "message_field_role" in ds_cfg
else "from"
)
message_field_content = (
ds_cfg["message_field_content"]
if ds_cfg and "message_field_content" in ds_cfg
else "value"
)
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
drop_system_message = (
ds_cfg["drop_system_message"]
if ds_cfg and "drop_system_message" in ds_cfg
else False
)
chat_template_str = chat_templates(chat_template, tokenizer=tokenizer) chat_template_str = chat_templates(chat_template, tokenizer=tokenizer)
LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---") LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "train_detail"
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train"),
"train_on_eos": ds_cfg.get("train_on_eos", "last"),
}
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
tokenizer,
chat_template_str,
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,
drop_system_message=drop_system_message,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"] strategy.messages = ds_cfg["field_messages"]
return strategy return strategy

View File

@@ -62,7 +62,7 @@ def default(
tokenize=False, tokenize=False,
) )
chosen_strip_index = result["chosen"].find(chosen["content"]) chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:] result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template( result["rejected"] = tokenizer.apply_chat_template(
[rejected], [rejected],
@@ -71,7 +71,7 @@ def default(
tokenize=False, tokenize=False,
) )
rejected_strip_index = result["rejected"].find(rejected["content"]) rejected_strip_index = result["rejected"].find(rejected["content"])
result["rejected"] = result["rejected"][rejected_strip_index:] result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
return result return result

View File

@@ -212,26 +212,23 @@ def train(
elif cfg.deepspeed and is_deepspeed_zero3_enabled(): elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone() trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped) trainer.save_model(cfg.output_dir)
# the trainer saved a model.safetensors file in the output directory, # the trainer saved a model.safetensors file in the output directory,
# but it is a proxy model and should be deleted # but it is most likely a proxy model and if so, should be deleted
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")): maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors"))
maybe_sharded = os.path.exists(
os.path.join(cfg.output_dir, "model.safetensors.index.json")
)
if maybe_proxy and maybe_sharded:
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}") LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
LOG.info("This is a proxy model and should be deleted") LOG.info("This is a proxy model and should be deleted")
os.remove(os.path.join(cfg.output_dir, "model.safetensors")) try:
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)

View File

@@ -123,6 +123,10 @@ class SFTDataset(BaseModel):
field_messages: Optional[str] = None field_messages: Optional[str] = None
message_field_role: Optional[str] = None message_field_role: Optional[str] = None
message_field_content: Optional[str] = None message_field_content: Optional[str] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None drop_system_message: Optional[bool] = None
@@ -179,6 +183,7 @@ class RLType(str, Enum):
ipo = "ipo" # pylint: disable=invalid-name ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum): class ChatTemplate(str, Enum):
@@ -653,6 +658,8 @@ class AxolotlInputConfig(
orpo_alpha: Optional[float] = None orpo_alpha: Optional[float] = None
rpo_alpha: Optional[float] = None rpo_alpha: Optional[float] = None
simpo_gamma: Optional[float] = None
cpo_alpha: Optional[float] = None
kto_desirable_weight: Optional[float] = None kto_desirable_weight: Optional[float] = None
kto_undesirable_weight: Optional[float] = None kto_undesirable_weight: Optional[float] = None

View File

@@ -42,7 +42,7 @@ from axolotl.prompters import (
from axolotl.utils.data.pretraining import wrap_pretraining_dataset from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import md5 from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
calculate_total_num_steps, calculate_total_num_steps,
process_datasets_for_packing, process_datasets_for_packing,
@@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl")
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with zero_first(is_local_main_process()):
if cfg.test_datasets: if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets( train_dataset, _, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
@@ -170,6 +170,7 @@ def load_tokenized_prepared_datasets(
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
if dataset: if dataset:
# This is for the case where we already loaded a pretokenized dataset from the hub
... ...
elif ( elif (
cfg.dataset_prepared_path cfg.dataset_prepared_path
@@ -198,6 +199,8 @@ def load_tokenized_prepared_datasets(
def for_d_in_datasets(dataset_configs): def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs: for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list): if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name: for name in dataset.name:
yield DictDefault({**dataset, "name": name}) yield DictDefault({**dataset, "name": name})
else: else:
@@ -208,6 +211,8 @@ def load_tokenized_prepared_datasets(
ds: Optional[Union[Dataset, DatasetDict]] = None ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False ds_from_hub = False
try: try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset( load_dataset(
config_dataset.path, config_dataset.path,
name=config_dataset.name, name=config_dataset.name,

View File

@@ -44,6 +44,10 @@ def is_main_process():
return dist.get_rank() == 0 return dist.get_rank() == 0
def is_local_main_process():
return PartialState().is_main_process
def get_world_size(): def get_world_size():
return int(os.getenv("WORLD_SIZE", "1")) return int(os.getenv("WORLD_SIZE", "1"))

View File

@@ -29,6 +29,7 @@ from transformers import ( # noqa: F401
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
AwqConfig,
BitsAndBytesConfig, BitsAndBytesConfig,
GPTQConfig, GPTQConfig,
PreTrainedModel, PreTrainedModel,
@@ -36,6 +37,7 @@ from transformers import ( # noqa: F401
) )
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES, SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -510,7 +512,25 @@ def load_model(
model_kwargs["quantization_config"] = GPTQConfig( model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config **model_config.quantization_config
) )
if cfg.adapter == "qlora" and cfg.load_in_4bit: if (
cfg.adapter in ["qlora", "lora"]
and hasattr(model_config, "quantization_config")
and model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
):
if model_config.quantization_config["quant_method"] == "gptq":
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
elif model_config.quantization_config["quant_method"] == "awq":
model_kwargs["quantization_config"] = AwqConfig(
**model_config.quantization_config
)
elif model_config.quantization_config["quant_method"] == "bitsandbytes":
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**model_config.quantization_config
)
elif cfg.adapter == "qlora" and cfg.load_in_4bit:
bnb_config = { bnb_config = {
"load_in_4bit": True, "load_in_4bit": True,
"llm_int8_threshold": 6.0, "llm_int8_threshold": 6.0,
@@ -619,7 +639,7 @@ def load_model(
and not cfg.trust_remote_code and not cfg.trust_remote_code
and not cfg.gptq and not cfg.gptq
): ):
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True skip_move_to_device = True
if "device_map" in model_kwargs: if "device_map" in model_kwargs:
del model_kwargs["device_map"] del model_kwargs["device_map"]
@@ -701,7 +721,7 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
else: else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
# disabling either of these two still leads to VRAM spike before setting back down # disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True skip_move_to_device = True
if "device_map" in model_kwargs: if "device_map" in model_kwargs:
@@ -785,12 +805,14 @@ def load_model(
set_z3_leaf_modules, set_z3_leaf_modules,
) )
if cfg.model_config_type == "mixtral": if cfg.model_config_type in MOE_ARCH_BLOCK:
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock") set_z3_leaf_modules(
set_z3_leaf_modules(model, [moe_block]) model,
elif cfg.model_config_type == "dbrx": [
moe_block = get_module_class_from_name(model, "DbrxFFN") get_module_class_from_name(model, module_name)
set_z3_leaf_modules(model, [moe_block]) for module_name in MOE_ARCH_BLOCK[cfg.model_config_type]
],
)
if cfg.model_config_type == "qwen" and cfg.adapter == "lora": if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled # Qwen doesn't play nicely with LoRA if this is enabled
@@ -804,6 +826,9 @@ def load_model(
# make sure everything is in the same dtype # make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True skip_prepare_model_for_kbit_training = True
if is_deepspeed_zero3_enabled():
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]: if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable( model.gradient_checkpointing_enable(
@@ -838,6 +863,9 @@ def load_model(
else: else:
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, lora_config = load_adapter(model, cfg, cfg.adapter)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
if ( if (
cfg.ddp cfg.ddp
and not load_in_8bit and not load_in_8bit

View File

@@ -1,4 +1,5 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import json
import math import math
import os import os
import random import random
@@ -389,6 +390,19 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
return total_num_steps return total_num_steps
def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if cfg.bf16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
def setup_fsdp_envs(cfg): def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true" os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_activation_checkpointing: if cfg.fsdp_config.fsdp_activation_checkpointing:
@@ -415,8 +429,14 @@ def prepare_optim_env(cfg):
if cfg.fsdp: if cfg.fsdp:
setup_fsdp_envs(cfg) setup_fsdp_envs(cfg)
elif cfg.deepspeed: elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" stage = None
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed # check if the cfg.deepspeed is a file
if os.path.isfile(cfg.deepspeed):
# parse with json
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
deepspeed_config = json.load(fin)
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage)
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True: if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
@@ -425,7 +445,7 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]: if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1] trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2] trainer_builder.peft_config = model[2]

20
tests/e2e/test_imports.py Normal file
View File

@@ -0,0 +1,20 @@
"""
test module to import various submodules that have historically broken due to dependency issues
"""
import unittest
class TestImports(unittest.TestCase):
"""
Test class to import various submodules that have historically broken due to dependency issues
"""
def test_import_causal_trainer(self):
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
HFCausalTrainerBuilder,
)
def test_import_rl_trainer(self):
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
HFRLTrainerBuilder,
)

View File

@@ -2,6 +2,7 @@
tests for chat_template prompt strategy tests for chat_template prompt strategy
""" """
import logging
import unittest import unittest
import pytest import pytest
@@ -13,33 +14,24 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy, ChatTemplateStrategy,
load, load,
) )
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
@pytest.fixture(name="assistant_dataset") @pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset(): def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list( return Dataset.from_list(
[ [
{ {
"messages": [ "messages": [
{ {"role": "user", "content": "hello"},
"role": "user", {"role": "assistant", "content": "hello"},
"content": "hello", {"role": "user", "content": "goodbye"},
}, {"role": "assistant", "content": "goodbye"},
{
"role": "assistant",
"content": "hello",
},
{
"role": "user",
"content": "goodbye",
},
{
"role": "assistant",
"content": "goodbye",
},
] ]
} }
] ]
@@ -53,22 +45,28 @@ def fixture_sharegpt_dataset():
[ [
{ {
"conversations": [ "conversations": [
{ {"from": "human", "value": "hello"},
"from": "human", {"from": "gpt", "value": "hello"},
"value": "hello", {"from": "human", "value": "goodbye"},
}, {"from": "gpt", "value": "goodbye"},
{ ]
"from": "gpt", }
"value": "hello", ]
}, )
{
"from": "human",
"value": "goodbye", @pytest.fixture(name="basic_dataset")
}, def fixture_basic_dataset():
{ # pylint: disable=duplicate-code
"from": "gpt", return Dataset.from_list(
"value": "goodbye", [
}, {
"conversations": [
{"from": "system", "value": "You are an AI assistant."},
{"from": "human", "value": "Hello"},
{"from": "assistant", "value": "Hi there!"},
{"from": "human", "value": "How are you?"},
{"from": "assistant", "value": "I'm doing well, thank you!"},
] ]
} }
] ]
@@ -77,8 +75,7 @@ def fixture_sharegpt_dataset():
@pytest.fixture(name="llama3_tokenizer") @pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer(): def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
tokenizer.eos_token = "<|eot_id|>"
return tokenizer return tokenizer
@@ -130,13 +127,607 @@ class TestChatTemplates:
assert chat_template_str == "test_template" assert chat_template_str == "test_template"
class TestChatTemplateConfigurations:
"""
Test class for various configurations of ChatTemplateStrategy.
"""
@staticmethod
def find_sublist(full_list, sub_list):
token_count = len(sub_list)
for index in range(len(full_list) - token_count + 1):
if full_list[index : index + token_count] == sub_list:
return index
return -1
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
# Check the behavior of human inputs
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
labeled = all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(input_ids)]
)
LOG.debug(
f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
)
LOG.debug("Full labels: %s", labels)
LOG.debug("Full input_ids: %s", input_ids)
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that only assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
# Verify that human inputs are not labeled
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
LOG.debug(
f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
assert all(
label == IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(input_ids)]
), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that only assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["human", "assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that all responses are labeled (except for special tokens)
all_responses = [
"Hello",
"Hi there!",
"How are you?",
"I'm doing well, thank you!",
]
for response in all_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=[],
train_on_eos="none", # Add this line
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
# Verify that no labels are set when roles_to_train is empty
LOG.debug("Full labels: %s", labels)
assert all(
label == IGNORE_TOKEN_ID for label in labels
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="all",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to be labeled"
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="turn",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
eos_idx = start_idx + len(response_ids)
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert eos_idx < len(
input_ids
), f"Could not find EOS token after '{response}'"
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token after assistant response '{response}' to be labeled"
# Check that EOS tokens after human inputs are not labeled
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
eos_idx = start_idx + len(input_ids)
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token after human input '{input_text}' to not be labeled"
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="last",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
last_eos_idx = eos_indices[-1]
# Check that only the last EOS token is labeled
for idx in eos_indices[:-1]:
assert (
labels[idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {idx} to not be labeled"
assert (
labels[last_eos_idx] != IGNORE_TOKEN_ID
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="none",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to not be labeled"
def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with drop_system_message=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
# Check if system message is not present in input_ids
system_message = "You are an AI assistant."
system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
assert (
self.find_sublist(input_ids, system_ids) == -1
), "Expected system message to be dropped"
def test_custom_roles(self, llama3_tokenizer):
LOG.info("Testing with custom roles mapping")
custom_roles = {
"user": ["human", "user"],
"assistant": ["ai", "assistant"],
"system": ["context"],
}
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["ai"],
)
# Create a new dataset with modified role names
modified_conversations = [
{"from": "context", "value": "You are an AI assistant."},
{"from": "human", "value": "Hello"},
{"from": "ai", "value": "Hi there!"},
{"from": "human", "value": "How are you?"},
{"from": "ai", "value": "I'm doing well, thank you!"},
]
modified_dataset = Dataset.from_dict(
{"conversations": [modified_conversations]}
)
res = strategy.tokenize_prompt(modified_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Check if AI responses are labeled correctly
ai_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in ai_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
assert start_idx != -1, f"Could not find response '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for AI response '{response}' to be set"
# Check if human messages are not labeled
human_messages = ["Hello", "How are you?"]
for message in human_messages:
message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, message_ids)
assert start_idx != -1, f"Could not find message '{message}' in input_ids"
assert all(
label == IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(message_ids)]
), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
def test_message_field_training(self, llama3_tokenizer):
LOG.info("Testing with message_field_training")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
message_field_training="train",
message_field_training_detail="train_detail",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=[],
)
# Create a new dataset with the train and train_detail fields
modified_conversation = [
{"from": "system", "value": "You are an AI assistant.", "train": False},
{"from": "human", "value": "Hello", "train": False},
{"from": "assistant", "value": "Hello", "train": True},
{"from": "human", "value": "How are you?", "train": True},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": False},
{"begin_offset": 9, "end_offset": 18, "train": True},
{"begin_offset": 19, "end_offset": 30, "train": False},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": False,
},
{"from": "assistant", "value": "Hi there!", "train": True},
]
modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
res = strategy.tokenize_prompt(modified_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Function to find all occurrences of a sublist
def find_all_sublists(full_list, sub_list):
indices = []
for index in range(len(full_list) - len(sub_list) + 1):
if full_list[index : index + len(sub_list)] == sub_list:
indices.append(index)
return indices
# Keep track of which occurrences we've processed
processed_occurrences = {}
# Check if messages are labeled correctly based on train or train_detail
for i, turn in enumerate(modified_conversation):
turn_tokens = llama3_tokenizer.encode(
turn["value"], add_special_tokens=False
)
occurrences = find_all_sublists(input_ids, turn_tokens)
turn_key = turn["value"]
if turn_key not in processed_occurrences:
processed_occurrences[turn_key] = 0
current_occurrence = processed_occurrences[turn_key]
if current_occurrence >= len(occurrences):
assert (
False
), f"Not enough occurrences found for message: {turn['value']}"
start_idx = occurrences[current_occurrence]
processed_occurrences[turn_key] += 1
end_idx = start_idx + len(turn_tokens)
LOG.debug(
f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
)
if "train_detail" in turn:
# Get token offsets
tokenized_output = llama3_tokenizer(
turn["value"], return_offsets_mapping=True, add_special_tokens=False
)
token_offsets = tokenized_output["offset_mapping"]
# Adjust token offsets as done in the implementation
for i in range(len(token_offsets) - 1):
token_offsets[i] = (
token_offsets[i][0],
token_offsets[i + 1][0] - 1,
)
token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
# Adjust train_details
adjusted_train_details = strategy.prompter.adjust_train_details(
turn["train_detail"], token_offsets
)
LOG.debug(f"Original train_details: {turn['train_detail']}")
LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
# Handle train_detail
token_offsets = strategy.prompter.get_offsets_for_train_detail(
text=turn["value"],
train_details=adjusted_train_details,
mask_untrainable=False,
)
token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
text=turn["value"],
train_details=adjusted_train_details,
mask_untrainable=True,
)
LOG.debug(f"Token offsets: {token_offsets_masked}")
expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
for i, offset in enumerate(token_offsets_masked):
if offset != IGNORE_TOKEN_ID:
expected_labels[i] = turn_tokens[i]
actual_labels = labels[
start_idx : start_idx + len(token_offsets_masked)
]
assert (
actual_labels == expected_labels
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
for detail in adjusted_train_details:
# Find the token indices that correspond to the character offsets
detail_start = start_idx + next(
i
for i, offset in enumerate(token_offsets)
if offset >= detail["begin_offset"]
)
detail_end = start_idx + next(
(
i
for i, offset in enumerate(token_offsets)
if offset > detail["end_offset"]
),
len(token_offsets),
)
detail_text = turn["value"][
detail["begin_offset"] : detail["end_offset"] + 1
]
detail_labels = labels[detail_start:detail_end]
detail_input_ids = input_ids[detail_start:detail_end]
LOG.debug(
f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
)
LOG.debug(f"Detail input_ids: {detail_input_ids}")
LOG.debug(f"Detail labels: {detail_labels}")
LOG.debug(
f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
)
LOG.debug(
f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
)
if detail["train"]:
assert all(
label != IGNORE_TOKEN_ID for label in detail_labels
), (
f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
f"InputIDs: {detail_input_ids}, "
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
)
else:
assert all(
label == IGNORE_TOKEN_ID for label in detail_labels
), (
f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
f"InputIDs: {detail_input_ids}, "
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
)
else:
should_train = turn.get("train", False)
turn_labels = labels[start_idx:end_idx]
LOG.debug(f"Should train: {should_train}")
LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
LOG.debug(f"Turn labels: {turn_labels}")
LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
LOG.debug(
f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
)
if should_train:
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected all labels for '{turn['value']}' to be set\n"
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
f"InputIDs: {input_ids[start_idx:end_idx]}, "
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
)
else:
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
f"InputIDs: {input_ids[start_idx:end_idx]}, "
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
)
LOG.debug(
f"Processed turn: {turn['from']}, content: '{turn['value']}', "
f"start_idx: {start_idx}, end_idx: {end_idx}, "
f"labels: {labels[start_idx:end_idx]}"
)
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")
class TestAssistantChatTemplateLlama3: class TestAssistantChatTemplateLlama3:
""" """
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
""" """
def test_llama3_load(self, llama3_tokenizer, assistant_dataset): def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code LOG.info("Loading llama-3 tokenizer with assistant dataset")
strategy = load( strategy = load(
llama3_tokenizer, llama3_tokenizer,
DictDefault( DictDefault(
@@ -162,21 +753,26 @@ class TestAssistantChatTemplateLlama3:
res = strategy.tokenize_prompt(assistant_dataset[0]) res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"] input_ids = res["input_ids"]
# fmt: off # fmt: off
assert input_ids == [ expected_input_ids = [
128000, # bos 128000, # bos
128006, 882, 128007, # user header 128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot 271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header 128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot 271, 15339, 128009, # assistant response eot
128006, 882, 128007, 128006, 882, 128007,
271, 19045, 29474, 128009, 271, 19045, 29474, 128009,
128006, 78191, 128007, 128006, 78191, 128007,
271, 19045, 29474, 128009, 271, 19045, 29474, 128009,
] ]
# fmt: on # fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
def test_llama3(self, llama3_tokenizer, assistant_dataset): def test_llama3(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code LOG.info("Testing llama-3 with assistant dataset")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
@@ -189,15 +785,16 @@ class TestAssistantChatTemplateLlama3:
"system": ["system"], "system": ["system"],
}, },
), ),
llama3_tokenizer, tokenizer=llama3_tokenizer,
False, train_on_inputs=False,
512, sequence_len=512,
roles_to_train=["assistant"],
) )
strategy.messages = "messages" strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0]) res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"] input_ids = res["input_ids"]
# fmt: off # fmt: off
assert input_ids == [ expected_input_ids = [
128000, # bos 128000, # bos
128006, 882, 128007, # user header 128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot 271, 15339, 128009, # user prompt eot
@@ -209,6 +806,64 @@ class TestAssistantChatTemplateLlama3:
271, 19045, 29474, 128009, 271, 19045, 29474, 128009,
] ]
# fmt: on # fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
LOG.info("Testing llama-3 with assistant dataset including training data")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
roles={
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "messages"
prompt_tokens = strategy.prompter.build_prompt(
assistant_dataset[0]["messages"], False
)
prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False)
LOG.debug(f"Generated prompt: {prompt}")
res = strategy.tokenize_prompt(assistant_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# fmt: off
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert labels == expected_labels, (
f"Labels mismatch:\n"
f"Expected: {expected_labels}\n"
f"Actual: {labels}\n"
f"Input IDs: {input_ids}\n"
)
class TestSharegptChatTemplateLlama3: class TestSharegptChatTemplateLlama3:
@@ -216,30 +871,160 @@ class TestSharegptChatTemplateLlama3:
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy. Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
""" """
def test_llama3(self, llama3_tokenizer, sharegpt_dataset): def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
# pylint: disable=duplicate-code LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
llama3_tokenizer, tokenizer=llama3_tokenizer,
False, train_on_inputs=False,
512, train_on_eos="none",
sequence_len=512,
roles_to_train=["gpt"],
) )
res = strategy.tokenize_prompt(sharegpt_dataset[0]) res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"] input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off # fmt: off
assert input_ids == [ expected_input_ids = [
128000, # bos 128000, # bos
128006, 882, 128007, # user header 128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot 271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header 128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot 271, 15339, 128009, # assistant response eot
128006, 882, 128007, 128006, 882, 128007,
271, 19045, 29474, 128009, 271, 19045, 29474, 128009,
128006, 78191, 128007, 128006, 78191, 128007,
271, 19045, 29474, 128009, 271, 19045, 29474, 128009,
] ]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
]
# fmt: on # fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["human"],
)
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["system", "human"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 9125, 128007,
271, 2675, 527, 459, 15592, 18328, 13, 128009,
128006, 882, 128007, # user header
271, 9906, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 13347, 1070, 0, 128009, # assistant response eot
128006, 882, 128007,
271, 4438, 527, 499, 30, 128009,
128006, 78191, 128007,
271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009,
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID, # system prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -192,6 +192,7 @@ class TestSharegptLlama3:
input_ids = dataset_wrapper[0]["input_ids"] input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off # fmt: off
# pylint: disable=duplicate-code
assert input_ids == [ assert input_ids == [
128000, # bos 128000, # bos
128006, 9125, 128007, # system header 128006, 9125, 128007, # system header
@@ -228,6 +229,7 @@ class TestSharegptLlama3:
input_ids = dataset_wrapper[0]["input_ids"] input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off # fmt: off
# pylint: disable=duplicate-code
assert input_ids == [ assert input_ids == [
128000, # bos 128000, # bos
128006, 9125, 128007, # system header 128006, 9125, 128007, # system header