Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
317761406e add support for NCA 2024-05-06 17:01:14 -04:00
Wing Lian
6a9ac4ad27 consistency w sppo -> sppo_hard 2024-05-06 16:58:58 -04:00
Wing Lian
027f7d54f0 update for sppo 2024-05-06 16:55:46 -04:00
Wing Lian
0554105baa add mistral instruct strategy and fix dpo_loss input 2024-05-06 16:55:18 -04:00
Wing Lian
f58fcd09ec use DPOConfig 2024-05-06 16:55:16 -04:00
Wing Lian
60fecac367 bump trl 2024-05-06 16:54:03 -04:00
Wing Lian
b301068098 remove override 2024-05-06 16:54:02 -04:00
Wing Lian
df645906eb invert check 2024-05-06 16:54:02 -04:00
Wing Lian
7fea5822f0 add support for SPPO 2024-05-06 16:54:02 -04:00
118 changed files with 569 additions and 2184 deletions

View File

@@ -30,7 +30,7 @@ jobs:
- cuda: "121" - cuda: "121"
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.2 pytorch: 2.2.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121" - cuda: "121"
cuda_version: 12.1.0 cuda_version: 12.1.0

View File

@@ -28,7 +28,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.2 pytorch: 2.2.1
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
@@ -89,7 +89,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.2 pytorch: 2.2.1
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
@@ -125,45 +125,3 @@ jobs:
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl-cloud-term
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Build
uses: docker/build-push-action@v5
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-no-tmux
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -27,7 +27,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.2 pytorch: 2.2.1
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
@@ -89,7 +89,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.2 pytorch: 2.2.1
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0

View File

@@ -82,12 +82,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.2 pytorch: 2.2.1
num_gpus: 1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
num_gpus: 1 num_gpus: 1
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -34,7 +34,6 @@ Features:
- [Mac](#mac) - [Mac](#mac)
- [Google Colab](#google-colab) - [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset) - [Dataset](#dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
@@ -124,11 +123,11 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference # inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out" --lora_model_dir="./lora-out"
# gradio # gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out" --gradio --lora_model_dir="./lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL # remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml # Note: the yaml config must directly link to the **raw** yaml
@@ -293,42 +292,6 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
``` ```
#### Launching on public clouds via dstack
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
Write a job description in YAML as below:
```yaml
# dstack.yaml
type: task
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
env:
- HUGGING_FACE_HUB_TOKEN
- WANDB_API_KEY
commands:
- accelerate launch -m axolotl.cli.train config.yaml
ports:
- 6006
resources:
gpu:
memory: 24GB..
count: 2
```
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
```bash
pip install dstack
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
```
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
### Dataset ### Dataset
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field. Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.

View File

@@ -1,5 +1,4 @@
#!/bin/bash #!/bin/bash
set -e
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest /workspace/axolotl/tests/e2e/patched/ pytest /workspace/axolotl/tests/e2e/patched/

View File

@@ -11,7 +11,7 @@ ARG PYTORCH_VERSION="2.1.2"
ENV PYTORCH_VERSION=$PYTORCH_VERSION ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
WORKDIR /workspace WORKDIR /workspace

View File

@@ -1,27 +0,0 @@
ARG BASE_TAG=main
FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \
pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

View File

@@ -138,7 +138,7 @@ test_datasets:
data_files: data_files:
- /workspace/data/eval.jsonl - /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair' # use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard', 'nca_pair'
rl: rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing # Saves the desired chat template to the tokenizer_config.json for easier inferencing
@@ -186,11 +186,6 @@ eval_sample_packing:
# The trainer will provide recommended values for these values. # The trainer will provide recommended values for these values.
sample_packing_eff_est: sample_packing_eff_est:
total_num_tokens: total_num_tokens:
# Increasing the following values helps with packing, but usually only slightly (<%1.)
# The number of samples packed at a time.
sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200
# Passed through to transformers when loading the model when launched without accelerate # Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory # Use `sequential` when training w/ model parallelism to limit memory
@@ -290,7 +285,7 @@ lr_quadratic_warmup:
logging_steps: logging_steps:
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
save_strategy: # Set to `"no"` to skip checkpoint saves save_strategy: # Set to `no` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch save_steps: # Leave empty to save at each epoch
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time save_total_limit: # Checkpoints saved at a time

View File

@@ -38,7 +38,7 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/btlm-out output_dir: btlm-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1

View File

@@ -25,7 +25,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
batch_size: 4 batch_size: 4
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 2 num_epochs: 2

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -1,223 +1,216 @@
{ {
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "AKjdG7tbTb-n" "id": "AKjdG7tbTb-n"
}, },
"source": [ "source": [
"# Example notebook for running Axolotl on google colab" "# Example notebook for running Axolotl on google colab"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RcbNpOgWRcii"
},
"outputs": [],
"source": [
"import torch\n",
"# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
"assert (torch.cuda.is_available()==True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h3nLav8oTRA5"
},
"source": [
"## Install Axolotl and dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
}, },
"id": "3c3yGAwnOIdi", {
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01" "cell_type": "code",
}, "execution_count": null,
"outputs": [], "metadata": {
"source": [ "id": "RcbNpOgWRcii"
"!pip install torch==\"2.1.2\"\n", },
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n", "outputs": [],
"!pip install flash-attn==\"2.5.0\"\n", "source": [
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\"" "import torch\n",
] "# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
}, "assert (torch.cuda.is_available()==True)"
{ ]
"cell_type": "markdown",
"metadata": {
"id": "BW2MFr7HTjub"
},
"source": [
"## Create an yaml config file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pkF2dSoQEUN"
},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"# Your YAML string\n",
"yaml_string = \"\"\"\n",
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
"model_type: LlamaForCausalLM\n",
"tokenizer_type: LlamaTokenizer\n",
"\n",
"load_in_8bit: false\n",
"load_in_4bit: true\n",
"strict: false\n",
"\n",
"datasets:\n",
" - path: mhenrichsen/alpaca_2k_test\n",
" type: alpaca\n",
"dataset_prepared_path:\n",
"val_set_size: 0.05\n",
"output_dir: ./outputs/qlora-out\n",
"\n",
"adapter: qlora\n",
"lora_model_dir:\n",
"\n",
"sequence_len: 4096\n",
"sample_packing: true\n",
"eval_sample_packing: false\n",
"pad_to_sequence_len: true\n",
"\n",
"lora_r: 32\n",
"lora_alpha: 16\n",
"lora_dropout: 0.05\n",
"lora_target_modules:\n",
"lora_target_linear: true\n",
"lora_fan_in_fan_out:\n",
"\n",
"wandb_project:\n",
"wandb_entity:\n",
"wandb_watch:\n",
"wandb_name:\n",
"wandb_log_model:\n",
"\n",
"gradient_accumulation_steps: 4\n",
"micro_batch_size: 2\n",
"num_epochs: 4\n",
"optimizer: paged_adamw_32bit\n",
"lr_scheduler: cosine\n",
"learning_rate: 0.0002\n",
"\n",
"train_on_inputs: false\n",
"group_by_length: false\n",
"bf16: auto\n",
"fp16:\n",
"tf32: false\n",
"\n",
"gradient_checkpointing: true\n",
"early_stopping_patience:\n",
"resume_from_checkpoint:\n",
"local_rank:\n",
"logging_steps: 1\n",
"xformers_attention:\n",
"flash_attention: true\n",
"\n",
"warmup_steps: 10\n",
"evals_per_epoch: 4\n",
"saves_per_epoch: 1\n",
"debug:\n",
"deepspeed:\n",
"weight_decay: 0.0\n",
"fsdp:\n",
"fsdp_config:\n",
"special_tokens:\n",
"\n",
"\"\"\"\n",
"\n",
"# Convert the YAML string to a Python dictionary\n",
"yaml_dict = yaml.safe_load(yaml_string)\n",
"\n",
"# Specify your file path\n",
"file_path = 'test_axolotl.yaml'\n",
"\n",
"# Write the YAML file\n",
"with open(file_path, 'w') as file:\n",
" yaml.dump(yaml_dict, file)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bidoj8YLTusD"
},
"source": [
"## Launch the training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
}, },
"id": "ydTI2Jk2RStU", {
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b" "cell_type": "markdown",
}, "metadata": {
"outputs": [], "id": "h3nLav8oTRA5"
"source": [ },
"# Buy using the ! the comand will be executed as a bash command\n", "source": [
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml" "## Install Axolotl and dependencies"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3c3yGAwnOIdi",
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
},
"outputs": [],
"source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BW2MFr7HTjub"
},
"source": [
"## Create an yaml config file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pkF2dSoQEUN"
},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"# Your YAML string\n",
"yaml_string = \"\"\"\n",
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
"model_type: LlamaForCausalLM\n",
"tokenizer_type: LlamaTokenizer\n",
"is_llama_derived_model: true\n",
"\n",
"load_in_8bit: false\n",
"load_in_4bit: true\n",
"strict: false\n",
"\n",
"datasets:\n",
" - path: mhenrichsen/alpaca_2k_test\n",
" type: alpaca\n",
"dataset_prepared_path:\n",
"val_set_size: 0.05\n",
"output_dir: ./qlora-out\n",
"\n",
"adapter: qlora\n",
"lora_model_dir:\n",
"\n",
"sequence_len: 1096\n",
"sample_packing: true\n",
"pad_to_sequence_len: true\n",
"\n",
"lora_r: 32\n",
"lora_alpha: 16\n",
"lora_dropout: 0.05\n",
"lora_target_modules:\n",
"lora_target_linear: true\n",
"lora_fan_in_fan_out:\n",
"\n",
"wandb_project:\n",
"wandb_entity:\n",
"wandb_watch:\n",
"wandb_name:\n",
"wandb_log_model:\n",
"\n",
"mlflow_experiment_name: colab-example\n",
"\n",
"gradient_accumulation_steps: 1\n",
"micro_batch_size: 1\n",
"num_epochs: 4\n",
"max_steps: 20\n",
"optimizer: paged_adamw_32bit\n",
"lr_scheduler: cosine\n",
"learning_rate: 0.0002\n",
"\n",
"train_on_inputs: false\n",
"group_by_length: false\n",
"bf16: false\n",
"fp16: true\n",
"tf32: false\n",
"\n",
"gradient_checkpointing: true\n",
"early_stopping_patience:\n",
"resume_from_checkpoint:\n",
"local_rank:\n",
"logging_steps: 1\n",
"xformers_attention:\n",
"flash_attention: false\n",
"\n",
"warmup_steps: 10\n",
"evals_per_epoch:\n",
"saves_per_epoch:\n",
"debug:\n",
"deepspeed:\n",
"weight_decay: 0.0\n",
"fsdp:\n",
"fsdp_config:\n",
"special_tokens:\n",
"\n",
"\"\"\"\n",
"\n",
"# Convert the YAML string to a Python dictionary\n",
"yaml_dict = yaml.safe_load(yaml_string)\n",
"\n",
"# Specify your file path\n",
"file_path = 'test_axolotl.yaml'\n",
"\n",
"# Write the YAML file\n",
"with open(file_path, 'w') as file:\n",
" yaml.dump(yaml_dict, file)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bidoj8YLTusD"
},
"source": [
"## Launch the training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ydTI2Jk2RStU",
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
}, },
{ "nbformat": 4,
"cell_type": "markdown", "nbformat_minor": 0
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
} }

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./out
sequence_len: 512 sequence_len: 512
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./out
sequence_len: 512 sequence_len: 512
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./out
sequence_len: 512 sequence_len: 512
sample_packing: false sample_packing: false

View File

@@ -28,7 +28,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/falcon-7b output_dir: ./falcon-7b
batch_size: 2 batch_size: 2
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -42,7 +42,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
# QLoRA paper Table 9 # QLoRA paper Table 9
# - 16 for 7b & 13b # - 16 for 7b & 13b

View File

@@ -28,7 +28,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/falcon-7b output_dir: ./falcon-7b
batch_size: 2 batch_size: 2
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -12,7 +12,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./outputs/out output_dir: ./out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32

View File

@@ -23,7 +23,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 2 num_epochs: 2

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./out
sequence_len: 4096 sequence_len: 4096
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./out
sequence_len: 4096 sequence_len: 4096
sample_packing: false sample_packing: false

View File

@@ -21,7 +21,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/jeopardy-bot-7b output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/out output_dir: ./out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -33,7 +33,7 @@ wandb_project:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/model-out output_dir: ./model-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lisa-out output_dir: ./lisa-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/relora-out output_dir: ./relora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/out output_dir: ./out
sequence_len: 8192 sequence_len: 8192
sample_packing: true sample_packing: true

View File

@@ -1,76 +0,0 @@
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
chat_template: llama3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: llama3
field_messages: messages
message_field_role: role
message_field_content: content
roles:
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
@@ -24,9 +24,6 @@ lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/out/qlora-llama3-70b output_dir: ./out/qlora-llama3-70b
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./out
sequence_len: 2048 sequence_len: 2048
sample_packing: false sample_packing: false

View File

@@ -23,7 +23,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/out output_dir: ./out
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/out output_dir: ./out
sequence_len: 8192 sequence_len: 8192
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/lora-out output_dir: ./lora-out
eval_sample_packing: false eval_sample_packing: false
adapter: lora adapter: lora

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./outputs/lora-out output_dir: ./lora-out
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
model_config: model_config:
output_router_logits: true output_router_logits: true

View File

@@ -16,7 +16,7 @@ datasets:
type: chat_template.argilla type: chat_template.argilla
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./outputs/mistral-qlora-orpo-out output_dir: ./mistral-qlora-orpo-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
model_config: model_config:
output_router_logits: true output_router_logits: true

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
model_config: model_config:
output_router_logits: true output_router_logits: true

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
## You can optionally freeze the entire model and unfreeze a subset of parameters ## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters: unfrozen_parameters:

View File

@@ -21,7 +21,7 @@ model_config:
datasets: datasets:
- path: yahma/alpaca-cleaned - path: yahma/alpaca-cleaned
type: alpaca type: alpaca
output_dir: ./outputs/out output_dir: ./out
sequence_len: 8000 sequence_len: 8000
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -23,7 +23,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/mpt-alpaca-7b output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -25,7 +25,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/openllama-out output_dir: ./openllama-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -31,7 +31,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/lora-out output_dir: ./lora-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 4

View File

@@ -25,7 +25,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 4

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/phi-sft-out output_dir: ./phi-sft-out
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/phi-sft-out output_dir: ./phi-sft-out
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/phi-sft-out output_dir: ./phi-sft-out
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -26,7 +26,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/pythia-12b output_dir: ./pythia-12b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 5 num_epochs: 5

View File

@@ -20,7 +20,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/lora-alpaca-pythia output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 4 num_epochs: 4

View File

@@ -13,7 +13,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 2048 # supports up to 8192 sequence_len: 2048 # supports up to 8192
sample_packing: false sample_packing: false

View File

@@ -13,7 +13,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 2048 # supports up to 8192 sequence_len: 2048 # supports up to 8192
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/out output_dir: ./out
sequence_len: 1024 # supports up to 32k sequence_len: 1024 # supports up to 32k
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/out output_dir: ./out
sequence_len: 1024 # supports up to 32k sequence_len: 1024 # supports up to 32k
sample_packing: false sample_packing: false

View File

@@ -24,7 +24,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/redpajama-alpaca-3b output_dir: ./redpajama-alpaca-3b
batch_size: 4 batch_size: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -23,7 +23,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/lora-replit output_dir: ./lora-replit
batch_size: 8 batch_size: 8
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/out output_dir: ./out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.2 val_set_size: 0.2
output_dir: ./outputs/qlora output_dir: ./qlora
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0 val_set_size: 0
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -14,7 +14,7 @@ pretraining_dataset:
type: pretrain type: pretrain
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/model-out output_dir: ./model-out
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -11,14 +11,13 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true pad_to_sequence_len: true
lora_r: 32 lora_r: 32

View File

@@ -40,7 +40,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
# QLoRA paper Table 9 # QLoRA paper Table 9
# - 16 for 7b & 13b # - 16 for 7b & 13b

View File

@@ -33,7 +33,7 @@ eval_sample_packing: false
eval_batch_size: 1 eval_batch_size: 1
# LoRA # LoRA
output_dir: ./outputs/qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
lora_r: 32 lora_r: 32

View File

@@ -1,22 +1,22 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.11.1 peft==0.10.0
transformers==4.41.1 transformers @ git+https://github.com/huggingface/transformers.git@43d17c18360ac9c3d3491389328e2fe55fe8f9ce
tokenizers==0.19.1 tokenizers==0.15.0
bitsandbytes==0.43.1 bitsandbytes==0.43.0
accelerate==0.30.1 accelerate==0.28.0
deepspeed==0.14.2 deepspeed==0.13.1
pydantic==2.6.3 pydantic==2.6.3
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
datasets==2.19.1 datasets==2.15.0
flash-attn==2.5.8 flash-attn==2.5.5
sentencepiece sentencepiece
wandb wandb
einops einops
xformers==0.0.26.post1 xformers==0.0.22
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
colorama colorama
@@ -28,7 +28,7 @@ scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
@@ -39,6 +39,6 @@ s3fs
gcsfs gcsfs
# adlfs # adlfs
trl==0.8.6 trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -1,82 +0,0 @@
#!/bin/bash
# Export specific ENV variables to /etc/rp_environment
echo "Exporting environment variables..."
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
conda init
# this needs to come after conda init
echo 'source /etc/rp_environment' >> ~/.bashrc
add_keys_to_authorized() {
local key_value=$1
# Create the ~/.ssh directory and set permissions
mkdir -p ~/.ssh
chmod 700 ~/.ssh
# Create the authorized_keys file if it doesn't exist
touch ~/.ssh/authorized_keys
# Initialize an empty key variable
local key=""
# Read the key variable word by word
for word in $key_value; do
# Check if the word looks like the start of a key
if [[ $word == ssh-* ]]; then
# If there's a key being built, add it to the authorized_keys file
if [[ -n $key ]]; then
echo $key >> ~/.ssh/authorized_keys
fi
# Start a new key
key=$word
else
# Append the word to the current key
key="$key $word"
fi
done
# Add the last key to the authorized_keys file
if [[ -n $key ]]; then
echo $key >> ~/.ssh/authorized_keys
fi
# Set the correct permissions
chmod 600 ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
}
if [[ $PUBLIC_KEY ]]; then
# runpod
add_keys_to_authorized "$PUBLIC_KEY"
# Start the SSH service in the background
service ssh start
elif [[ $SSH_KEY ]]; then
# latitude.sh
add_keys_to_authorized "$SSH_KEY"
# Start the SSH service in the background
service ssh start
else
echo "No PUBLIC_KEY or SSH_KEY environment variable provided, not starting openSSH daemon"
fi
# Check if JUPYTER_PASSWORD is set and not empty
if [ -n "$JUPYTER_PASSWORD" ]; then
# Set JUPYTER_TOKEN to the value of JUPYTER_PASSWORD
export JUPYTER_TOKEN="$JUPYTER_PASSWORD"
fi
if [ "$JUPYTER_DISABLE" != "1" ]; then
# Run Jupyter Lab in the background
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
fi
if [ ! -d "/workspace/data/axolotl-artifacts" ]; then
mkdir -p /workspace/data/axolotl-artifacts
fi
if [ ! -L "/workspace/axolotl/outputs" ]; then
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
fi
# Execute the passed arguments (CMD)
exec "$@"

View File

@@ -5,53 +5,20 @@ echo "Exporting environment variables..."
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
echo 'source /etc/rp_environment' >> ~/.bashrc echo 'source /etc/rp_environment' >> ~/.bashrc
add_keys_to_authorized() {
local key_value=$1
# Create the ~/.ssh directory and set permissions
mkdir -p ~/.ssh
chmod 700 ~/.ssh
# Create the authorized_keys file if it doesn't exist
touch ~/.ssh/authorized_keys
# Initialize an empty key variable
local key=""
# Read the key variable word by word
for word in $key_value; do
# Check if the word looks like the start of a key
if [[ $word == ssh-* ]]; then
# If there's a key being built, add it to the authorized_keys file
if [[ -n $key ]]; then
echo $key >> ~/.ssh/authorized_keys
fi
# Start a new key
key=$word
else
# Append the word to the current key
key="$key $word"
fi
done
# Add the last key to the authorized_keys file
if [[ -n $key ]]; then
echo $key >> ~/.ssh/authorized_keys
fi
# Set the correct permissions
chmod 600 ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
}
if [[ $PUBLIC_KEY ]]; then if [[ $PUBLIC_KEY ]]; then
# runpod # runpod
add_keys_to_authorized "$PUBLIC_KEY" mkdir -p ~/.ssh
chmod 700 ~/.ssh
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
# Start the SSH service in the background # Start the SSH service in the background
service ssh start service ssh start
elif [[ $SSH_KEY ]]; then elif [ -n "$SSH_KEY" ]; then
# latitude.sh # latitude.sh
add_keys_to_authorized "$SSH_KEY" mkdir -p ~/.ssh
chmod 700 ~/.ssh
echo $SSH_KEY >> ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
# Start the SSH service in the background # Start the SSH service in the background
service ssh start service ssh start
else else
@@ -69,12 +36,5 @@ if [ "$JUPYTER_DISABLE" != "1" ]; then
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* & jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
fi fi
if [ ! -d "/workspace/data/axolotl-artifacts" ]; then
mkdir -p /workspace/data/axolotl-artifacts
fi
if [ ! -L "/workspace/axolotl/outputs" ]; then
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
fi
# Execute the passed arguments (CMD) # Execute the passed arguments (CMD)
exec "$@" exec "$@"

View File

@@ -30,11 +30,8 @@ def parse_requirements():
try: try:
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# don't install xformers on MacOS _install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
else: else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
torch_version = version("torch") torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}") _install_requires.append(f"torch=={torch_version}")
@@ -48,15 +45,9 @@ def parse_requirements():
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 3): if (major, minor) >= (2, 1):
pass _install_requires.pop(_install_requires.index("xformers==0.0.22"))
elif (major, minor) >= (2, 2): _install_requires.append("xformers>=0.0.23")
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError: except PackageNotFoundError:
pass pass
@@ -68,7 +59,7 @@ install_requires, dependency_links = parse_requirements()
setup( setup(
name="axolotl", name="axolotl",
version="0.4.1", version="0.4.0",
description="LLM Trainer", description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.", long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"}, package_dir={"": "src"},
@@ -77,13 +68,13 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.5.8", "flash-attn==2.5.5",
], ],
"fused-dense-lib": [ "fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib", "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.14.2", "deepspeed==0.13.1",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -25,8 +25,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
flash_attention=False, flash_attention=False,
deepspeed=None,
fsdp=None,
**kwargs, **kwargs,
) )
@@ -42,7 +40,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cfg.flash_attention = False parsed_cfg.flash_attention = False
parsed_cfg.deepspeed = None parsed_cfg.deepspeed = None
parsed_cfg.fsdp = None parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -19,10 +19,7 @@ from axolotl.cli import (
) )
from axolotl.common.cli import PreprocessCliArgs from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import ( from axolotl.prompt_strategies.sharegpt import register_chatml_template
register_chatml_template,
register_llama3_template,
)
LOG = logging.getLogger("axolotl.cli.preprocess") LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -39,22 +36,13 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cfg.chat_template == "chatml": if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
if parsed_cfg.default_system_message: LOG.info(
LOG.info( f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" )
) register_chatml_template(parsed_cfg.default_system_message)
register_chatml_template(parsed_cfg.default_system_message) else:
else: register_chatml_template()
register_chatml_template()
elif parsed_cfg.chat_template == "llama3":
if parsed_cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_llama3_template(parsed_cfg.default_system_message)
else:
register_llama3_template()
if not parsed_cfg.dataset_prepared_path: if not parsed_cfg.dataset_prepared_path:
msg = ( msg = (

View File

@@ -19,10 +19,7 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.prompt_strategies.sharegpt import ( from axolotl.prompt_strategies.sharegpt import register_chatml_template
register_chatml_template,
register_llama3_template,
)
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train") LOG = logging.getLogger("axolotl.cli.train")
@@ -50,14 +47,6 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else: else:
register_chatml_template() register_chatml_template()
if cfg.chat_template == "llama3" and cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
)
register_llama3_template(cfg.default_system_message)
else:
register_llama3_template()
if cfg.rl: # and cfg.rl != "orpo": if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:

199
src/axolotl/core/trainer_builder.py Executable file → Normal file
View File

@@ -30,7 +30,7 @@ from transformers import (
) )
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback, LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SaveModelCallback, SaveModelOnTrainEndCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory, causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
@@ -91,12 +91,11 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
@dataclass @dataclass
class AxolotlTrainingMixins: class AxolotlTrainingArguments(TrainingArguments):
""" """
Mixin class for the Axolotl training args. Extend the base TrainingArguments for axolotl helpers
""" """
# pylint: disable=duplicate-code
model_type: Optional[str] = field( model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."} default=None, metadata={"help": "HF model configuration model_type."}
) )
@@ -126,22 +125,14 @@ class AxolotlTrainingMixins:
default=1.0, default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."}, metadata={"help": "Sample packing efficiency for calculating batch length."},
) )
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field( max_seq_length: int = field(
default=2048, default=2048,
metadata={"help": "The maximum sequence length the model can handle"}, metadata={"help": "The maximum sequence length the model can handle"},
) )
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field( relora_steps: Optional[int] = field(
default=None, default=None,
metadata={"help": "how often to reset for ReLoRA"}, metadata={"help": "how often to reset for ReLoRA"},
@@ -228,30 +219,6 @@ class AxolotlTrainingMixins:
) )
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin.
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""
ORPO config for ORPO training
"""
@dataclass
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
"""
KTO config for KTO training
"""
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
""" """
Extend the base Trainer for axolotl helpers Extend the base Trainer for axolotl helpers
@@ -379,12 +346,11 @@ class AxolotlTrainer(Trainer):
) )
return MultipackBatchSampler( return MultipackBatchSampler(
RandomSampler(self.train_dataset), RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset),
batch_max_len=batch_max_len,
batch_size=batch_size, batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True, drop_last=True,
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
) )
if self.args.curriculum_sampling: if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset) return SequentialSampler(self.train_dataset)
@@ -404,12 +370,11 @@ class AxolotlTrainer(Trainer):
) )
return MultipackBatchSampler( return MultipackBatchSampler(
SequentialSampler(eval_dataset), SequentialSampler(eval_dataset),
lengths=get_dataset_lengths(self.eval_dataset),
batch_max_len=batch_max_len,
batch_size=batch_size, batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True, drop_last=True,
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
) )
return super()._get_eval_sampler(eval_dataset) return super()._get_eval_sampler(eval_dataset)
@@ -833,40 +798,6 @@ class AxolotlDPOTrainer(DPOTrainer):
tag_names = ["axolotl", "dpo"] tag_names = ["axolotl", "dpo"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.optimizer = None
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub) @wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str: def push_to_hub(self, *args, **kwargs) -> str:
""" """
@@ -895,14 +826,6 @@ class AxolotlORPOTrainer(ORPOTrainer):
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """
Base class for trainer builder Base class for trainer builder
@@ -1022,7 +945,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None: if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg)) callbacks.append(LossWatchDogCallback(self.cfg))
callbacks.append(SaveModelCallback()) callbacks.append(SaveModelOnTrainEndCallback())
return callbacks return callbacks
@@ -1148,6 +1071,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.save_safetensors is not None: if self.cfg.save_safetensors is not None:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[ training_arguments_kwargs[
"dataloader_pin_memory" "dataloader_pin_memory"
@@ -1195,8 +1123,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined # default to saving each epoch if not defined
training_arguments_kwargs["save_strategy"] = "epoch" training_arguments_kwargs["save_strategy"] = "epoch"
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
if self.cfg.do_bench_eval: if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset: if self.cfg.bench_dataset:
@@ -1276,14 +1202,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = [] report_to = None
if self.cfg.use_wandb: if self.cfg.use_wandb:
report_to.append("wandb") report_to = "wandb"
if self.cfg.use_mlflow: if self.cfg.use_mlflow:
report_to.append("mlflow") report_to = "mlflow"
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = ( training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None self.cfg.wandb_name if self.cfg.use_wandb else None
@@ -1323,27 +1246,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
training_arguments_kwargs["sample_packing"] = (
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) self.cfg.sample_packing if self.cfg.sample_packing else False
training_arguments_kwargs[
"multipack_real_batches"
] = not self.cfg.flash_attention
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
) )
if self.cfg.sample_packing_bin_size is not None: training_arguments_kwargs["multipack_real_batches"] = (
training_arguments_kwargs[ self.cfg.flash_attention is not True
"sample_packing_bin_size" )
] = self.cfg.sample_packing_bin_size training_arguments_kwargs["eval_sample_packing"] = (
if self.cfg.sample_packing_group_size is not None: self.cfg.sample_packing
training_arguments_kwargs[ if self.cfg.eval_sample_packing is not False
"sample_packing_group_size" else False
] = self.cfg.sample_packing_group_size )
if self.cfg.sample_packing_eff_est: training_arguments_kwargs[
training_arguments_kwargs[ "sample_packing_seq_len_multiplier"
"sample_packing_efficiency" ] = self.cfg.micro_batch_size
] = self.cfg.sample_packing_eff_est
if self.cfg.relora_steps: if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[ training_arguments_kwargs[
@@ -1513,7 +1429,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = super().get_callbacks()
callbacks.append(SaveModelCallback()) callbacks.append(SaveModelOnTrainEndCallback())
return callbacks return callbacks
@@ -1554,8 +1470,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.bf16 or self.cfg.bfloat16: if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True training_args_kwargs["bf16"] = True
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
training_args_kwargs["lr_scheduler_type"] = ( training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
) )
@@ -1608,36 +1522,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_cls = AxolotlTrainingArguments training_args_cls = TrainingArguments
if self.cfg.rl == "orpo": if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
if self.cfg.max_prompt_len: training_args_cls = DPOConfig
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
if self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg training_args = training_args_cls(
output_dir=self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps, max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate, learning_rate=self.cfg.learning_rate,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps, warmup_steps=self.cfg.warmup_steps,
logging_first_step=True, logging_first_step=True,
logging_steps=1, logging_steps=1,
@@ -1655,8 +1553,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["loss_type"] = "ipo" dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing: if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl == "kto_pair": elif self.cfg.rl in ["kto_pair", "sppo_hard", "nca_pair"]:
dpo_trainer_kwargs["loss_type"] = "kto_pair" dpo_trainer_kwargs["loss_type"] = self.cfg.rl
if self.eval_dataset: if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
@@ -1665,9 +1563,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[ dpo_trainer_kwargs[
"precompute_ref_log_probs" "precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
trainer_cls = AxolotlDPOTrainer trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1 dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer # these aren't used for the ORPO trainer
@@ -1680,9 +1578,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl == "orpo": elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
elif self.cfg.rl == "kto":
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls( dpo_trainer = trainer_cls(

View File

@@ -123,17 +123,6 @@ def get_turns( # pylint: disable=too-many-return-statements
else: else:
yield role, "" yield role, ""
return return
if self.sep_style == SeparatorStyle.LLAMA3:
if self.system_message:
# For llama3, the system message is NOT incorporated into the first human instruction
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
else:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
return
if self.sep_style == SeparatorStyle.GEMMA: if self.sep_style == SeparatorStyle.GEMMA:
if self.system_message: if self.system_message:
raise ValueError("Gemma chat template does not support system messages") raise ValueError("Gemma chat template does not support system messages")

View File

@@ -42,9 +42,9 @@ def patch_mixtral_moe_forward_zero3() -> None:
return final_hidden_states, router_logits return final_hidden_states, router_logits
from transformers.models.mixtral.modeling_mixtral import ( from transformers.models.mixtral.modeling_mixtral import (
MixtralBlockSparseTop2MLP, MixtralBLockSparseTop2MLP,
MixtralSparseMoeBlock, MixtralSparseMoeBlock,
) )
MixtralBlockSparseTop2MLP.forward = mlp_forward MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -1,267 +0,0 @@
"""module for patching with unsloth optimizations"""
import inspect
import logging
import re
import types
from typing import Tuple
from peft import PeftModelForCausalLM
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """ if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """ if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
"""
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
""".lstrip(
"\n"
)
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
""".lstrip(
"\n"
)
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip(
"\n"
)
PATCHED_O_CODE = """
attn_output = self.apply_o(self, attn_output)
""".lstrip(
"\n"
)
def original_apply_qkv(self, hidden_states):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def original_apply_o(self, hidden_states):
attn_output = self.o_proj(hidden_states)
return attn_output
def get_forward_code() -> str:
forward = inspect.getsource(LlamaForCausalLM.forward)
return forward
def test_cel_is_patchable() -> bool:
forward = get_forward_code()
return ORIGINAL_CEL_CODE in forward
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
def test_self_attn_is_patchable() -> bool:
qkv = get_self_attn_code()
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
def integrate_cross_entropy_loss_patch():
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth fast_cross_entropy_loss")
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
def detab_code(code: str) -> Tuple[str, str]:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
return code, spaces
def patch_self_attn_lora():
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward
)
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def unsloth_attn_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in self_attn_forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth attn lora")
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
from unsloth.kernels import apply_lora_mlp_swiglu
apply_lora_mlp = apply_lora_mlp_swiglu
elif peft_model.base_model.config.model_type == "gemma":
from unsloth.kernels import apply_lora_mlp_geglu_approx
apply_lora_mlp = apply_lora_mlp_geglu_approx
else:
raise NotImplementedError(
f"Model type {peft_model.base_model.config.model_type} not supported"
)
for idx, layer in enumerate(peft_model.model.model.layers):
layer_modules = [
getattr(layer.mlp, linear_proj)
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
]
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
mlp_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
mlp_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else:
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
from unsloth.kernels import apply_lora_o, apply_lora_qkv
for idx, layer in enumerate(peft_model.model.model.layers):
if cfg.unsloth_lora_qkv:
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
qkv_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
qkv_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
layer.self_attn.apply_qkv = apply_lora_qkv
else:
layer.self_attn.apply_qkv = original_apply_qkv
logging.warning(
"unable to apply unsloth lora qkv patch to layer %d", idx
)
if cfg.unsloth_lora_o:
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
o_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
o_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)
if is_o_lora and o_no_bias and o_not_dora:
layer.self_attn.apply_o = apply_lora_o
else:
layer.self_attn.apply_o = original_apply_o
logging.warning(
"unable to apply unsloth lora o_proj patch to layer %d", idx
)

View File

@@ -1,56 +1,24 @@
""" """
HF Chat Templates prompt strategy HF Chat Templates prompt strategy
""" """
from typing import Any, Dict, Optional
import logging
from typing import Any, Dict, List, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import chat_templates
LOG = logging.getLogger("axolotl")
class ChatTemplatePrompter(Prompter): class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates""" """prompter for HF chat templates"""
def __init__( def __init__(self, tokenizer, chat_template=None, max_length=2048):
self,
tokenizer,
chat_template=None,
max_length=2048,
message_field_role: str = "from",
message_field_content: str = "value",
roles: Optional[Dict[str, List[str]]] = None,
):
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
else:
self.roles = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
"system": "system",
}
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.chat_template = chat_template self.chat_template = chat_template
self.max_length = max_length self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False): def build_prompt(self, conversation, add_generation_prompt=False):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
}
for t in conversation
]
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
turns, conversation,
truncation=True, truncation=True,
max_length=self.max_length, max_length=self.max_length,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
@@ -63,19 +31,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts. Tokenizing strategy for instruction-based prompts.
""" """
_messages = "conversations"
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt) turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True) prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns) input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs: if not self.train_on_inputs:
@@ -93,37 +51,28 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt return tokenized_prompt
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
return prompt[self.messages] conversations = prompt["conversations"]
# remap roles - allow for assistant turn
role_map = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
return turns
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = ( chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml" ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
) )
message_field_role = (
ds_cfg["message_field_role"]
if ds_cfg and "message_field_role" in ds_cfg
else "from"
)
message_field_content = (
ds_cfg["message_field_content"]
if ds_cfg and "message_field_content" in ds_cfg
else "value"
)
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
tokenizer,
chat_templates(chat_template),
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,
),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy return strategy

View File

@@ -1,133 +0,0 @@
"""
DPO strategies for llama-3 chat template
"""
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample
return transform_fn
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -0,0 +1,30 @@
"""
DPO strategies for mistral instruct
"""
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
sample["chosen"] = f"{sample['chosen']}"
sample["rejected"] = f"{sample['rejected']}"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
return sample
return transform_fn

View File

@@ -1,9 +0,0 @@
"""
module for KTO style dataset transform strategies
"""
from functools import partial
from ..base import load as load_base
load = partial(load_base, module_base="axolotl.prompt_strategies.kto")

View File

@@ -1,105 +0,0 @@
"""
KTO strategies for chatml
"""
# pylint: disable=duplicate-code
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/kto-mix-15k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
return sample
return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca KTO
ex: argilla/distilabel-intel-orca-kto
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
return transform_fn
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
return transform_fn

View File

@@ -1,105 +0,0 @@
"""
KTO strategies for llama-3 chat template
"""
# pylint: disable=duplicate-code
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/kto-mix-15k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
return sample
return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca KTO
ex: argilla/distilabel-intel-orca-kto
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
return transform_fn
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -1,39 +0,0 @@
"""
User-defined KTO strategies
"""
# pylint: disable=duplicate-code
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
ds_cfg = cfg["datasets"][dataset_idx]["type"]
if not isinstance(ds_cfg, dict):
raise ValueError(
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
)
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_completion = ds_cfg.get("field_completion", "completion")
field_label = ds_cfg.get("field_label", "label")
prompt_format = ds_cfg.get("prompt_format")
if not prompt_format:
prompt_format = "{" + field_prompt + "}"
completion_format = ds_cfg.get("completion_format")
if not completion_format:
chosen_format = "{" + field_completion + "}"
def transform_fn(sample):
if (
"{" + field_system + "}" in prompt_format
and field_system in sample
and sample[field_system]
):
sample["prompt"] = prompt_format.format(
system=sample[field_system], prompt=sample[field_prompt]
)
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["completion"] = chosen_format.format(chosen=sample[field_completion])
sample["label"] = sample[field_label]
return sample
return transform_fn

View File

@@ -1,7 +1,7 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging import logging
from typing import Any, Dict, Optional, Type from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
name="chatml", name="chatml",
system_template="<|im_start|>system\n{system_message}", system_template="<|im_start|>system\n{system_message}",
system_message=system_message, system_message=system_message,
roles=("<|im_start|>user", "<|im_start|>assistant"), roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML, sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>", sep="<|im_end|>",
) )
@@ -32,65 +32,83 @@ def register_chatml_template(system_message=None):
name="chatml_glaive", name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}", system_template="<|im_start|>system\n{system_message}",
system_message=system_message, system_message=system_message,
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"), roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
sep_style=SeparatorStyle.CHATML, sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>", sep="<|im_end|>",
) )
) )
def register_llama3_template(system_message=None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
system_message = system_message or "You are a helpful assistant." conversation = (
register_conv_template( ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
Conversation( )
name="llama3", field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
system_message=system_message, roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
roles=("user", "assistant"), strategy = SimpleShareGPTPromptTokenizingStrategy(
sep_style=SeparatorStyle.LLAMA3, ShareGPTPrompterV2(
sep="", conversation=conversation,
stop_str="<|eot_id|>", role_key_model=field_model,
stop_token_ids=[128001, 128009], role_key_human=field_human,
) roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
def build_loader( def load_guanaco(tokenizer, cfg):
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"], return GuanacoShareGPTPromptTokenizingStrategy(
prompter_cls: Type["ShareGPTPrompterV2"], ShareGPTPrompterV2(),
default_conversation: Optional[str] = None, tokenizer,
): cfg.train_on_inputs,
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): cfg.sequence_len,
conversation = ( )
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else default_conversation
)
field_human = (
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
)
field_model = (
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
)
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = tokenization_strategy_cls(
prompter_cls(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy
return _load
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -99,7 +117,6 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
""" """
_strict = False _strict = False
_messages = "conversations"
@property @property
def strict(self): def strict(self):
@@ -109,16 +126,8 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
def strict(self, strict): def strict(self, strict):
self._strict = strict self._strict = strict
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
conversations = prompt[self.messages] conversations = prompt["conversations"]
if self.strict: if self.strict:
return conversations return conversations
role_key = "from" role_key = "from"
@@ -149,9 +158,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
return turns return turns
class SimpleRoleShareGPTPromptTokenizingStrategy( class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
SimpleShareGPTPromptTokenizingStrategy
):
""" """
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
@@ -202,16 +209,3 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
conversation = merge_consecutive_messages(conversation) conversation = merge_consecutive_messages(conversation)
return conversation return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -263,7 +263,6 @@ CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}", "chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>", "zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}", "vicuna_v1.1": "{ROLE}",
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
} }
@@ -349,10 +348,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
) )
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
if ( LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])

View File

@@ -197,13 +197,6 @@ def train(
trainer.accelerator.wait_for_everyone() trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped) unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
# the trainer saved a model.safetensors file in the output directory,
# but it is a proxy model and should be deleted
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
LOG.info("This is a proxy model and should be deleted")
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin. # `zero3_save_16bit_model` is True in DeepSpeed Plugin.
@@ -219,10 +212,6 @@ def train(
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:

View File

@@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import math
import os import os
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
@@ -776,27 +775,9 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
return control return control
class SaveModelCallback(TrainerCallback): class SaveModelOnTrainEndCallback(TrainerCallback):
"""Callback to save model on train end""" """Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
elif (
args.save_strategy == IntervalStrategy.STEPS
and state.save_steps < 1.0
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
):
# workaround to save model on fractional save_steps
control.should_save = True
def on_train_end( # pylint: disable=unused-argument def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs self, args, state, control, **kwargs
): ):

Some files were not shown because too many files have changed in this diff Show More