Compare commits
39 Commits
custom-tra
...
mora
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7ec10e337 | ||
|
|
05b0bd08d2 | ||
|
|
d4f6c65e4c | ||
|
|
a944f7b32b | ||
|
|
9d4225a058 | ||
|
|
f7332ac449 | ||
|
|
16d46b74e4 | ||
|
|
a6b37bdeb4 | ||
|
|
b7520801a3 | ||
|
|
fe650dd326 | ||
|
|
49b967b62f | ||
|
|
65db903714 | ||
|
|
6a5a725f10 | ||
|
|
f5febc729a | ||
|
|
230e0ac363 | ||
|
|
cc11c6bce2 | ||
|
|
5f91064040 | ||
|
|
ef223519c9 | ||
|
|
8a20a7b711 | ||
|
|
367b2e879b | ||
|
|
bbfed318bc | ||
|
|
84bb8061ba | ||
|
|
a27d5e1f4e | ||
|
|
6299eb5919 | ||
|
|
7c2bf3091f | ||
|
|
22ae21a6c2 | ||
|
|
ba45531802 | ||
|
|
8a1572a831 | ||
|
|
702a669cad | ||
|
|
891ae8aa13 | ||
|
|
0c49ecc429 | ||
|
|
60113437e4 | ||
|
|
419b2a6a98 | ||
|
|
2501a371c6 | ||
|
|
e6937e884b | ||
|
|
039e2a0370 | ||
|
|
4fde300e5f | ||
|
|
3319780300 | ||
|
|
81da7d2531 |
2
.github/workflows/base.yml
vendored
2
.github/workflows/base.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
- cuda: "121"
|
- cuda: "121"
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "121"
|
- cuda: "121"
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
|
|||||||
46
.github/workflows/main.yml
vendored
46
.github/workflows/main.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
@@ -125,3 +125,45 @@ jobs:
|
|||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
|
build-axolotl-cloud-no-tmux:
|
||||||
|
needs: build-axolotl
|
||||||
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||||
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
axolotl_extras:
|
||||||
|
runs-on: axolotl-gpu-runner
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Docker metadata
|
||||||
|
id: metadata
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: winglian/axolotl-cloud-term
|
||||||
|
- name: Login to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v2
|
||||||
|
- name: Build
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
build-args: |
|
||||||
|
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
CUDA=${{ matrix.cuda }}
|
||||||
|
file: ./docker/Dockerfile-cloud-no-tmux
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
tags: |
|
||||||
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|||||||
4
.github/workflows/nightlies.yml
vendored
4
.github/workflows/nightlies.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
|
|||||||
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
@@ -82,7 +82,12 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
|
num_gpus: 1
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
@@ -124,11 +124,11 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
|||||||
|
|
||||||
# inference
|
# inference
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--lora_model_dir="./lora-out"
|
--lora_model_dir="./outputs/lora-out"
|
||||||
|
|
||||||
# gradio
|
# gradio
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--lora_model_dir="./lora-out" --gradio
|
--lora_model_dir="./outputs/lora-out" --gradio
|
||||||
|
|
||||||
# remote yaml files - the yaml config can be hosted on a public URL
|
# remote yaml files - the yaml config can be hosted on a public URL
|
||||||
# Note: the yaml config must directly link to the **raw** yaml
|
# Note: the yaml config must directly link to the **raw** yaml
|
||||||
@@ -302,7 +302,7 @@ Write a job description in YAML as below:
|
|||||||
# dstack.yaml
|
# dstack.yaml
|
||||||
type: task
|
type: task
|
||||||
|
|
||||||
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.1
|
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
|
||||||
|
|
||||||
env:
|
env:
|
||||||
- HUGGING_FACE_HUB_TOKEN
|
- HUGGING_FACE_HUB_TOKEN
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
pytest /workspace/axolotl/tests/e2e/patched/
|
pytest /workspace/axolotl/tests/e2e/patched/
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ ARG PYTORCH_VERSION="2.1.2"
|
|||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
|||||||
27
docker/Dockerfile-cloud-no-tmux
Normal file
27
docker/Dockerfile-cloud-no-tmux
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
ARG BASE_TAG=main
|
||||||
|
FROM winglian/axolotl:$BASE_TAG
|
||||||
|
|
||||||
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
EXPOSE 8888
|
||||||
|
EXPOSE 22
|
||||||
|
|
||||||
|
COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh
|
||||||
|
COPY scripts/motd /etc/motd
|
||||||
|
|
||||||
|
RUN pip install jupyterlab notebook ipywidgets && \
|
||||||
|
jupyter lab clean
|
||||||
|
RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \
|
||||||
|
pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \
|
||||||
|
mkdir -p ~/.ssh && \
|
||||||
|
chmod 700 ~/.ssh && \
|
||||||
|
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||||
|
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||||
|
chmod +x /root/cloud-entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
||||||
|
CMD ["sleep", "infinity"]
|
||||||
@@ -186,6 +186,11 @@ eval_sample_packing:
|
|||||||
# The trainer will provide recommended values for these values.
|
# The trainer will provide recommended values for these values.
|
||||||
sample_packing_eff_est:
|
sample_packing_eff_est:
|
||||||
total_num_tokens:
|
total_num_tokens:
|
||||||
|
# Increasing the following values helps with packing, but usually only slightly (<%1.)
|
||||||
|
# The number of samples packed at a time.
|
||||||
|
sample_packing_group_size: 100000
|
||||||
|
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||||
|
sample_packing_bin_size: 200
|
||||||
|
|
||||||
# Passed through to transformers when loading the model when launched without accelerate
|
# Passed through to transformers when loading the model when launched without accelerate
|
||||||
# Use `sequential` when training w/ model parallelism to limit memory
|
# Use `sequential` when training w/ model parallelism to limit memory
|
||||||
@@ -285,7 +290,7 @@ lr_quadratic_warmup:
|
|||||||
logging_steps:
|
logging_steps:
|
||||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||||
save_strategy: # Set to `no` to skip checkpoint saves
|
save_strategy: # Set to `"no"` to skip checkpoint saves
|
||||||
save_steps: # Leave empty to save at each epoch
|
save_steps: # Leave empty to save at each epoch
|
||||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||||
save_total_limit: # Checkpoints saved at a time
|
save_total_limit: # Checkpoints saved at a time
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
output_dir: btlm-out
|
output_dir: ./outputs/btlm-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 2
|
num_epochs: 2
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -1,216 +1,223 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "AKjdG7tbTb-n"
|
"id": "AKjdG7tbTb-n"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# Example notebook for running Axolotl on google colab"
|
"# Example notebook for running Axolotl on google colab"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "RcbNpOgWRcii"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
|
|
||||||
"assert (torch.cuda.is_available()==True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "h3nLav8oTRA5"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Install Axolotl and dependencies"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "3c3yGAwnOIdi",
|
|
||||||
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!pip install torch==\"2.1.2\"\n",
|
|
||||||
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
|
|
||||||
"!pip install flash-attn==\"2.5.0\"\n",
|
|
||||||
"!pip install deepspeed==\"0.13.1\""
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "BW2MFr7HTjub"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Create an yaml config file"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "9pkF2dSoQEUN"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import yaml\n",
|
|
||||||
"\n",
|
|
||||||
"# Your YAML string\n",
|
|
||||||
"yaml_string = \"\"\"\n",
|
|
||||||
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
|
|
||||||
"model_type: LlamaForCausalLM\n",
|
|
||||||
"tokenizer_type: LlamaTokenizer\n",
|
|
||||||
"is_llama_derived_model: true\n",
|
|
||||||
"\n",
|
|
||||||
"load_in_8bit: false\n",
|
|
||||||
"load_in_4bit: true\n",
|
|
||||||
"strict: false\n",
|
|
||||||
"\n",
|
|
||||||
"datasets:\n",
|
|
||||||
" - path: mhenrichsen/alpaca_2k_test\n",
|
|
||||||
" type: alpaca\n",
|
|
||||||
"dataset_prepared_path:\n",
|
|
||||||
"val_set_size: 0.05\n",
|
|
||||||
"output_dir: ./qlora-out\n",
|
|
||||||
"\n",
|
|
||||||
"adapter: qlora\n",
|
|
||||||
"lora_model_dir:\n",
|
|
||||||
"\n",
|
|
||||||
"sequence_len: 1096\n",
|
|
||||||
"sample_packing: true\n",
|
|
||||||
"pad_to_sequence_len: true\n",
|
|
||||||
"\n",
|
|
||||||
"lora_r: 32\n",
|
|
||||||
"lora_alpha: 16\n",
|
|
||||||
"lora_dropout: 0.05\n",
|
|
||||||
"lora_target_modules:\n",
|
|
||||||
"lora_target_linear: true\n",
|
|
||||||
"lora_fan_in_fan_out:\n",
|
|
||||||
"\n",
|
|
||||||
"wandb_project:\n",
|
|
||||||
"wandb_entity:\n",
|
|
||||||
"wandb_watch:\n",
|
|
||||||
"wandb_name:\n",
|
|
||||||
"wandb_log_model:\n",
|
|
||||||
"\n",
|
|
||||||
"mlflow_experiment_name: colab-example\n",
|
|
||||||
"\n",
|
|
||||||
"gradient_accumulation_steps: 1\n",
|
|
||||||
"micro_batch_size: 1\n",
|
|
||||||
"num_epochs: 4\n",
|
|
||||||
"max_steps: 20\n",
|
|
||||||
"optimizer: paged_adamw_32bit\n",
|
|
||||||
"lr_scheduler: cosine\n",
|
|
||||||
"learning_rate: 0.0002\n",
|
|
||||||
"\n",
|
|
||||||
"train_on_inputs: false\n",
|
|
||||||
"group_by_length: false\n",
|
|
||||||
"bf16: false\n",
|
|
||||||
"fp16: true\n",
|
|
||||||
"tf32: false\n",
|
|
||||||
"\n",
|
|
||||||
"gradient_checkpointing: true\n",
|
|
||||||
"early_stopping_patience:\n",
|
|
||||||
"resume_from_checkpoint:\n",
|
|
||||||
"local_rank:\n",
|
|
||||||
"logging_steps: 1\n",
|
|
||||||
"xformers_attention:\n",
|
|
||||||
"flash_attention: false\n",
|
|
||||||
"\n",
|
|
||||||
"warmup_steps: 10\n",
|
|
||||||
"evals_per_epoch:\n",
|
|
||||||
"saves_per_epoch:\n",
|
|
||||||
"debug:\n",
|
|
||||||
"deepspeed:\n",
|
|
||||||
"weight_decay: 0.0\n",
|
|
||||||
"fsdp:\n",
|
|
||||||
"fsdp_config:\n",
|
|
||||||
"special_tokens:\n",
|
|
||||||
"\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"# Convert the YAML string to a Python dictionary\n",
|
|
||||||
"yaml_dict = yaml.safe_load(yaml_string)\n",
|
|
||||||
"\n",
|
|
||||||
"# Specify your file path\n",
|
|
||||||
"file_path = 'test_axolotl.yaml'\n",
|
|
||||||
"\n",
|
|
||||||
"# Write the YAML file\n",
|
|
||||||
"with open(file_path, 'w') as file:\n",
|
|
||||||
" yaml.dump(yaml_dict, file)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "bidoj8YLTusD"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Launch the training"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "ydTI2Jk2RStU",
|
|
||||||
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Buy using the ! the comand will be executed as a bash command\n",
|
|
||||||
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Play with inference"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Buy using the ! the comand will be executed as a bash command\n",
|
|
||||||
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
|
||||||
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"accelerator": "GPU",
|
|
||||||
"colab": {
|
|
||||||
"gpuType": "T4",
|
|
||||||
"provenance": []
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
{
|
||||||
"nbformat_minor": 0
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "RcbNpOgWRcii"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
|
||||||
|
"assert (torch.cuda.is_available()==True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "h3nLav8oTRA5"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Install Axolotl and dependencies"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "3c3yGAwnOIdi",
|
||||||
|
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install torch==\"2.1.2\"\n",
|
||||||
|
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
|
||||||
|
"!pip install flash-attn==\"2.5.0\"\n",
|
||||||
|
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "BW2MFr7HTjub"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Create an yaml config file"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "9pkF2dSoQEUN"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import yaml\n",
|
||||||
|
"\n",
|
||||||
|
"# Your YAML string\n",
|
||||||
|
"yaml_string = \"\"\"\n",
|
||||||
|
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
|
||||||
|
"model_type: LlamaForCausalLM\n",
|
||||||
|
"tokenizer_type: LlamaTokenizer\n",
|
||||||
|
"\n",
|
||||||
|
"load_in_8bit: false\n",
|
||||||
|
"load_in_4bit: true\n",
|
||||||
|
"strict: false\n",
|
||||||
|
"\n",
|
||||||
|
"datasets:\n",
|
||||||
|
" - path: mhenrichsen/alpaca_2k_test\n",
|
||||||
|
" type: alpaca\n",
|
||||||
|
"dataset_prepared_path:\n",
|
||||||
|
"val_set_size: 0.05\n",
|
||||||
|
"output_dir: ./outputs/qlora-out\n",
|
||||||
|
"\n",
|
||||||
|
"adapter: qlora\n",
|
||||||
|
"lora_model_dir:\n",
|
||||||
|
"\n",
|
||||||
|
"sequence_len: 4096\n",
|
||||||
|
"sample_packing: true\n",
|
||||||
|
"eval_sample_packing: false\n",
|
||||||
|
"pad_to_sequence_len: true\n",
|
||||||
|
"\n",
|
||||||
|
"lora_r: 32\n",
|
||||||
|
"lora_alpha: 16\n",
|
||||||
|
"lora_dropout: 0.05\n",
|
||||||
|
"lora_target_modules:\n",
|
||||||
|
"lora_target_linear: true\n",
|
||||||
|
"lora_fan_in_fan_out:\n",
|
||||||
|
"\n",
|
||||||
|
"wandb_project:\n",
|
||||||
|
"wandb_entity:\n",
|
||||||
|
"wandb_watch:\n",
|
||||||
|
"wandb_name:\n",
|
||||||
|
"wandb_log_model:\n",
|
||||||
|
"\n",
|
||||||
|
"gradient_accumulation_steps: 4\n",
|
||||||
|
"micro_batch_size: 2\n",
|
||||||
|
"num_epochs: 4\n",
|
||||||
|
"optimizer: paged_adamw_32bit\n",
|
||||||
|
"lr_scheduler: cosine\n",
|
||||||
|
"learning_rate: 0.0002\n",
|
||||||
|
"\n",
|
||||||
|
"train_on_inputs: false\n",
|
||||||
|
"group_by_length: false\n",
|
||||||
|
"bf16: auto\n",
|
||||||
|
"fp16:\n",
|
||||||
|
"tf32: false\n",
|
||||||
|
"\n",
|
||||||
|
"gradient_checkpointing: true\n",
|
||||||
|
"early_stopping_patience:\n",
|
||||||
|
"resume_from_checkpoint:\n",
|
||||||
|
"local_rank:\n",
|
||||||
|
"logging_steps: 1\n",
|
||||||
|
"xformers_attention:\n",
|
||||||
|
"flash_attention: true\n",
|
||||||
|
"\n",
|
||||||
|
"warmup_steps: 10\n",
|
||||||
|
"evals_per_epoch: 4\n",
|
||||||
|
"saves_per_epoch: 1\n",
|
||||||
|
"debug:\n",
|
||||||
|
"deepspeed:\n",
|
||||||
|
"weight_decay: 0.0\n",
|
||||||
|
"fsdp:\n",
|
||||||
|
"fsdp_config:\n",
|
||||||
|
"special_tokens:\n",
|
||||||
|
"\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Convert the YAML string to a Python dictionary\n",
|
||||||
|
"yaml_dict = yaml.safe_load(yaml_string)\n",
|
||||||
|
"\n",
|
||||||
|
"# Specify your file path\n",
|
||||||
|
"file_path = 'test_axolotl.yaml'\n",
|
||||||
|
"\n",
|
||||||
|
"# Write the YAML file\n",
|
||||||
|
"with open(file_path, 'w') as file:\n",
|
||||||
|
" yaml.dump(yaml_dict, file)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "bidoj8YLTusD"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Launch the training"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "ydTI2Jk2RStU",
|
||||||
|
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Buy using the ! the comand will be executed as a bash command\n",
|
||||||
|
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Play with inference"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Buy using the ! the comand will be executed as a bash command\n",
|
||||||
|
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
||||||
|
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"gpuType": "T4",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./outputs/falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
# QLoRA paper Table 9
|
# QLoRA paper Table 9
|
||||||
# - 16 for 7b & 13b
|
# - 16 for 7b & 13b
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./outputs/falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 2
|
num_epochs: 2
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./jeopardy-bot-7b
|
output_dir: ./outputs/jeopardy-bot-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ wandb_project:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./model-out
|
output_dir: ./outputs/model-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lisa-out
|
output_dir: ./outputs/lisa-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./relora-out
|
output_dir: ./outputs/relora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
76
examples/llama-3/instruct-lora-8b.yml
Normal file
76
examples/llama-3/instruct-lora-8b.yml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: llama3
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
chat_template: llama3
|
||||||
|
field_messages: messages
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
@@ -24,6 +24,9 @@ lora_alpha: 16
|
|||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
|
lora_modules_to_save:
|
||||||
|
- embed_tokens
|
||||||
|
- lm_head
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out/qlora-llama3-70b
|
output_dir: ./outputs/out/qlora-llama3-70b
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
model_config:
|
model_config:
|
||||||
output_router_logits: true
|
output_router_logits: true
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ datasets:
|
|||||||
type: chat_template.argilla
|
type: chat_template.argilla
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./mistral-qlora-orpo-out
|
output_dir: ./outputs/mistral-qlora-orpo-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
model_config:
|
model_config:
|
||||||
output_router_logits: true
|
output_router_logits: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
model_config:
|
model_config:
|
||||||
output_router_logits: true
|
output_router_logits: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||||
unfrozen_parameters:
|
unfrozen_parameters:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ model_config:
|
|||||||
datasets:
|
datasets:
|
||||||
- path: yahma/alpaca-cleaned
|
- path: yahma/alpaca-cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 8000
|
sequence_len: 8000
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./mpt-alpaca-7b
|
output_dir: ./outputs/mpt-alpaca-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./openllama-out
|
output_dir: ./outputs/openllama-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./phi-sft-out
|
output_dir: ./outputs/phi-sft-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./phi-sft-out
|
output_dir: ./outputs/phi-sft-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./phi-sft-out
|
output_dir: ./outputs/phi-sft-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./pythia-12b
|
output_dir: ./outputs/pythia-12b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 5
|
num_epochs: 5
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-alpaca-pythia
|
output_dir: ./outputs/lora-alpaca-pythia
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 2048 # supports up to 8192
|
sequence_len: 2048 # supports up to 8192
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 2048 # supports up to 8192
|
sequence_len: 2048 # supports up to 8192
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 1024 # supports up to 32k
|
sequence_len: 1024 # supports up to 32k
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 1024 # supports up to 32k
|
sequence_len: 1024 # supports up to 32k
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./redpajama-alpaca-3b
|
output_dir: ./outputs/redpajama-alpaca-3b
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-replit
|
output_dir: ./outputs/lora-replit
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.2
|
val_set_size: 0.2
|
||||||
output_dir: ./qlora
|
output_dir: ./outputs/qlora
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ pretraining_dataset:
|
|||||||
type: pretrain
|
type: pretrain
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./model-out
|
output_dir: ./outputs/model-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
# QLoRA paper Table 9
|
# QLoRA paper Table 9
|
||||||
# - 16 for 7b & 13b
|
# - 16 for 7b & 13b
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ eval_sample_packing: false
|
|||||||
eval_batch_size: 1
|
eval_batch_size: 1
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -1,22 +1,22 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.10.0
|
peft==0.11.1
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@43d17c18360ac9c3d3491389328e2fe55fe8f9ce
|
transformers==4.41.1
|
||||||
tokenizers==0.15.0
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.0
|
bitsandbytes==0.43.1
|
||||||
accelerate==0.28.0
|
accelerate==0.30.1
|
||||||
deepspeed==0.13.1
|
deepspeed==0.14.2
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
datasets==2.15.0
|
datasets==2.19.1
|
||||||
flash-attn==2.5.5
|
flash-attn==2.5.8
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers==0.0.22
|
xformers==0.0.26.post1
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
@@ -39,6 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl==0.8.5
|
trl==0.8.6
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|||||||
82
scripts/cloud-entrypoint-term.sh
Executable file
82
scripts/cloud-entrypoint-term.sh
Executable file
@@ -0,0 +1,82 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Export specific ENV variables to /etc/rp_environment
|
||||||
|
echo "Exporting environment variables..."
|
||||||
|
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
||||||
|
conda init
|
||||||
|
# this needs to come after conda init
|
||||||
|
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||||
|
|
||||||
|
add_keys_to_authorized() {
|
||||||
|
local key_value=$1
|
||||||
|
|
||||||
|
# Create the ~/.ssh directory and set permissions
|
||||||
|
mkdir -p ~/.ssh
|
||||||
|
chmod 700 ~/.ssh
|
||||||
|
|
||||||
|
# Create the authorized_keys file if it doesn't exist
|
||||||
|
touch ~/.ssh/authorized_keys
|
||||||
|
|
||||||
|
# Initialize an empty key variable
|
||||||
|
local key=""
|
||||||
|
|
||||||
|
# Read the key variable word by word
|
||||||
|
for word in $key_value; do
|
||||||
|
# Check if the word looks like the start of a key
|
||||||
|
if [[ $word == ssh-* ]]; then
|
||||||
|
# If there's a key being built, add it to the authorized_keys file
|
||||||
|
if [[ -n $key ]]; then
|
||||||
|
echo $key >> ~/.ssh/authorized_keys
|
||||||
|
fi
|
||||||
|
# Start a new key
|
||||||
|
key=$word
|
||||||
|
else
|
||||||
|
# Append the word to the current key
|
||||||
|
key="$key $word"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Add the last key to the authorized_keys file
|
||||||
|
if [[ -n $key ]]; then
|
||||||
|
echo $key >> ~/.ssh/authorized_keys
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set the correct permissions
|
||||||
|
chmod 600 ~/.ssh/authorized_keys
|
||||||
|
chmod 700 -R ~/.ssh
|
||||||
|
}
|
||||||
|
|
||||||
|
if [[ $PUBLIC_KEY ]]; then
|
||||||
|
# runpod
|
||||||
|
add_keys_to_authorized "$PUBLIC_KEY"
|
||||||
|
# Start the SSH service in the background
|
||||||
|
service ssh start
|
||||||
|
elif [[ $SSH_KEY ]]; then
|
||||||
|
# latitude.sh
|
||||||
|
add_keys_to_authorized "$SSH_KEY"
|
||||||
|
# Start the SSH service in the background
|
||||||
|
service ssh start
|
||||||
|
else
|
||||||
|
echo "No PUBLIC_KEY or SSH_KEY environment variable provided, not starting openSSH daemon"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if JUPYTER_PASSWORD is set and not empty
|
||||||
|
if [ -n "$JUPYTER_PASSWORD" ]; then
|
||||||
|
# Set JUPYTER_TOKEN to the value of JUPYTER_PASSWORD
|
||||||
|
export JUPYTER_TOKEN="$JUPYTER_PASSWORD"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
||||||
|
# Run Jupyter Lab in the background
|
||||||
|
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d "/workspace/data/axolotl-artifacts" ]; then
|
||||||
|
mkdir -p /workspace/data/axolotl-artifacts
|
||||||
|
fi
|
||||||
|
if [ ! -L "/workspace/axolotl/outputs" ]; then
|
||||||
|
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Execute the passed arguments (CMD)
|
||||||
|
exec "$@"
|
||||||
@@ -5,20 +5,53 @@ echo "Exporting environment variables..."
|
|||||||
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
||||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||||
|
|
||||||
|
add_keys_to_authorized() {
|
||||||
|
local key_value=$1
|
||||||
|
|
||||||
|
# Create the ~/.ssh directory and set permissions
|
||||||
|
mkdir -p ~/.ssh
|
||||||
|
chmod 700 ~/.ssh
|
||||||
|
|
||||||
|
# Create the authorized_keys file if it doesn't exist
|
||||||
|
touch ~/.ssh/authorized_keys
|
||||||
|
|
||||||
|
# Initialize an empty key variable
|
||||||
|
local key=""
|
||||||
|
|
||||||
|
# Read the key variable word by word
|
||||||
|
for word in $key_value; do
|
||||||
|
# Check if the word looks like the start of a key
|
||||||
|
if [[ $word == ssh-* ]]; then
|
||||||
|
# If there's a key being built, add it to the authorized_keys file
|
||||||
|
if [[ -n $key ]]; then
|
||||||
|
echo $key >> ~/.ssh/authorized_keys
|
||||||
|
fi
|
||||||
|
# Start a new key
|
||||||
|
key=$word
|
||||||
|
else
|
||||||
|
# Append the word to the current key
|
||||||
|
key="$key $word"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Add the last key to the authorized_keys file
|
||||||
|
if [[ -n $key ]]; then
|
||||||
|
echo $key >> ~/.ssh/authorized_keys
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set the correct permissions
|
||||||
|
chmod 600 ~/.ssh/authorized_keys
|
||||||
|
chmod 700 -R ~/.ssh
|
||||||
|
}
|
||||||
|
|
||||||
if [[ $PUBLIC_KEY ]]; then
|
if [[ $PUBLIC_KEY ]]; then
|
||||||
# runpod
|
# runpod
|
||||||
mkdir -p ~/.ssh
|
add_keys_to_authorized "$PUBLIC_KEY"
|
||||||
chmod 700 ~/.ssh
|
|
||||||
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
|
|
||||||
chmod 700 -R ~/.ssh
|
|
||||||
# Start the SSH service in the background
|
# Start the SSH service in the background
|
||||||
service ssh start
|
service ssh start
|
||||||
elif [ -n "$SSH_KEY" ]; then
|
elif [[ $SSH_KEY ]]; then
|
||||||
# latitude.sh
|
# latitude.sh
|
||||||
mkdir -p ~/.ssh
|
add_keys_to_authorized "$SSH_KEY"
|
||||||
chmod 700 ~/.ssh
|
|
||||||
echo $SSH_KEY >> ~/.ssh/authorized_keys
|
|
||||||
chmod 700 -R ~/.ssh
|
|
||||||
# Start the SSH service in the background
|
# Start the SSH service in the background
|
||||||
service ssh start
|
service ssh start
|
||||||
else
|
else
|
||||||
@@ -36,5 +69,12 @@ if [ "$JUPYTER_DISABLE" != "1" ]; then
|
|||||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ ! -d "/workspace/data/axolotl-artifacts" ]; then
|
||||||
|
mkdir -p /workspace/data/axolotl-artifacts
|
||||||
|
fi
|
||||||
|
if [ ! -L "/workspace/axolotl/outputs" ]; then
|
||||||
|
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
||||||
|
fi
|
||||||
|
|
||||||
# Execute the passed arguments (CMD)
|
# Execute the passed arguments (CMD)
|
||||||
exec "$@"
|
exec "$@"
|
||||||
|
|||||||
25
setup.py
25
setup.py
@@ -30,8 +30,11 @@ def parse_requirements():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
# don't install xformers on MacOS
|
||||||
|
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
||||||
else:
|
else:
|
||||||
|
# detect the version of torch already installed
|
||||||
|
# and set it so dependencies don't clobber the torch version
|
||||||
torch_version = version("torch")
|
torch_version = version("torch")
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
@@ -45,9 +48,15 @@ def parse_requirements():
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 1):
|
if (major, minor) >= (2, 3):
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
pass
|
||||||
_install_requires.append("xformers>=0.0.23")
|
elif (major, minor) >= (2, 2):
|
||||||
|
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
||||||
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
||||||
|
_install_requires.append("xformers>=0.0.23.post1")
|
||||||
|
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -59,7 +68,7 @@ install_requires, dependency_links = parse_requirements()
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="axolotl",
|
name="axolotl",
|
||||||
version="0.4.0",
|
version="0.4.1",
|
||||||
description="LLM Trainer",
|
description="LLM Trainer",
|
||||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
@@ -68,13 +77,13 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.5.5",
|
"flash-attn==2.5.8",
|
||||||
],
|
],
|
||||||
"fused-dense-lib": [
|
"fused-dense-lib": [
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.13.1",
|
"deepspeed==0.14.2",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
193
src/axolotl/core/trainer_builder.py
Normal file → Executable file
193
src/axolotl/core/trainer_builder.py
Normal file → Executable file
@@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
|
|||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SaveModelOnTrainEndCallback,
|
SaveModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -91,11 +91,12 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(TrainingArguments):
|
class AxolotlTrainingMixins:
|
||||||
"""
|
"""
|
||||||
Extend the base TrainingArguments for axolotl helpers
|
Mixin class for the Axolotl training args.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
default=None, metadata={"help": "HF model configuration model_type."}
|
default=None, metadata={"help": "HF model configuration model_type."}
|
||||||
)
|
)
|
||||||
@@ -125,14 +126,22 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||||
)
|
)
|
||||||
|
sample_packing_bin_size: int = field(
|
||||||
|
default=200,
|
||||||
|
metadata={
|
||||||
|
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sample_packing_group_size: int = field(
|
||||||
|
default=100000,
|
||||||
|
metadata={
|
||||||
|
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=2048,
|
default=2048,
|
||||||
metadata={"help": "The maximum sequence length the model can handle"},
|
metadata={"help": "The maximum sequence length the model can handle"},
|
||||||
)
|
)
|
||||||
sample_packing_seq_len_multiplier: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
|
||||||
)
|
|
||||||
relora_steps: Optional[int] = field(
|
relora_steps: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "how often to reset for ReLoRA"},
|
metadata={"help": "how often to reset for ReLoRA"},
|
||||||
@@ -219,6 +228,30 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
|
"""
|
||||||
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||||
|
so it can't be used as a mixin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||||
|
"""
|
||||||
|
ORPO config for ORPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
||||||
|
"""
|
||||||
|
KTO config for KTO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
"""
|
"""
|
||||||
Extend the base Trainer for axolotl helpers
|
Extend the base Trainer for axolotl helpers
|
||||||
@@ -346,11 +379,12 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
RandomSampler(self.train_dataset),
|
RandomSampler(self.train_dataset),
|
||||||
batch_size=batch_size,
|
|
||||||
drop_last=True,
|
|
||||||
batch_max_len=batch_max_len,
|
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
batch_max_len=batch_max_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
group_size=self.args.sample_packing_group_size,
|
||||||
|
bin_size=self.args.sample_packing_bin_size,
|
||||||
|
drop_last=True,
|
||||||
)
|
)
|
||||||
if self.args.curriculum_sampling:
|
if self.args.curriculum_sampling:
|
||||||
return SequentialSampler(self.train_dataset)
|
return SequentialSampler(self.train_dataset)
|
||||||
@@ -370,11 +404,12 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
SequentialSampler(eval_dataset),
|
SequentialSampler(eval_dataset),
|
||||||
batch_size=batch_size,
|
lengths=get_dataset_lengths(self.eval_dataset),
|
||||||
drop_last=True,
|
|
||||||
batch_max_len=batch_max_len,
|
batch_max_len=batch_max_len,
|
||||||
lengths=get_dataset_lengths(eval_dataset),
|
batch_size=batch_size,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
group_size=self.args.sample_packing_group_size,
|
||||||
|
bin_size=self.args.sample_packing_bin_size,
|
||||||
|
drop_last=True,
|
||||||
)
|
)
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
@@ -798,6 +833,40 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.optimizer = None
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if self.args.loraplus_lr_ratio is None:
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args,
|
||||||
|
opt_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
if loraplus_lr_ratio:
|
||||||
|
print("Using lora+")
|
||||||
|
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
optimizer_kwargs,
|
||||||
|
loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
@wraps(DPOTrainer.push_to_hub)
|
@wraps(DPOTrainer.push_to_hub)
|
||||||
def push_to_hub(self, *args, **kwargs) -> str:
|
def push_to_hub(self, *args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -826,6 +895,14 @@ class AxolotlORPOTrainer(ORPOTrainer):
|
|||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(KTOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Base class for trainer builder
|
Base class for trainer builder
|
||||||
@@ -945,7 +1022,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.loss_watchdog_threshold is not None:
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
callbacks.append(SaveModelOnTrainEndCallback())
|
callbacks.append(SaveModelCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -1071,11 +1148,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.save_safetensors is not None:
|
if self.cfg.save_safetensors is not None:
|
||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.cfg.sample_packing_eff_est:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"sample_packing_efficiency"
|
|
||||||
] = self.cfg.sample_packing_eff_est
|
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"dataloader_pin_memory"
|
"dataloader_pin_memory"
|
||||||
@@ -1123,6 +1195,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
|
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||||
|
|
||||||
if self.cfg.do_bench_eval:
|
if self.cfg.do_bench_eval:
|
||||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||||
if self.cfg.bench_dataset:
|
if self.cfg.bench_dataset:
|
||||||
@@ -1202,11 +1276,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
report_to = None
|
report_to = []
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
report_to = "wandb"
|
report_to.append("wandb")
|
||||||
if self.cfg.use_mlflow:
|
if self.cfg.use_mlflow:
|
||||||
report_to = "mlflow"
|
report_to.append("mlflow")
|
||||||
|
if self.cfg.use_tensorboard:
|
||||||
|
report_to.append("tensorboard")
|
||||||
|
|
||||||
training_arguments_kwargs["report_to"] = report_to
|
training_arguments_kwargs["report_to"] = report_to
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||||
@@ -1246,20 +1323,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["weight_decay"] = (
|
training_arguments_kwargs["weight_decay"] = (
|
||||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["sample_packing"] = (
|
|
||||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||||
)
|
|
||||||
training_arguments_kwargs["multipack_real_batches"] = (
|
|
||||||
self.cfg.flash_attention is not True
|
|
||||||
)
|
|
||||||
training_arguments_kwargs["eval_sample_packing"] = (
|
|
||||||
self.cfg.sample_packing
|
|
||||||
if self.cfg.eval_sample_packing is not False
|
|
||||||
else False
|
|
||||||
)
|
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"sample_packing_seq_len_multiplier"
|
"multipack_real_batches"
|
||||||
] = self.cfg.micro_batch_size
|
] = not self.cfg.flash_attention
|
||||||
|
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||||
|
self.cfg.eval_sample_packing
|
||||||
|
)
|
||||||
|
if self.cfg.sample_packing_bin_size is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_bin_size"
|
||||||
|
] = self.cfg.sample_packing_bin_size
|
||||||
|
if self.cfg.sample_packing_group_size is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_group_size"
|
||||||
|
] = self.cfg.sample_packing_group_size
|
||||||
|
if self.cfg.sample_packing_eff_est:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_efficiency"
|
||||||
|
] = self.cfg.sample_packing_eff_est
|
||||||
|
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
@@ -1429,7 +1513,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
callbacks.append(SaveModelOnTrainEndCallback())
|
callbacks.append(SaveModelCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -1470,6 +1554,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
|
|
||||||
|
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
|
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
|
||||||
training_args_kwargs["lr_scheduler_type"] = (
|
training_args_kwargs["lr_scheduler_type"] = (
|
||||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||||
)
|
)
|
||||||
@@ -1522,20 +1608,36 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|
||||||
training_args_cls = TrainingArguments
|
training_args_cls = AxolotlTrainingArguments
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = ORPOConfig
|
training_args_cls = AxolotlORPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
training_args = training_args_cls(
|
if self.cfg.rl == "kto":
|
||||||
|
training_args_cls = AxolotlKTOConfig
|
||||||
|
|
||||||
|
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
||||||
|
training_args_kwargs["desirable_weight"] = (
|
||||||
|
self.cfg.kto_desirable_weight or 1.0
|
||||||
|
)
|
||||||
|
training_args_kwargs["undesirable_weight"] = (
|
||||||
|
self.cfg.kto_undesirable_weight or 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
if self.cfg.max_prompt_len:
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
|
output_dir=self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=self.cfg.max_steps or total_num_steps,
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
learning_rate=self.cfg.learning_rate,
|
learning_rate=self.cfg.learning_rate,
|
||||||
output_dir=self.cfg.output_dir,
|
|
||||||
warmup_steps=self.cfg.warmup_steps,
|
warmup_steps=self.cfg.warmup_steps,
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
logging_steps=1,
|
logging_steps=1,
|
||||||
@@ -1565,7 +1667,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
# these aren't used for the ORPO trainer
|
||||||
@@ -1578,6 +1680,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
elif self.cfg.rl == "kto":
|
||||||
|
trainer_cls = AxolotlKTOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
|
|||||||
@@ -42,9 +42,9 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
|||||||
return final_hidden_states, router_logits
|
return final_hidden_states, router_logits
|
||||||
|
|
||||||
from transformers.models.mixtral.modeling_mixtral import (
|
from transformers.models.mixtral.modeling_mixtral import (
|
||||||
MixtralBLockSparseTop2MLP,
|
MixtralBlockSparseTop2MLP,
|
||||||
MixtralSparseMoeBlock,
|
MixtralSparseMoeBlock,
|
||||||
)
|
)
|
||||||
|
|
||||||
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
MixtralBlockSparseTop2MLP.forward = mlp_forward
|
||||||
MixtralSparseMoeBlock.forward = moe_forward
|
MixtralSparseMoeBlock.forward = moe_forward
|
||||||
|
|||||||
267
src/axolotl/monkeypatch/unsloth_.py
Normal file
267
src/axolotl/monkeypatch/unsloth_.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""module for patching with unsloth optimizations"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import types
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from peft import PeftModelForCausalLM
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaFlashAttention2,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
|
ORIGINAL_CEL_CODE = """ if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_CEL_CODE = """ if labels is not None:
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
loss = fast_cross_entropy_loss(
|
||||||
|
logits = shift_logits,
|
||||||
|
labels = shift_labels,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
ORIGINAL_QKV_CODE = """
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
""".lstrip(
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
PATCHED_QKV_CODE = """
|
||||||
|
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
|
||||||
|
""".lstrip(
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
ORIGINAL_O_CODE = """
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
""".lstrip(
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
PATCHED_O_CODE = """
|
||||||
|
attn_output = self.apply_o(self, attn_output)
|
||||||
|
""".lstrip(
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def original_apply_qkv(self, hidden_states):
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
return query_states, key_states, value_states
|
||||||
|
|
||||||
|
|
||||||
|
def original_apply_o(self, hidden_states):
|
||||||
|
attn_output = self.o_proj(hidden_states)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def get_forward_code() -> str:
|
||||||
|
forward = inspect.getsource(LlamaForCausalLM.forward)
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def test_cel_is_patchable() -> bool:
|
||||||
|
forward = get_forward_code()
|
||||||
|
return ORIGINAL_CEL_CODE in forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_self_attn_code() -> str:
|
||||||
|
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def test_self_attn_is_patchable() -> bool:
|
||||||
|
qkv = get_self_attn_code()
|
||||||
|
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
|
||||||
|
|
||||||
|
|
||||||
|
def integrate_cross_entropy_loss_patch():
|
||||||
|
forward = get_forward_code()
|
||||||
|
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
|
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
||||||
|
|
||||||
|
forward = forward.replace(
|
||||||
|
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
||||||
|
)
|
||||||
|
forward = forward.replace(
|
||||||
|
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
||||||
|
forward = forward.replace(
|
||||||
|
"def forward(",
|
||||||
|
"def fast_cross_entropy_loss_forward(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
|
if item in forward:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
print("patching unsloth fast_cross_entropy_loss")
|
||||||
|
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
|
||||||
|
|
||||||
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
|
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||||
|
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
||||||
|
return code, spaces
|
||||||
|
|
||||||
|
|
||||||
|
def patch_self_attn_lora():
|
||||||
|
self_attn_forward = get_self_attn_code()
|
||||||
|
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
||||||
|
self_attn_forward
|
||||||
|
)
|
||||||
|
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||||
|
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
|
||||||
|
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
|
||||||
|
|
||||||
|
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
||||||
|
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||||
|
self_attn_forward = self_attn_forward.replace(
|
||||||
|
"def forward(",
|
||||||
|
"def unsloth_attn_forward(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
|
if item in self_attn_forward:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
print("patching unsloth attn lora")
|
||||||
|
LlamaFlashAttention2.forward = (
|
||||||
|
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||||
|
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
||||||
|
from unsloth.kernels import apply_lora_mlp_swiglu
|
||||||
|
|
||||||
|
apply_lora_mlp = apply_lora_mlp_swiglu
|
||||||
|
elif peft_model.base_model.config.model_type == "gemma":
|
||||||
|
from unsloth.kernels import apply_lora_mlp_geglu_approx
|
||||||
|
|
||||||
|
apply_lora_mlp = apply_lora_mlp_geglu_approx
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Model type {peft_model.base_model.config.model_type} not supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, layer in enumerate(peft_model.model.model.layers):
|
||||||
|
layer_modules = [
|
||||||
|
getattr(layer.mlp, linear_proj)
|
||||||
|
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
|
||||||
|
]
|
||||||
|
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||||
|
mlp_no_bias = all(
|
||||||
|
getattr(module, "base_layer", module).bias is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
mlp_not_dora = all(
|
||||||
|
getattr(module, "lora_magnitude_vector", None) is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
||||||
|
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
||||||
|
else:
|
||||||
|
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
||||||
|
|
||||||
|
|
||||||
|
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||||
|
from unsloth.kernels import apply_lora_o, apply_lora_qkv
|
||||||
|
|
||||||
|
for idx, layer in enumerate(peft_model.model.model.layers):
|
||||||
|
if cfg.unsloth_lora_qkv:
|
||||||
|
layer_modules = [
|
||||||
|
getattr(layer.self_attn, linear_proj)
|
||||||
|
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||||
|
]
|
||||||
|
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||||
|
qkv_no_bias = all(
|
||||||
|
getattr(module, "base_layer", module).bias is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
qkv_not_dora = all(
|
||||||
|
getattr(module, "lora_magnitude_vector", None) is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
|
||||||
|
layer.self_attn.apply_qkv = apply_lora_qkv
|
||||||
|
else:
|
||||||
|
layer.self_attn.apply_qkv = original_apply_qkv
|
||||||
|
logging.warning(
|
||||||
|
"unable to apply unsloth lora qkv patch to layer %d", idx
|
||||||
|
)
|
||||||
|
if cfg.unsloth_lora_o:
|
||||||
|
layer_modules = [
|
||||||
|
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
|
]
|
||||||
|
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||||
|
o_no_bias = all(
|
||||||
|
getattr(module, "base_layer", module).bias is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
o_not_dora = all(
|
||||||
|
getattr(module, "lora_magnitude_vector", None) is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_o_lora and o_no_bias and o_not_dora:
|
||||||
|
layer.self_attn.apply_o = apply_lora_o
|
||||||
|
else:
|
||||||
|
layer.self_attn.apply_o = original_apply_o
|
||||||
|
logging.warning(
|
||||||
|
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
||||||
|
)
|
||||||
@@ -1,24 +1,56 @@
|
|||||||
"""
|
"""
|
||||||
HF Chat Templates prompt strategy
|
HF Chat Templates prompt strategy
|
||||||
"""
|
"""
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import Prompter
|
from axolotl.prompters import Prompter
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplatePrompter(Prompter):
|
class ChatTemplatePrompter(Prompter):
|
||||||
"""prompter for HF chat templates"""
|
"""prompter for HF chat templates"""
|
||||||
|
|
||||||
def __init__(self, tokenizer, chat_template=None, max_length=2048):
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
max_length=2048,
|
||||||
|
message_field_role: str = "from",
|
||||||
|
message_field_content: str = "value",
|
||||||
|
roles: Optional[Dict[str, List[str]]] = None,
|
||||||
|
):
|
||||||
|
if roles:
|
||||||
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||||
|
else:
|
||||||
|
self.roles = {
|
||||||
|
"human": "user",
|
||||||
|
"user": "user",
|
||||||
|
"assistant": "assistant",
|
||||||
|
"gpt": "assistant",
|
||||||
|
"system": "system",
|
||||||
|
}
|
||||||
|
self.message_field_role = message_field_role
|
||||||
|
self.message_field_content = message_field_content
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||||
|
turns = [
|
||||||
|
{
|
||||||
|
"role": self.roles[t[self.message_field_role]],
|
||||||
|
"content": t[self.message_field_content],
|
||||||
|
}
|
||||||
|
for t in conversation
|
||||||
|
]
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
conversation,
|
turns,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
@@ -31,9 +63,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
Tokenizing strategy for instruction-based prompts.
|
Tokenizing strategy for instruction-based prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_messages = "conversations"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self):
|
||||||
|
return self._messages
|
||||||
|
|
||||||
|
@messages.setter
|
||||||
|
def messages(self, messages):
|
||||||
|
self._messages = messages
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = self.get_conversation_thread(prompt)
|
||||||
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
|
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
|
|
||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
@@ -51,28 +93,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
conversations = prompt["conversations"]
|
return prompt[self.messages]
|
||||||
# remap roles - allow for assistant turn
|
|
||||||
role_map = {
|
|
||||||
"human": "user",
|
|
||||||
"user": "user",
|
|
||||||
"assistant": "assistant",
|
|
||||||
"gpt": "assistant",
|
|
||||||
}
|
|
||||||
turns = [
|
|
||||||
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
chat_template = (
|
chat_template = (
|
||||||
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
||||||
)
|
)
|
||||||
|
message_field_role = (
|
||||||
|
ds_cfg["message_field_role"]
|
||||||
|
if ds_cfg and "message_field_role" in ds_cfg
|
||||||
|
else "from"
|
||||||
|
)
|
||||||
|
message_field_content = (
|
||||||
|
ds_cfg["message_field_content"]
|
||||||
|
if ds_cfg and "message_field_content" in ds_cfg
|
||||||
|
else "value"
|
||||||
|
)
|
||||||
|
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_templates(chat_template),
|
||||||
|
message_field_role=message_field_role,
|
||||||
|
message_field_content=message_field_content,
|
||||||
|
roles=roles,
|
||||||
|
),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||||
|
strategy.messages = ds_cfg["field_messages"]
|
||||||
return strategy
|
return strategy
|
||||||
|
|||||||
9
src/axolotl/prompt_strategies/kto/__init__.py
Normal file
9
src/axolotl/prompt_strategies/kto/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""
|
||||||
|
module for KTO style dataset transform strategies
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from ..base import load as load_base
|
||||||
|
|
||||||
|
load = partial(load_base, module_base="axolotl.prompt_strategies.kto")
|
||||||
105
src/axolotl/prompt_strategies/kto/chatml.py
Normal file
105
src/axolotl/prompt_strategies/kto/chatml.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
KTO strategies for chatml
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
def argilla(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||||
|
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def argilla_chat(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for argilla/kto-mix-15k conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
For Intel Orca KTO
|
||||||
|
ex: argilla/distilabel-intel-orca-kto
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||||
|
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_pairs(
|
||||||
|
cfg, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||||
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for ultrafeedback binarized conversations
|
||||||
|
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||||
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
105
src/axolotl/prompt_strategies/kto/llama3.py
Normal file
105
src/axolotl/prompt_strategies/kto/llama3.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
KTO strategies for llama-3 chat template
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
def argilla(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def argilla_chat(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for argilla/kto-mix-15k conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
For Intel Orca KTO
|
||||||
|
ex: argilla/distilabel-intel-orca-kto
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_pairs(
|
||||||
|
cfg, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for ultrafeedback binarized conversations
|
||||||
|
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
39
src/axolotl/prompt_strategies/kto/user_defined.py
Normal file
39
src/axolotl/prompt_strategies/kto/user_defined.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
User-defined KTO strategies
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
|
||||||
|
ds_cfg = cfg["datasets"][dataset_idx]["type"]
|
||||||
|
if not isinstance(ds_cfg, dict):
|
||||||
|
raise ValueError(
|
||||||
|
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
|
||||||
|
)
|
||||||
|
field_prompt = ds_cfg.get("field_prompt", "prompt")
|
||||||
|
field_system = ds_cfg.get("field_system", "system")
|
||||||
|
field_completion = ds_cfg.get("field_completion", "completion")
|
||||||
|
field_label = ds_cfg.get("field_label", "label")
|
||||||
|
prompt_format = ds_cfg.get("prompt_format")
|
||||||
|
if not prompt_format:
|
||||||
|
prompt_format = "{" + field_prompt + "}"
|
||||||
|
completion_format = ds_cfg.get("completion_format")
|
||||||
|
if not completion_format:
|
||||||
|
chosen_format = "{" + field_completion + "}"
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
if (
|
||||||
|
"{" + field_system + "}" in prompt_format
|
||||||
|
and field_system in sample
|
||||||
|
and sample[field_system]
|
||||||
|
):
|
||||||
|
sample["prompt"] = prompt_format.format(
|
||||||
|
system=sample[field_system], prompt=sample[field_prompt]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
|
||||||
|
sample["completion"] = chosen_format.format(chosen=sample[field_completion])
|
||||||
|
sample["label"] = sample[field_label]
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -86,6 +86,8 @@ def build_loader(
|
|||||||
)
|
)
|
||||||
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
||||||
strategy.strict = ds_cfg["strict"]
|
strategy.strict = ds_cfg["strict"]
|
||||||
|
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||||
|
strategy.messages = ds_cfg["field_messages"]
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
return _load
|
return _load
|
||||||
@@ -97,6 +99,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_strict = False
|
_strict = False
|
||||||
|
_messages = "conversations"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strict(self):
|
def strict(self):
|
||||||
@@ -106,8 +109,16 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
def strict(self, strict):
|
def strict(self, strict):
|
||||||
self._strict = strict
|
self._strict = strict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self):
|
||||||
|
return self._messages
|
||||||
|
|
||||||
|
@messages.setter
|
||||||
|
def messages(self, messages):
|
||||||
|
self._messages = messages
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
conversations = prompt["conversations"]
|
conversations = prompt[self.messages]
|
||||||
if self.strict:
|
if self.strict:
|
||||||
return conversations
|
return conversations
|
||||||
role_key = "from"
|
role_key = "from"
|
||||||
|
|||||||
@@ -197,6 +197,13 @@ def train(
|
|||||||
trainer.accelerator.wait_for_everyone()
|
trainer.accelerator.wait_for_everyone()
|
||||||
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
||||||
|
|
||||||
|
# the trainer saved a model.safetensors file in the output directory,
|
||||||
|
# but it is a proxy model and should be deleted
|
||||||
|
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
|
||||||
|
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
|
||||||
|
LOG.info("This is a proxy model and should be deleted")
|
||||||
|
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
|
|
||||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
||||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
||||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
@@ -775,7 +776,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class SaveModelOnTrainEndCallback(TrainerCallback):
|
class SaveModelCallback(TrainerCallback):
|
||||||
"""Callback to save model on train end"""
|
"""Callback to save model on train end"""
|
||||||
|
|
||||||
def on_step_end( # pylint: disable=unused-argument
|
def on_step_end( # pylint: disable=unused-argument
|
||||||
@@ -788,6 +789,13 @@ class SaveModelOnTrainEndCallback(TrainerCallback):
|
|||||||
# Save
|
# Save
|
||||||
if state.global_step >= state.max_steps:
|
if state.global_step >= state.max_steps:
|
||||||
control.should_save = True
|
control.should_save = True
|
||||||
|
elif (
|
||||||
|
args.save_strategy == IntervalStrategy.STEPS
|
||||||
|
and state.save_steps < 1.0
|
||||||
|
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
|
||||||
|
):
|
||||||
|
# workaround to save model on fractional save_steps
|
||||||
|
control.should_save = True
|
||||||
|
|
||||||
def on_train_end( # pylint: disable=unused-argument
|
def on_train_end( # pylint: disable=unused-argument
|
||||||
self, args, state, control, **kwargs
|
self, args, state, control, **kwargs
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def chat_templates(user_choice: str):
|
|||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||||
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
|
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||||
}
|
}
|
||||||
|
|
||||||
if user_choice in templates:
|
if user_choice in templates:
|
||||||
|
|||||||
@@ -187,19 +187,22 @@ def normalize_cfg_datasets(cfg):
|
|||||||
helpers for mapping chat_template to various dataset configurations as necessary
|
helpers for mapping chat_template to various dataset configurations as necessary
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if cfg.chat_template and cfg.chat_template == "chatml":
|
if cfg.chat_template:
|
||||||
if cfg.datasets:
|
if cfg.datasets:
|
||||||
for idx, ds_cfg in enumerate(cfg.datasets):
|
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||||
if ds_cfg.type == "sharegpt" and not ds_cfg.conversation:
|
if ds_cfg.type == "sharegpt" and not ds_cfg.conversation:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
|
f"updating dataset {ds_cfg.path} with `conversation: {cfg.chat_template}` to match your chat_template"
|
||||||
)
|
)
|
||||||
cfg.datasets[idx].conversation = "chatml"
|
cfg.datasets[idx].conversation = cfg.chat_template
|
||||||
if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template:
|
if (
|
||||||
|
ds_cfg.type in ["orpo.chat_template", "chat_template"]
|
||||||
|
and not ds_cfg.chat_template
|
||||||
|
):
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template"
|
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
|
||||||
)
|
)
|
||||||
cfg.datasets[idx].chat_template = "chatml"
|
cfg.datasets[idx].chat_template = cfg.chat_template
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class DeprecatedParameters(BaseModel):
|
|||||||
max_packed_sequence_len: Optional[int] = None
|
max_packed_sequence_len: Optional[int] = None
|
||||||
rope_scaling: Optional[Any] = None
|
rope_scaling: Optional[Any] = None
|
||||||
noisy_embedding_alpha: Optional[float] = None
|
noisy_embedding_alpha: Optional[float] = None
|
||||||
|
dpo_beta: Optional[float] = None
|
||||||
|
|
||||||
@field_validator("max_packed_sequence_len")
|
@field_validator("max_packed_sequence_len")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -48,6 +49,13 @@ class DeprecatedParameters(BaseModel):
|
|||||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||||
return noisy_embedding_alpha
|
return noisy_embedding_alpha
|
||||||
|
|
||||||
|
@field_validator("dpo_beta")
|
||||||
|
@classmethod
|
||||||
|
def validate_dpo_beta(cls, dpo_beta):
|
||||||
|
if dpo_beta is not None:
|
||||||
|
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||||
|
return dpo_beta
|
||||||
|
|
||||||
|
|
||||||
class RemappedParameters(BaseModel):
|
class RemappedParameters(BaseModel):
|
||||||
"""parameters that have been remapped to other names"""
|
"""parameters that have been remapped to other names"""
|
||||||
@@ -101,6 +109,9 @@ class SFTDataset(BaseModel):
|
|||||||
field: Optional[str] = None
|
field: Optional[str] = None
|
||||||
field_human: Optional[str] = None
|
field_human: Optional[str] = None
|
||||||
field_model: Optional[str] = None
|
field_model: Optional[str] = None
|
||||||
|
field_messages: Optional[str] = None
|
||||||
|
message_field_role: Optional[str] = None
|
||||||
|
message_field_content: Optional[str] = None
|
||||||
|
|
||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
|
|
||||||
@@ -126,6 +137,26 @@ class DPODataset(BaseModel):
|
|||||||
data_files: Optional[List[str]] = None
|
data_files: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedKTOType(BaseModel):
|
||||||
|
"""User defined typing for KTO"""
|
||||||
|
|
||||||
|
field_system: Optional[str] = None
|
||||||
|
field_prompt: Optional[str] = None
|
||||||
|
field_completion: Optional[str] = None
|
||||||
|
field_label: Optional[bool] = None
|
||||||
|
prompt_format: Optional[str] = None
|
||||||
|
completion_format: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class KTODataset(BaseModel):
|
||||||
|
"""KTO configuration subset"""
|
||||||
|
|
||||||
|
path: Optional[str] = None
|
||||||
|
split: Optional[str] = None
|
||||||
|
type: Optional[Union[UserDefinedKTOType, str]] = None
|
||||||
|
data_files: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class RLType(str, Enum):
|
class RLType(str, Enum):
|
||||||
"""RL trainer type configuration subset"""
|
"""RL trainer type configuration subset"""
|
||||||
|
|
||||||
@@ -133,6 +164,7 @@ class RLType(str, Enum):
|
|||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
@@ -183,10 +215,12 @@ class LoraConfig(BaseModel):
|
|||||||
lora_target_modules: Optional[List[str]] = None
|
lora_target_modules: Optional[List[str]] = None
|
||||||
lora_target_linear: Optional[bool] = None
|
lora_target_linear: Optional[bool] = None
|
||||||
lora_modules_to_save: Optional[List[str]] = None
|
lora_modules_to_save: Optional[List[str]] = None
|
||||||
lora_dropout: Optional[float] = None
|
lora_dropout: Optional[float] = 0.0
|
||||||
peft_layers_to_transform: Optional[List[int]] = None
|
peft_layers_to_transform: Optional[List[int]] = None
|
||||||
peft: Optional[PeftConfig] = None
|
peft: Optional[PeftConfig] = None
|
||||||
peft_use_dora: Optional[bool] = None
|
peft_use_dora: Optional[bool] = None
|
||||||
|
peft_use_mora: Optional[bool] = None
|
||||||
|
peft_mora_type: Optional[int] = None
|
||||||
peft_use_rslora: Optional[bool] = None
|
peft_use_rslora: Optional[bool] = None
|
||||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||||
|
|
||||||
@@ -450,8 +484,8 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
rl: Optional[RLType] = None
|
rl: Optional[RLType] = None
|
||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||||
shuffle_merged_datasets: Optional[bool] = True
|
shuffle_merged_datasets: Optional[bool] = True
|
||||||
dataset_prepared_path: Optional[str] = None
|
dataset_prepared_path: Optional[str] = None
|
||||||
dataset_shard_num: Optional[int] = None
|
dataset_shard_num: Optional[int] = None
|
||||||
@@ -521,6 +555,8 @@ class AxolotlInputConfig(
|
|||||||
default=512, metadata={"help": "maximum prompt length for RL training"}
|
default=512, metadata={"help": "maximum prompt length for RL training"}
|
||||||
)
|
)
|
||||||
sample_packing: Optional[bool] = None
|
sample_packing: Optional[bool] = None
|
||||||
|
sample_packing_group_size: Optional[int] = 100_000
|
||||||
|
sample_packing_bin_size: Optional[int] = 200
|
||||||
eval_sample_packing: Optional[bool] = None
|
eval_sample_packing: Optional[bool] = None
|
||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
curriculum_sampling: Optional[bool] = None
|
curriculum_sampling: Optional[bool] = None
|
||||||
@@ -549,6 +585,11 @@ class AxolotlInputConfig(
|
|||||||
flash_attn_fuse_mlp: Optional[bool] = None
|
flash_attn_fuse_mlp: Optional[bool] = None
|
||||||
flash_optimum: Optional[bool] = None
|
flash_optimum: Optional[bool] = None
|
||||||
|
|
||||||
|
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||||
|
unsloth_lora_mlp: Optional[bool] = None
|
||||||
|
unsloth_lora_qkv: Optional[bool] = None
|
||||||
|
unsloth_lora_o: Optional[bool] = None
|
||||||
|
|
||||||
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
fsdp: Optional[List[str]] = None
|
fsdp: Optional[List[str]] = None
|
||||||
fsdp_config: Optional[Dict[str, Any]] = None
|
fsdp_config: Optional[Dict[str, Any]] = None
|
||||||
@@ -574,11 +615,17 @@ class AxolotlInputConfig(
|
|||||||
logging_steps: Optional[int] = None
|
logging_steps: Optional[int] = None
|
||||||
early_stopping_patience: Optional[int] = None
|
early_stopping_patience: Optional[int] = None
|
||||||
load_best_model_at_end: Optional[bool] = False
|
load_best_model_at_end: Optional[bool] = False
|
||||||
|
save_only_model: Optional[bool] = False
|
||||||
|
use_tensorboard: Optional[bool] = None
|
||||||
|
|
||||||
neftune_noise_alpha: Optional[float] = None
|
neftune_noise_alpha: Optional[float] = None
|
||||||
|
|
||||||
orpo_alpha: Optional[float] = None
|
orpo_alpha: Optional[float] = None
|
||||||
|
|
||||||
|
kto_desirable_weight: Optional[float] = None
|
||||||
|
kto_undesirable_weight: Optional[float] = None
|
||||||
|
rl_beta: Optional[float] = None
|
||||||
|
|
||||||
max_memory: Optional[
|
max_memory: Optional[
|
||||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||||
] = None
|
] = None
|
||||||
@@ -878,6 +925,13 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||||
return neftune_noise_alpha
|
return neftune_noise_alpha
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check(self):
|
||||||
|
if self.dpo_beta and not self.rl_beta:
|
||||||
|
self.rl_beta = self.dpo_beta
|
||||||
|
del self.dpo_beta
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_frozen(cls, data):
|
def check_frozen(cls, data):
|
||||||
|
|||||||
@@ -150,6 +150,8 @@ def wrap_pretraining_dataset(
|
|||||||
max_seq_length=max_tokens,
|
max_seq_length=max_tokens,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
multipack_attn=cfg.pretrain_multipack_attn,
|
multipack_attn=cfg.pretrain_multipack_attn,
|
||||||
|
group_size=cfg.sample_packing_group_size,
|
||||||
|
bin_size=cfg.sample_packing_bin_size,
|
||||||
)
|
)
|
||||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||||
cfg.micro_batch_size = 1
|
cfg.micro_batch_size = 1
|
||||||
@@ -189,6 +191,8 @@ def encode_packed_pretraining(
|
|||||||
max_seq_length: int = 2048,
|
max_seq_length: int = 2048,
|
||||||
batch_size: int = 4,
|
batch_size: int = 4,
|
||||||
multipack_attn: Optional[bool] = False,
|
multipack_attn: Optional[bool] = False,
|
||||||
|
group_size: int = 100000,
|
||||||
|
bin_size: int = 200,
|
||||||
) -> Dict[str, List]:
|
) -> Dict[str, List]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
# tokenize all the examples
|
# tokenize all the examples
|
||||||
@@ -202,11 +206,13 @@ def encode_packed_pretraining(
|
|||||||
)
|
)
|
||||||
|
|
||||||
sampler = MultipackBatchSampler(
|
sampler = MultipackBatchSampler(
|
||||||
RandomSampler(train_dataset),
|
sampler=RandomSampler(train_dataset),
|
||||||
batch_size=1,
|
|
||||||
drop_last=True,
|
|
||||||
batch_max_len=batch_size * max_seq_length,
|
|
||||||
lengths=get_dataset_lengths(train_dataset),
|
lengths=get_dataset_lengths(train_dataset),
|
||||||
|
batch_size=1,
|
||||||
|
batch_max_len=batch_size * max_seq_length,
|
||||||
|
group_size=group_size,
|
||||||
|
bin_size=bin_size,
|
||||||
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
chunked_data = defaultdict(list)
|
chunked_data = defaultdict(list)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_
|
|||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||||
|
from axolotl.prompt_strategies.kto import load as load_kto
|
||||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||||
from axolotl.utils.data.utils import md5
|
from axolotl.utils.data.utils import md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -55,6 +56,22 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
|||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
|
|
||||||
|
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
||||||
|
sig = inspect.signature(ds_transform_fn)
|
||||||
|
if "tokenizer" in sig.parameters:
|
||||||
|
if not tokenizer:
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
data_set = data_set.map(
|
||||||
|
ds_transform_fn,
|
||||||
|
desc="Mapping RL Dataset",
|
||||||
|
)
|
||||||
|
if isinstance(data_set, DatasetDict):
|
||||||
|
data_set = data_set["train"]
|
||||||
|
return data_set
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_dpo_datasets(cfg):
|
def load_prepare_dpo_datasets(cfg):
|
||||||
def load_split(dataset_cfgs, _cfg):
|
def load_split(dataset_cfgs, _cfg):
|
||||||
split_datasets: List[Any] = []
|
split_datasets: List[Any] = []
|
||||||
@@ -76,6 +93,7 @@ def load_prepare_dpo_datasets(cfg):
|
|||||||
split_datasets.insert(i, ds)
|
split_datasets.insert(i, ds)
|
||||||
|
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
|
||||||
for i, data_set in enumerate(split_datasets):
|
for i, data_set in enumerate(split_datasets):
|
||||||
_type = dataset_cfgs[i]["type"]
|
_type = dataset_cfgs[i]["type"]
|
||||||
if _type:
|
if _type:
|
||||||
@@ -83,21 +101,19 @@ def load_prepare_dpo_datasets(cfg):
|
|||||||
_type = "user_defined.default"
|
_type = "user_defined.default"
|
||||||
if _cfg.rl == "orpo":
|
if _cfg.rl == "orpo":
|
||||||
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
||||||
|
elif _cfg.rl == "kto":
|
||||||
|
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||||
else:
|
else:
|
||||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||||
sig = inspect.signature(ds_transform_fn)
|
|
||||||
if "tokenizer" in sig.parameters:
|
|
||||||
if not tokenizer:
|
|
||||||
tokenizer = load_tokenizer(_cfg)
|
|
||||||
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
|
||||||
|
|
||||||
data_set = data_set.map(
|
split_datasets[i] = map_dataset(
|
||||||
ds_transform_fn,
|
cfg, data_set, ds_transform_fn, tokenizer
|
||||||
desc="Mapping RL Dataset",
|
)
|
||||||
|
elif _cfg.rl == "kto":
|
||||||
|
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||||
|
split_datasets[i] = map_dataset(
|
||||||
|
cfg, data_set, ds_transform_fn, tokenizer
|
||||||
)
|
)
|
||||||
if isinstance(data_set, DatasetDict):
|
|
||||||
data_set = data_set["train"]
|
|
||||||
split_datasets[i] = data_set
|
|
||||||
else:
|
else:
|
||||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||||
# "prompt", "chosen" and "rejected" already preprocessed
|
# "prompt", "chosen" and "rejected" already preprocessed
|
||||||
|
|||||||
@@ -308,12 +308,16 @@ def load_tokenized_prepared_datasets(
|
|||||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||||
)
|
)
|
||||||
elif ds_from_hub:
|
elif ds_from_hub:
|
||||||
|
load_ds_kwargs = {}
|
||||||
|
if config_dataset.split:
|
||||||
|
load_ds_kwargs = {"split": config_dataset.split}
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
elif ds_from_cloud and remote_file_system:
|
elif ds_from_cloud and remote_file_system:
|
||||||
if remote_file_system.isdir(config_dataset.path):
|
if remote_file_system.isdir(config_dataset.path):
|
||||||
@@ -494,7 +498,9 @@ def load_prepare_datasets(
|
|||||||
test_fingerprint = md5(to_hash_test)
|
test_fingerprint = md5(to_hash_test)
|
||||||
|
|
||||||
dataset = dataset.train_test_split(
|
dataset = dataset.train_test_split(
|
||||||
test_size=cfg.val_set_size,
|
test_size=int(cfg.val_set_size)
|
||||||
|
if cfg.val_set_size == int(cfg.val_set_size)
|
||||||
|
else cfg.val_set_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
seed=cfg.seed or 42,
|
seed=cfg.seed or 42,
|
||||||
train_new_fingerprint=train_fingerprint,
|
train_new_fingerprint=train_fingerprint,
|
||||||
|
|||||||
@@ -390,6 +390,16 @@ def load_model(
|
|||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.unsloth_cross_entropy_loss:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||||
|
|
||||||
|
integrate_cross_entropy_loss_patch()
|
||||||
|
|
||||||
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
|
|
||||||
|
patch_self_attn_lora()
|
||||||
|
|
||||||
# Modify mistral derived models
|
# Modify mistral derived models
|
||||||
if (
|
if (
|
||||||
cfg.model_config_type == "mistral"
|
cfg.model_config_type == "mistral"
|
||||||
@@ -793,7 +803,11 @@ def load_model(
|
|||||||
if not reference_model or cfg.lora_model_dir:
|
if not reference_model or cfg.lora_model_dir:
|
||||||
# if we're not loading the reference model, then we're loading the model for training
|
# if we're not loading the reference model, then we're loading the model for training
|
||||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||||
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
|
if (
|
||||||
|
cfg.adapter
|
||||||
|
and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"]
|
||||||
|
and not cfg.merge_lora
|
||||||
|
):
|
||||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
@@ -828,6 +842,15 @@ def load_model(
|
|||||||
if cfg.adapter is not None:
|
if cfg.adapter is not None:
|
||||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||||
|
|
||||||
|
if cfg.unsloth_lora_mlp:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
|
||||||
|
|
||||||
|
integrate_lora_mlp_patch(model)
|
||||||
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
|
||||||
|
|
||||||
|
integrate_lora_patch(model, cfg)
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
@@ -930,6 +953,8 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
|
|
||||||
lora_config_kwargs = {}
|
lora_config_kwargs = {}
|
||||||
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||||
|
if cfg.lora_alpha:
|
||||||
|
lora_config_kwargs["lora_alpha"] = cfg.lora_alpha
|
||||||
if loftq_bits:
|
if loftq_bits:
|
||||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||||
@@ -937,12 +962,14 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||||
if cfg.peft_use_rslora:
|
if cfg.peft_use_rslora:
|
||||||
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
||||||
|
if cfg.peft_use_mora and cfg.peft_mora_type is not None:
|
||||||
|
lora_config_kwargs["use_mora"] = cfg.peft_use_mora
|
||||||
|
lora_config_kwargs["mora_type"] = cfg.peft_mora_type
|
||||||
if cfg.peft_layer_replication:
|
if cfg.peft_layer_replication:
|
||||||
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=cfg.lora_r,
|
r=cfg.lora_r,
|
||||||
lora_alpha=cfg.lora_alpha,
|
|
||||||
target_modules=lora_target_modules,
|
target_modules=lora_target_modules,
|
||||||
layers_to_transform=cfg.peft_layers_to_transform,
|
layers_to_transform=cfg.peft_layers_to_transform,
|
||||||
lora_dropout=cfg.lora_dropout,
|
lora_dropout=cfg.lora_dropout,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user