Compare commits
9 Commits
llama-mult
...
nca-pair
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
317761406e | ||
|
|
6a9ac4ad27 | ||
|
|
027f7d54f0 | ||
|
|
0554105baa | ||
|
|
f58fcd09ec | ||
|
|
60fecac367 | ||
|
|
b301068098 | ||
|
|
df645906eb | ||
|
|
7fea5822f0 |
2
.github/workflows/base.yml
vendored
2
.github/workflows/base.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
- cuda: "121"
|
- cuda: "121"
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.2
|
pytorch: 2.2.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "121"
|
- cuda: "121"
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
|
|||||||
46
.github/workflows/main.yml
vendored
46
.github/workflows/main.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.2
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.2
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
@@ -125,45 +125,3 @@ jobs:
|
|||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
build-axolotl-cloud-no-tmux:
|
|
||||||
needs: build-axolotl
|
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: axolotl-gpu-runner
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
- name: Docker metadata
|
|
||||||
id: metadata
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
images: winglian/axolotl-cloud-term
|
|
||||||
- name: Login to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v2
|
|
||||||
- name: Build
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
build-args: |
|
|
||||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
|
||||||
CUDA=${{ matrix.cuda }}
|
|
||||||
file: ./docker/Dockerfile-cloud-no-tmux
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
|
||||||
tags: |
|
|
||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
|
||||||
|
|||||||
4
.github/workflows/nightlies.yml
vendored
4
.github/workflows/nightlies.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.2
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.2
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
|
|||||||
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
@@ -82,12 +82,7 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.2
|
pytorch: 2.2.1
|
||||||
num_gpus: 1
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.0
|
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -176,9 +176,3 @@ qlora-out/*
|
|||||||
mlruns/*
|
mlruns/*
|
||||||
|
|
||||||
/.quarto/
|
/.quarto/
|
||||||
prepared-datasets/
|
|
||||||
submit.sh
|
|
||||||
*.out*
|
|
||||||
|
|
||||||
typings/
|
|
||||||
out/
|
|
||||||
|
|||||||
43
README.md
43
README.md
@@ -34,7 +34,6 @@ Features:
|
|||||||
- [Mac](#mac)
|
- [Mac](#mac)
|
||||||
- [Google Colab](#google-colab)
|
- [Google Colab](#google-colab)
|
||||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||||
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
|
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
- [Config](#config)
|
- [Config](#config)
|
||||||
- [Train](#train)
|
- [Train](#train)
|
||||||
@@ -124,11 +123,11 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
|||||||
|
|
||||||
# inference
|
# inference
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--lora_model_dir="./outputs/lora-out"
|
--lora_model_dir="./lora-out"
|
||||||
|
|
||||||
# gradio
|
# gradio
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--lora_model_dir="./outputs/lora-out" --gradio
|
--lora_model_dir="./lora-out" --gradio
|
||||||
|
|
||||||
# remote yaml files - the yaml config can be hosted on a public URL
|
# remote yaml files - the yaml config can be hosted on a public URL
|
||||||
# Note: the yaml config must directly link to the **raw** yaml
|
# Note: the yaml config must directly link to the **raw** yaml
|
||||||
@@ -293,42 +292,6 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
|||||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Launching on public clouds via dstack
|
|
||||||
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
|
|
||||||
|
|
||||||
Write a job description in YAML as below:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# dstack.yaml
|
|
||||||
type: task
|
|
||||||
|
|
||||||
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
|
|
||||||
|
|
||||||
env:
|
|
||||||
- HUGGING_FACE_HUB_TOKEN
|
|
||||||
- WANDB_API_KEY
|
|
||||||
|
|
||||||
commands:
|
|
||||||
- accelerate launch -m axolotl.cli.train config.yaml
|
|
||||||
|
|
||||||
ports:
|
|
||||||
- 6006
|
|
||||||
|
|
||||||
resources:
|
|
||||||
gpu:
|
|
||||||
memory: 24GB..
|
|
||||||
count: 2
|
|
||||||
```
|
|
||||||
|
|
||||||
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install dstack
|
|
||||||
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
|
|
||||||
```
|
|
||||||
|
|
||||||
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
|
|
||||||
|
|
||||||
### Dataset
|
### Dataset
|
||||||
|
|
||||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||||
@@ -609,7 +572,7 @@ If you decode a prompt constructed by axolotl, you might see spaces between toke
|
|||||||
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
|
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
|
||||||
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
|
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
|
||||||
|
|
||||||
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/finetuning/05_tokenizer_gotchas.html) for a concrete example.
|
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
|
||||||
|
|
||||||
## Debugging Axolotl
|
## Debugging Axolotl
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
|
||||||
|
|
||||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
pytest /workspace/axolotl/tests/e2e/patched/
|
pytest /workspace/axolotl/tests/e2e/patched/
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ ARG PYTORCH_VERSION="2.1.2"
|
|||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
ARG BASE_TAG=main
|
|
||||||
FROM winglian/axolotl:$BASE_TAG
|
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
|
||||||
|
|
||||||
EXPOSE 8888
|
|
||||||
EXPOSE 22
|
|
||||||
|
|
||||||
COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh
|
|
||||||
COPY scripts/motd /etc/motd
|
|
||||||
|
|
||||||
RUN pip install jupyterlab notebook ipywidgets && \
|
|
||||||
jupyter lab clean
|
|
||||||
RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \
|
|
||||||
pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \
|
|
||||||
mkdir -p ~/.ssh && \
|
|
||||||
chmod 700 ~/.ssh && \
|
|
||||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
|
||||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
|
||||||
chmod +x /root/cloud-entrypoint.sh
|
|
||||||
|
|
||||||
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
|
||||||
CMD ["sleep", "infinity"]
|
|
||||||
@@ -138,7 +138,7 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard', 'nca_pair'
|
||||||
rl:
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
@@ -186,11 +186,6 @@ eval_sample_packing:
|
|||||||
# The trainer will provide recommended values for these values.
|
# The trainer will provide recommended values for these values.
|
||||||
sample_packing_eff_est:
|
sample_packing_eff_est:
|
||||||
total_num_tokens:
|
total_num_tokens:
|
||||||
# Increasing the following values helps with packing, but usually only slightly (<%1.)
|
|
||||||
# The number of samples packed at a time.
|
|
||||||
sample_packing_group_size: 100000
|
|
||||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
|
||||||
sample_packing_bin_size: 200
|
|
||||||
|
|
||||||
# Passed through to transformers when loading the model when launched without accelerate
|
# Passed through to transformers when loading the model when launched without accelerate
|
||||||
# Use `sequential` when training w/ model parallelism to limit memory
|
# Use `sequential` when training w/ model parallelism to limit memory
|
||||||
@@ -290,7 +285,7 @@ lr_quadratic_warmup:
|
|||||||
logging_steps:
|
logging_steps:
|
||||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||||
save_strategy: # Set to `"no"` to skip checkpoint saves
|
save_strategy: # Set to `no` to skip checkpoint saves
|
||||||
save_steps: # Leave empty to save at each epoch
|
save_steps: # Leave empty to save at each epoch
|
||||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||||
save_total_limit: # Checkpoints saved at a time
|
save_total_limit: # Checkpoints saved at a time
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
output_dir: ./outputs/btlm-out
|
output_dir: btlm-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 2
|
num_epochs: 2
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -1,223 +1,216 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "AKjdG7tbTb-n"
|
"id": "AKjdG7tbTb-n"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# Example notebook for running Axolotl on google colab"
|
"# Example notebook for running Axolotl on google colab"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "RcbNpOgWRcii"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
|
|
||||||
"assert (torch.cuda.is_available()==True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "h3nLav8oTRA5"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Install Axolotl and dependencies"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
},
|
||||||
"id": "3c3yGAwnOIdi",
|
{
|
||||||
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
|
"cell_type": "code",
|
||||||
},
|
"execution_count": null,
|
||||||
"outputs": [],
|
"metadata": {
|
||||||
"source": [
|
"id": "RcbNpOgWRcii"
|
||||||
"!pip install torch==\"2.1.2\"\n",
|
},
|
||||||
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
|
"outputs": [],
|
||||||
"!pip install flash-attn==\"2.5.0\"\n",
|
"source": [
|
||||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
"import torch\n",
|
||||||
]
|
"# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
|
||||||
},
|
"assert (torch.cuda.is_available()==True)"
|
||||||
{
|
]
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "BW2MFr7HTjub"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Create an yaml config file"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "9pkF2dSoQEUN"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import yaml\n",
|
|
||||||
"\n",
|
|
||||||
"# Your YAML string\n",
|
|
||||||
"yaml_string = \"\"\"\n",
|
|
||||||
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
|
|
||||||
"model_type: LlamaForCausalLM\n",
|
|
||||||
"tokenizer_type: LlamaTokenizer\n",
|
|
||||||
"\n",
|
|
||||||
"load_in_8bit: false\n",
|
|
||||||
"load_in_4bit: true\n",
|
|
||||||
"strict: false\n",
|
|
||||||
"\n",
|
|
||||||
"datasets:\n",
|
|
||||||
" - path: mhenrichsen/alpaca_2k_test\n",
|
|
||||||
" type: alpaca\n",
|
|
||||||
"dataset_prepared_path:\n",
|
|
||||||
"val_set_size: 0.05\n",
|
|
||||||
"output_dir: ./outputs/qlora-out\n",
|
|
||||||
"\n",
|
|
||||||
"adapter: qlora\n",
|
|
||||||
"lora_model_dir:\n",
|
|
||||||
"\n",
|
|
||||||
"sequence_len: 4096\n",
|
|
||||||
"sample_packing: true\n",
|
|
||||||
"eval_sample_packing: false\n",
|
|
||||||
"pad_to_sequence_len: true\n",
|
|
||||||
"\n",
|
|
||||||
"lora_r: 32\n",
|
|
||||||
"lora_alpha: 16\n",
|
|
||||||
"lora_dropout: 0.05\n",
|
|
||||||
"lora_target_modules:\n",
|
|
||||||
"lora_target_linear: true\n",
|
|
||||||
"lora_fan_in_fan_out:\n",
|
|
||||||
"\n",
|
|
||||||
"wandb_project:\n",
|
|
||||||
"wandb_entity:\n",
|
|
||||||
"wandb_watch:\n",
|
|
||||||
"wandb_name:\n",
|
|
||||||
"wandb_log_model:\n",
|
|
||||||
"\n",
|
|
||||||
"gradient_accumulation_steps: 4\n",
|
|
||||||
"micro_batch_size: 2\n",
|
|
||||||
"num_epochs: 4\n",
|
|
||||||
"optimizer: paged_adamw_32bit\n",
|
|
||||||
"lr_scheduler: cosine\n",
|
|
||||||
"learning_rate: 0.0002\n",
|
|
||||||
"\n",
|
|
||||||
"train_on_inputs: false\n",
|
|
||||||
"group_by_length: false\n",
|
|
||||||
"bf16: auto\n",
|
|
||||||
"fp16:\n",
|
|
||||||
"tf32: false\n",
|
|
||||||
"\n",
|
|
||||||
"gradient_checkpointing: true\n",
|
|
||||||
"early_stopping_patience:\n",
|
|
||||||
"resume_from_checkpoint:\n",
|
|
||||||
"local_rank:\n",
|
|
||||||
"logging_steps: 1\n",
|
|
||||||
"xformers_attention:\n",
|
|
||||||
"flash_attention: true\n",
|
|
||||||
"\n",
|
|
||||||
"warmup_steps: 10\n",
|
|
||||||
"evals_per_epoch: 4\n",
|
|
||||||
"saves_per_epoch: 1\n",
|
|
||||||
"debug:\n",
|
|
||||||
"deepspeed:\n",
|
|
||||||
"weight_decay: 0.0\n",
|
|
||||||
"fsdp:\n",
|
|
||||||
"fsdp_config:\n",
|
|
||||||
"special_tokens:\n",
|
|
||||||
"\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"# Convert the YAML string to a Python dictionary\n",
|
|
||||||
"yaml_dict = yaml.safe_load(yaml_string)\n",
|
|
||||||
"\n",
|
|
||||||
"# Specify your file path\n",
|
|
||||||
"file_path = 'test_axolotl.yaml'\n",
|
|
||||||
"\n",
|
|
||||||
"# Write the YAML file\n",
|
|
||||||
"with open(file_path, 'w') as file:\n",
|
|
||||||
" yaml.dump(yaml_dict, file)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "bidoj8YLTusD"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Launch the training"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
},
|
||||||
"id": "ydTI2Jk2RStU",
|
{
|
||||||
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
|
"cell_type": "markdown",
|
||||||
},
|
"metadata": {
|
||||||
"outputs": [],
|
"id": "h3nLav8oTRA5"
|
||||||
"source": [
|
},
|
||||||
"# Buy using the ! the comand will be executed as a bash command\n",
|
"source": [
|
||||||
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
"## Install Axolotl and dependencies"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "3c3yGAwnOIdi",
|
||||||
|
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install torch==\"2.1.2\"\n",
|
||||||
|
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
|
||||||
|
"!pip install flash-attn==\"2.5.0\"\n",
|
||||||
|
"!pip install deepspeed==\"0.13.1\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "BW2MFr7HTjub"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Create an yaml config file"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "9pkF2dSoQEUN"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import yaml\n",
|
||||||
|
"\n",
|
||||||
|
"# Your YAML string\n",
|
||||||
|
"yaml_string = \"\"\"\n",
|
||||||
|
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
|
||||||
|
"model_type: LlamaForCausalLM\n",
|
||||||
|
"tokenizer_type: LlamaTokenizer\n",
|
||||||
|
"is_llama_derived_model: true\n",
|
||||||
|
"\n",
|
||||||
|
"load_in_8bit: false\n",
|
||||||
|
"load_in_4bit: true\n",
|
||||||
|
"strict: false\n",
|
||||||
|
"\n",
|
||||||
|
"datasets:\n",
|
||||||
|
" - path: mhenrichsen/alpaca_2k_test\n",
|
||||||
|
" type: alpaca\n",
|
||||||
|
"dataset_prepared_path:\n",
|
||||||
|
"val_set_size: 0.05\n",
|
||||||
|
"output_dir: ./qlora-out\n",
|
||||||
|
"\n",
|
||||||
|
"adapter: qlora\n",
|
||||||
|
"lora_model_dir:\n",
|
||||||
|
"\n",
|
||||||
|
"sequence_len: 1096\n",
|
||||||
|
"sample_packing: true\n",
|
||||||
|
"pad_to_sequence_len: true\n",
|
||||||
|
"\n",
|
||||||
|
"lora_r: 32\n",
|
||||||
|
"lora_alpha: 16\n",
|
||||||
|
"lora_dropout: 0.05\n",
|
||||||
|
"lora_target_modules:\n",
|
||||||
|
"lora_target_linear: true\n",
|
||||||
|
"lora_fan_in_fan_out:\n",
|
||||||
|
"\n",
|
||||||
|
"wandb_project:\n",
|
||||||
|
"wandb_entity:\n",
|
||||||
|
"wandb_watch:\n",
|
||||||
|
"wandb_name:\n",
|
||||||
|
"wandb_log_model:\n",
|
||||||
|
"\n",
|
||||||
|
"mlflow_experiment_name: colab-example\n",
|
||||||
|
"\n",
|
||||||
|
"gradient_accumulation_steps: 1\n",
|
||||||
|
"micro_batch_size: 1\n",
|
||||||
|
"num_epochs: 4\n",
|
||||||
|
"max_steps: 20\n",
|
||||||
|
"optimizer: paged_adamw_32bit\n",
|
||||||
|
"lr_scheduler: cosine\n",
|
||||||
|
"learning_rate: 0.0002\n",
|
||||||
|
"\n",
|
||||||
|
"train_on_inputs: false\n",
|
||||||
|
"group_by_length: false\n",
|
||||||
|
"bf16: false\n",
|
||||||
|
"fp16: true\n",
|
||||||
|
"tf32: false\n",
|
||||||
|
"\n",
|
||||||
|
"gradient_checkpointing: true\n",
|
||||||
|
"early_stopping_patience:\n",
|
||||||
|
"resume_from_checkpoint:\n",
|
||||||
|
"local_rank:\n",
|
||||||
|
"logging_steps: 1\n",
|
||||||
|
"xformers_attention:\n",
|
||||||
|
"flash_attention: false\n",
|
||||||
|
"\n",
|
||||||
|
"warmup_steps: 10\n",
|
||||||
|
"evals_per_epoch:\n",
|
||||||
|
"saves_per_epoch:\n",
|
||||||
|
"debug:\n",
|
||||||
|
"deepspeed:\n",
|
||||||
|
"weight_decay: 0.0\n",
|
||||||
|
"fsdp:\n",
|
||||||
|
"fsdp_config:\n",
|
||||||
|
"special_tokens:\n",
|
||||||
|
"\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Convert the YAML string to a Python dictionary\n",
|
||||||
|
"yaml_dict = yaml.safe_load(yaml_string)\n",
|
||||||
|
"\n",
|
||||||
|
"# Specify your file path\n",
|
||||||
|
"file_path = 'test_axolotl.yaml'\n",
|
||||||
|
"\n",
|
||||||
|
"# Write the YAML file\n",
|
||||||
|
"with open(file_path, 'w') as file:\n",
|
||||||
|
" yaml.dump(yaml_dict, file)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "bidoj8YLTusD"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Launch the training"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "ydTI2Jk2RStU",
|
||||||
|
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Buy using the ! the comand will be executed as a bash command\n",
|
||||||
|
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Play with inference"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Buy using the ! the comand will be executed as a bash command\n",
|
||||||
|
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
||||||
|
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"gpuType": "T4",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
"nbformat": 4,
|
||||||
"cell_type": "markdown",
|
"nbformat_minor": 0
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Play with inference"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Buy using the ! the comand will be executed as a bash command\n",
|
|
||||||
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
|
||||||
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"accelerator": "GPU",
|
|
||||||
"colab": {
|
|
||||||
"gpuType": "T4",
|
|
||||||
"provenance": []
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.12.1"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 4
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
# QLoRA paper Table 9
|
# QLoRA paper Table 9
|
||||||
# - 16 for 7b & 13b
|
# - 16 for 7b & 13b
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 2
|
num_epochs: 2
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/jeopardy-bot-7b
|
output_dir: ./jeopardy-bot-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ wandb_project:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/model-out
|
output_dir: ./model-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lisa-out
|
output_dir: ./lisa-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/relora-out
|
output_dir: ./relora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
chat_template: llama3
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
chat_template: llama3
|
|
||||||
field_messages: messages
|
|
||||||
message_field_role: role
|
|
||||||
message_field_content: content
|
|
||||||
roles:
|
|
||||||
user:
|
|
||||||
- user
|
|
||||||
assistant:
|
|
||||||
- assistant
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
s2_attention:
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
@@ -24,9 +24,6 @@ lora_alpha: 16
|
|||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
lora_modules_to_save:
|
|
||||||
- embed_tokens
|
|
||||||
- lm_head
|
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/out/qlora-llama3-70b
|
output_dir: ./out/qlora-llama3-70b
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
model_config:
|
model_config:
|
||||||
output_router_logits: true
|
output_router_logits: true
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ datasets:
|
|||||||
type: chat_template.argilla
|
type: chat_template.argilla
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./outputs/mistral-qlora-orpo-out
|
output_dir: ./mistral-qlora-orpo-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
model_config:
|
model_config:
|
||||||
output_router_logits: true
|
output_router_logits: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
model_config:
|
model_config:
|
||||||
output_router_logits: true
|
output_router_logits: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||||
unfrozen_parameters:
|
unfrozen_parameters:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ model_config:
|
|||||||
datasets:
|
datasets:
|
||||||
- path: yahma/alpaca-cleaned
|
- path: yahma/alpaca-cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 8000
|
sequence_len: 8000
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/mpt-alpaca-7b
|
output_dir: ./mpt-alpaca-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/openllama-out
|
output_dir: ./openllama-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/phi-sft-out
|
output_dir: ./phi-sft-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/phi-sft-out
|
output_dir: ./phi-sft-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/phi-sft-out
|
output_dir: ./phi-sft-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -1,83 +0,0 @@
|
|||||||
base_model: microsoft/Phi-3-mini-4k-instruct
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0
|
|
||||||
output_dir: ./phi-sft-out
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
adapter:
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r:
|
|
||||||
lora_alpha:
|
|
||||||
lora_dropout:
|
|
||||||
lora_target_linear:
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
wandb_project: phi3
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 12
|
|
||||||
num_epochs: 2
|
|
||||||
optimizer: adamw_torch
|
|
||||||
adam_beta2: 0.95
|
|
||||||
adam_epsilon: 0.00001
|
|
||||||
max_grad_norm: 1.0
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.000003
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 100
|
|
||||||
evals_per_epoch: 4
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.1
|
|
||||||
fsdp:
|
|
||||||
- full_shard
|
|
||||||
- auto_wrap
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_limit_all_gathers: true
|
|
||||||
fsdp_sync_module_states: true
|
|
||||||
fsdp_offload_params: true
|
|
||||||
fsdp_use_orig_params: false
|
|
||||||
fsdp_cpu_ram_efficient_loading: true
|
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
fsdp_transformer_layer_cls_to_wrap: Phi3DecoderLayer
|
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
|
||||||
fsdp_sharding_strategy: FULL_SHARD
|
|
||||||
resize_token_embeddings_to_32x: true
|
|
||||||
special_tokens:
|
|
||||||
pad_token: "<|endoftext|>"
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
base_model: microsoft/Phi-3-mini-4k-instruct
|
|
||||||
trust_remote_code: true
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
chat_template: phi_3
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: garage-bAInd/Open-Platypus
|
|
||||||
type: alpaca:phi
|
|
||||||
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./out
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r: 64
|
|
||||||
lora_alpha: 32
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch
|
|
||||||
adam_beta2: 0.95
|
|
||||||
adam_epsilon: 0.00001
|
|
||||||
max_grad_norm: 1.0
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 5.0e-6
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: True
|
|
||||||
early_stopping_patience: 3
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
eval_steps: 1000
|
|
||||||
save_steps: 5000
|
|
||||||
eval_table_size: 2
|
|
||||||
eval_batch_size: 2
|
|
||||||
eval_sample_packing: false
|
|
||||||
eval_max_new_tokens: 32
|
|
||||||
eval_causal_lm_metrics: ["perplexity"]
|
|
||||||
do_causal_lm_eval: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.2
|
|
||||||
debug: true
|
|
||||||
weight_decay: 0.1
|
|
||||||
resize_token_embeddings_to_32x: true
|
|
||||||
@@ -26,7 +26,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/pythia-12b
|
output_dir: ./pythia-12b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 5
|
num_epochs: 5
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/lora-alpaca-pythia
|
output_dir: ./lora-alpaca-pythia
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 2048 # supports up to 8192
|
sequence_len: 2048 # supports up to 8192
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 2048 # supports up to 8192
|
sequence_len: 2048 # supports up to 8192
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 1024 # supports up to 32k
|
sequence_len: 1024 # supports up to 32k
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 1024 # supports up to 32k
|
sequence_len: 1024 # supports up to 32k
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -1,75 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2-7B
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
eval_sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 64
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_torch
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
- full_shard
|
|
||||||
- auto_wrap
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_limit_all_gathers: true
|
|
||||||
fsdp_sync_module_states: true
|
|
||||||
fsdp_offload_params: true
|
|
||||||
fsdp_use_orig_params: false
|
|
||||||
fsdp_cpu_ram_efficient_loading: true
|
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
|
||||||
special_tokens:
|
|
||||||
@@ -24,7 +24,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/redpajama-alpaca-3b
|
output_dir: ./redpajama-alpaca-3b
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/lora-replit
|
output_dir: ./lora-replit
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.2
|
val_set_size: 0.2
|
||||||
output_dir: ./outputs/qlora
|
output_dir: ./qlora
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ pretraining_dataset:
|
|||||||
type: pretrain
|
type: pretrain
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/model-out
|
output_dir: ./model-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,14 +11,13 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
# QLoRA paper Table 9
|
# QLoRA paper Table 9
|
||||||
# - 16 for 7b & 13b
|
# - 16 for 7b & 13b
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ eval_sample_packing: false
|
|||||||
eval_batch_size: 1
|
eval_batch_size: 1
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
output_dir: ./outputs/qlora-out
|
output_dir: ./qlora-out
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -1,22 +1,22 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.10.0
|
||||||
transformers==4.41.1
|
transformers @ git+https://github.com/huggingface/transformers.git@43d17c18360ac9c3d3491389328e2fe55fe8f9ce
|
||||||
tokenizers==0.19.1
|
tokenizers==0.15.0
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.0
|
||||||
accelerate==0.30.1
|
accelerate==0.28.0
|
||||||
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
|
deepspeed==0.13.1
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
datasets==2.19.1
|
datasets==2.15.0
|
||||||
flash-attn==2.5.8
|
flash-attn==2.5.5
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers==0.0.26.post1
|
xformers==0.0.22
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
@@ -28,7 +28,7 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
||||||
@@ -39,6 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@f18253bf2d747f68acc9cd89da95c85ebf59dbb9
|
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Export specific ENV variables to /etc/rp_environment
|
|
||||||
echo "Exporting environment variables..."
|
|
||||||
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
|
||||||
conda init
|
|
||||||
# this needs to come after conda init
|
|
||||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
|
||||||
|
|
||||||
add_keys_to_authorized() {
|
|
||||||
local key_value=$1
|
|
||||||
|
|
||||||
# Create the ~/.ssh directory and set permissions
|
|
||||||
mkdir -p ~/.ssh
|
|
||||||
chmod 700 ~/.ssh
|
|
||||||
|
|
||||||
# Create the authorized_keys file if it doesn't exist
|
|
||||||
touch ~/.ssh/authorized_keys
|
|
||||||
|
|
||||||
# Initialize an empty key variable
|
|
||||||
local key=""
|
|
||||||
|
|
||||||
# Read the key variable word by word
|
|
||||||
for word in $key_value; do
|
|
||||||
# Check if the word looks like the start of a key
|
|
||||||
if [[ $word == ssh-* ]]; then
|
|
||||||
# If there's a key being built, add it to the authorized_keys file
|
|
||||||
if [[ -n $key ]]; then
|
|
||||||
echo $key >> ~/.ssh/authorized_keys
|
|
||||||
fi
|
|
||||||
# Start a new key
|
|
||||||
key=$word
|
|
||||||
else
|
|
||||||
# Append the word to the current key
|
|
||||||
key="$key $word"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Add the last key to the authorized_keys file
|
|
||||||
if [[ -n $key ]]; then
|
|
||||||
echo $key >> ~/.ssh/authorized_keys
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Set the correct permissions
|
|
||||||
chmod 600 ~/.ssh/authorized_keys
|
|
||||||
chmod 700 -R ~/.ssh
|
|
||||||
}
|
|
||||||
|
|
||||||
if [[ $PUBLIC_KEY ]]; then
|
|
||||||
# runpod
|
|
||||||
add_keys_to_authorized "$PUBLIC_KEY"
|
|
||||||
# Start the SSH service in the background
|
|
||||||
service ssh start
|
|
||||||
elif [[ $SSH_KEY ]]; then
|
|
||||||
# latitude.sh
|
|
||||||
add_keys_to_authorized "$SSH_KEY"
|
|
||||||
# Start the SSH service in the background
|
|
||||||
service ssh start
|
|
||||||
else
|
|
||||||
echo "No PUBLIC_KEY or SSH_KEY environment variable provided, not starting openSSH daemon"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Check if JUPYTER_PASSWORD is set and not empty
|
|
||||||
if [ -n "$JUPYTER_PASSWORD" ]; then
|
|
||||||
# Set JUPYTER_TOKEN to the value of JUPYTER_PASSWORD
|
|
||||||
export JUPYTER_TOKEN="$JUPYTER_PASSWORD"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
|
||||||
# Run Jupyter Lab in the background
|
|
||||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -d "/workspace/data/axolotl-artifacts" ]; then
|
|
||||||
mkdir -p /workspace/data/axolotl-artifacts
|
|
||||||
fi
|
|
||||||
if [ ! -L "/workspace/axolotl/outputs" ]; then
|
|
||||||
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Execute the passed arguments (CMD)
|
|
||||||
exec "$@"
|
|
||||||
@@ -5,53 +5,20 @@ echo "Exporting environment variables..."
|
|||||||
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
||||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||||
|
|
||||||
add_keys_to_authorized() {
|
|
||||||
local key_value=$1
|
|
||||||
|
|
||||||
# Create the ~/.ssh directory and set permissions
|
|
||||||
mkdir -p ~/.ssh
|
|
||||||
chmod 700 ~/.ssh
|
|
||||||
|
|
||||||
# Create the authorized_keys file if it doesn't exist
|
|
||||||
touch ~/.ssh/authorized_keys
|
|
||||||
|
|
||||||
# Initialize an empty key variable
|
|
||||||
local key=""
|
|
||||||
|
|
||||||
# Read the key variable word by word
|
|
||||||
for word in $key_value; do
|
|
||||||
# Check if the word looks like the start of a key
|
|
||||||
if [[ $word == ssh-* ]]; then
|
|
||||||
# If there's a key being built, add it to the authorized_keys file
|
|
||||||
if [[ -n $key ]]; then
|
|
||||||
echo $key >> ~/.ssh/authorized_keys
|
|
||||||
fi
|
|
||||||
# Start a new key
|
|
||||||
key=$word
|
|
||||||
else
|
|
||||||
# Append the word to the current key
|
|
||||||
key="$key $word"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Add the last key to the authorized_keys file
|
|
||||||
if [[ -n $key ]]; then
|
|
||||||
echo $key >> ~/.ssh/authorized_keys
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Set the correct permissions
|
|
||||||
chmod 600 ~/.ssh/authorized_keys
|
|
||||||
chmod 700 -R ~/.ssh
|
|
||||||
}
|
|
||||||
|
|
||||||
if [[ $PUBLIC_KEY ]]; then
|
if [[ $PUBLIC_KEY ]]; then
|
||||||
# runpod
|
# runpod
|
||||||
add_keys_to_authorized "$PUBLIC_KEY"
|
mkdir -p ~/.ssh
|
||||||
|
chmod 700 ~/.ssh
|
||||||
|
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
|
||||||
|
chmod 700 -R ~/.ssh
|
||||||
# Start the SSH service in the background
|
# Start the SSH service in the background
|
||||||
service ssh start
|
service ssh start
|
||||||
elif [[ $SSH_KEY ]]; then
|
elif [ -n "$SSH_KEY" ]; then
|
||||||
# latitude.sh
|
# latitude.sh
|
||||||
add_keys_to_authorized "$SSH_KEY"
|
mkdir -p ~/.ssh
|
||||||
|
chmod 700 ~/.ssh
|
||||||
|
echo $SSH_KEY >> ~/.ssh/authorized_keys
|
||||||
|
chmod 700 -R ~/.ssh
|
||||||
# Start the SSH service in the background
|
# Start the SSH service in the background
|
||||||
service ssh start
|
service ssh start
|
||||||
else
|
else
|
||||||
@@ -69,12 +36,5 @@ if [ "$JUPYTER_DISABLE" != "1" ]; then
|
|||||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -d "/workspace/data/axolotl-artifacts" ]; then
|
|
||||||
mkdir -p /workspace/data/axolotl-artifacts
|
|
||||||
fi
|
|
||||||
if [ ! -L "/workspace/axolotl/outputs" ]; then
|
|
||||||
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Execute the passed arguments (CMD)
|
# Execute the passed arguments (CMD)
|
||||||
exec "$@"
|
exec "$@"
|
||||||
|
|||||||
25
setup.py
25
setup.py
@@ -30,11 +30,8 @@ def parse_requirements():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
# don't install xformers on MacOS
|
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
|
||||||
else:
|
else:
|
||||||
# detect the version of torch already installed
|
|
||||||
# and set it so dependencies don't clobber the torch version
|
|
||||||
torch_version = version("torch")
|
torch_version = version("torch")
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
@@ -48,15 +45,9 @@ def parse_requirements():
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 3):
|
if (major, minor) >= (2, 1):
|
||||||
pass
|
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||||
elif (major, minor) >= (2, 2):
|
_install_requires.append("xformers>=0.0.23")
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
|
||||||
_install_requires.append("xformers>=0.0.25.post1")
|
|
||||||
else:
|
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
|
||||||
_install_requires.append("xformers>=0.0.23.post1")
|
|
||||||
|
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -68,7 +59,7 @@ install_requires, dependency_links = parse_requirements()
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="axolotl",
|
name="axolotl",
|
||||||
version="0.4.1",
|
version="0.4.0",
|
||||||
description="LLM Trainer",
|
description="LLM Trainer",
|
||||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
@@ -77,13 +68,13 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.5.8",
|
"flash-attn==2.5.5",
|
||||||
],
|
],
|
||||||
"fused-dense-lib": [
|
"fused-dense-lib": [
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
|
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
|
"deepspeed==0.13.1",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
@@ -25,8 +25,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
flash_attention=False,
|
flash_attention=False,
|
||||||
deepspeed=None,
|
|
||||||
fsdp=None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -42,7 +40,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cfg.flash_attention = False
|
parsed_cfg.flash_attention = False
|
||||||
parsed_cfg.deepspeed = None
|
parsed_cfg.deepspeed = None
|
||||||
parsed_cfg.fsdp = None
|
parsed_cfg.fsdp = None
|
||||||
parsed_cfg.fsdp_config = None
|
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ from typing import Union
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -21,10 +19,7 @@ from axolotl.cli import (
|
|||||||
)
|
)
|
||||||
from axolotl.common.cli import PreprocessCliArgs
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.sharegpt import (
|
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
||||||
register_chatml_template,
|
|
||||||
register_llama3_template,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||||
|
|
||||||
@@ -41,22 +36,13 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if parsed_cfg.chat_template == "chatml":
|
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
|
||||||
if parsed_cfg.default_system_message:
|
LOG.info(
|
||||||
LOG.info(
|
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
)
|
||||||
)
|
register_chatml_template(parsed_cfg.default_system_message)
|
||||||
register_chatml_template(parsed_cfg.default_system_message)
|
else:
|
||||||
else:
|
register_chatml_template()
|
||||||
register_chatml_template()
|
|
||||||
elif parsed_cfg.chat_template == "llama3":
|
|
||||||
if parsed_cfg.default_system_message:
|
|
||||||
LOG.info(
|
|
||||||
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
|
||||||
)
|
|
||||||
register_llama3_template(parsed_cfg.default_system_message)
|
|
||||||
else:
|
|
||||||
register_llama3_template()
|
|
||||||
|
|
||||||
if not parsed_cfg.dataset_prepared_path:
|
if not parsed_cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
@@ -73,11 +59,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
else:
|
else:
|
||||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
if parsed_cli_args.download:
|
|
||||||
model_name = parsed_cfg.base_model
|
|
||||||
with init_empty_weights():
|
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
Fore.GREEN
|
Fore.GREEN
|
||||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||||
|
|||||||
@@ -19,10 +19,7 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.prompt_strategies.sharegpt import (
|
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
||||||
register_chatml_template,
|
|
||||||
register_llama3_template,
|
|
||||||
)
|
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
@@ -50,14 +47,6 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
|
||||||
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
|
||||||
LOG.info(
|
|
||||||
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
|
||||||
)
|
|
||||||
register_llama3_template(cfg.default_system_message)
|
|
||||||
else:
|
|
||||||
register_llama3_template()
|
|
||||||
|
|
||||||
if cfg.rl: # and cfg.rl != "orpo":
|
if cfg.rl: # and cfg.rl != "orpo":
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ class PreprocessCliArgs:
|
|||||||
debug_text_only: bool = field(default=False)
|
debug_text_only: bool = field(default=False)
|
||||||
debug_num_examples: int = field(default=1)
|
debug_num_examples: int = field(default=1)
|
||||||
prompter: Optional[str] = field(default=None)
|
prompter: Optional[str] = field(default=None)
|
||||||
download: Optional[bool] = field(default=True)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
|
|||||||
212
src/axolotl/core/trainer_builder.py
Executable file → Normal file
212
src/axolotl/core/trainer_builder.py
Executable file → Normal file
@@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
|
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
|
|||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SaveModelCallback,
|
SaveModelOnTrainEndCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -91,12 +91,11 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingMixins:
|
class AxolotlTrainingArguments(TrainingArguments):
|
||||||
"""
|
"""
|
||||||
Mixin class for the Axolotl training args.
|
Extend the base TrainingArguments for axolotl helpers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
default=None, metadata={"help": "HF model configuration model_type."}
|
default=None, metadata={"help": "HF model configuration model_type."}
|
||||||
)
|
)
|
||||||
@@ -126,22 +125,14 @@ class AxolotlTrainingMixins:
|
|||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||||
)
|
)
|
||||||
sample_packing_bin_size: int = field(
|
|
||||||
default=200,
|
|
||||||
metadata={
|
|
||||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sample_packing_group_size: int = field(
|
|
||||||
default=100000,
|
|
||||||
metadata={
|
|
||||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=2048,
|
default=2048,
|
||||||
metadata={"help": "The maximum sequence length the model can handle"},
|
metadata={"help": "The maximum sequence length the model can handle"},
|
||||||
)
|
)
|
||||||
|
sample_packing_seq_len_multiplier: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||||
|
)
|
||||||
relora_steps: Optional[int] = field(
|
relora_steps: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "how often to reset for ReLoRA"},
|
metadata={"help": "how often to reset for ReLoRA"},
|
||||||
@@ -228,37 +219,6 @@ class AxolotlTrainingMixins:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
|
||||||
"""
|
|
||||||
Training arguments for Causal trainer
|
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
|
||||||
so it can't be used as a mixin.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|
||||||
"""
|
|
||||||
DPO config for DPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
|
||||||
"""
|
|
||||||
ORPO config for ORPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
|
||||||
"""
|
|
||||||
KTO config for KTO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
"""
|
"""
|
||||||
Extend the base Trainer for axolotl helpers
|
Extend the base Trainer for axolotl helpers
|
||||||
@@ -386,13 +346,11 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
RandomSampler(self.train_dataset),
|
RandomSampler(self.train_dataset),
|
||||||
|
batch_size=batch_size,
|
||||||
|
drop_last=True,
|
||||||
|
batch_max_len=batch_max_len,
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
batch_max_len=batch_max_len,
|
|
||||||
batch_size=batch_size,
|
|
||||||
group_size=self.args.sample_packing_group_size,
|
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
)
|
||||||
if self.args.curriculum_sampling:
|
if self.args.curriculum_sampling:
|
||||||
return SequentialSampler(self.train_dataset)
|
return SequentialSampler(self.train_dataset)
|
||||||
@@ -412,13 +370,11 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
SequentialSampler(eval_dataset),
|
SequentialSampler(eval_dataset),
|
||||||
lengths=get_dataset_lengths(self.eval_dataset),
|
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
|
||||||
batch_max_len=batch_max_len,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
group_size=self.args.sample_packing_group_size,
|
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
|
batch_max_len=batch_max_len,
|
||||||
|
lengths=get_dataset_lengths(eval_dataset),
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
)
|
)
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
@@ -459,8 +415,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
)
|
)
|
||||||
if eval_dataset:
|
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
|
||||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.train_data_collator
|
self.train_data_collator
|
||||||
@@ -844,40 +798,6 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.optimizer = None
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
|
||||||
if self.args.loraplus_lr_ratio is None:
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args,
|
|
||||||
opt_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
if loraplus_lr_ratio:
|
|
||||||
print("Using lora+")
|
|
||||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
optimizer_kwargs,
|
|
||||||
loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
@wraps(DPOTrainer.push_to_hub)
|
@wraps(DPOTrainer.push_to_hub)
|
||||||
def push_to_hub(self, *args, **kwargs) -> str:
|
def push_to_hub(self, *args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -906,14 +826,6 @@ class AxolotlORPOTrainer(ORPOTrainer):
|
|||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(KTOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Base class for trainer builder
|
Base class for trainer builder
|
||||||
@@ -1033,7 +945,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.loss_watchdog_threshold is not None:
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
callbacks.append(SaveModelCallback())
|
callbacks.append(SaveModelOnTrainEndCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -1159,6 +1071,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.save_safetensors is not None:
|
if self.cfg.save_safetensors is not None:
|
||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
|
if self.cfg.sample_packing_eff_est:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_efficiency"
|
||||||
|
] = self.cfg.sample_packing_eff_est
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"dataloader_pin_memory"
|
"dataloader_pin_memory"
|
||||||
@@ -1206,8 +1123,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
|
|
||||||
|
|
||||||
if self.cfg.do_bench_eval:
|
if self.cfg.do_bench_eval:
|
||||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||||
if self.cfg.bench_dataset:
|
if self.cfg.bench_dataset:
|
||||||
@@ -1287,14 +1202,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
report_to = []
|
report_to = None
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
report_to.append("wandb")
|
report_to = "wandb"
|
||||||
if self.cfg.use_mlflow:
|
if self.cfg.use_mlflow:
|
||||||
report_to.append("mlflow")
|
report_to = "mlflow"
|
||||||
if self.cfg.use_tensorboard:
|
|
||||||
report_to.append("tensorboard")
|
|
||||||
|
|
||||||
training_arguments_kwargs["report_to"] = report_to
|
training_arguments_kwargs["report_to"] = report_to
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||||
@@ -1334,27 +1246,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["weight_decay"] = (
|
training_arguments_kwargs["weight_decay"] = (
|
||||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||||
)
|
)
|
||||||
|
training_arguments_kwargs["sample_packing"] = (
|
||||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||||
training_arguments_kwargs[
|
|
||||||
"multipack_real_batches"
|
|
||||||
] = not self.cfg.flash_attention
|
|
||||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
|
||||||
self.cfg.eval_sample_packing
|
|
||||||
)
|
)
|
||||||
if self.cfg.sample_packing_bin_size is not None:
|
training_arguments_kwargs["multipack_real_batches"] = (
|
||||||
training_arguments_kwargs[
|
self.cfg.flash_attention is not True
|
||||||
"sample_packing_bin_size"
|
)
|
||||||
] = self.cfg.sample_packing_bin_size
|
training_arguments_kwargs["eval_sample_packing"] = (
|
||||||
if self.cfg.sample_packing_group_size is not None:
|
self.cfg.sample_packing
|
||||||
training_arguments_kwargs[
|
if self.cfg.eval_sample_packing is not False
|
||||||
"sample_packing_group_size"
|
else False
|
||||||
] = self.cfg.sample_packing_group_size
|
)
|
||||||
if self.cfg.sample_packing_eff_est:
|
training_arguments_kwargs[
|
||||||
training_arguments_kwargs[
|
"sample_packing_seq_len_multiplier"
|
||||||
"sample_packing_efficiency"
|
] = self.cfg.micro_batch_size
|
||||||
] = self.cfg.sample_packing_eff_est
|
|
||||||
|
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
@@ -1524,7 +1429,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
callbacks.append(SaveModelCallback())
|
callbacks.append(SaveModelOnTrainEndCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -1565,8 +1470,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
|
|
||||||
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
|
||||||
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
|
|
||||||
training_args_kwargs["lr_scheduler_type"] = (
|
training_args_kwargs["lr_scheduler_type"] = (
|
||||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||||
)
|
)
|
||||||
@@ -1619,38 +1522,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = TrainingArguments
|
||||||
if self.cfg.rpo_alpha is not None:
|
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = AxolotlORPOConfig
|
training_args_cls = ORPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
|
||||||
if self.cfg.max_prompt_len:
|
training_args_cls = DPOConfig
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
|
||||||
|
|
||||||
if self.cfg.rl == "kto":
|
|
||||||
training_args_cls = AxolotlKTOConfig
|
|
||||||
|
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
|
||||||
training_args_kwargs["desirable_weight"] = (
|
|
||||||
self.cfg.kto_desirable_weight or 1.0
|
|
||||||
)
|
|
||||||
training_args_kwargs["undesirable_weight"] = (
|
|
||||||
self.cfg.kto_undesirable_weight or 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
|
||||||
if self.cfg.max_prompt_len:
|
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls(
|
||||||
output_dir=self.cfg.output_dir,
|
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=self.cfg.max_steps or total_num_steps,
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
learning_rate=self.cfg.learning_rate,
|
learning_rate=self.cfg.learning_rate,
|
||||||
|
output_dir=self.cfg.output_dir,
|
||||||
warmup_steps=self.cfg.warmup_steps,
|
warmup_steps=self.cfg.warmup_steps,
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
logging_steps=1,
|
logging_steps=1,
|
||||||
@@ -1668,8 +1553,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
elif self.cfg.rl == "kto_pair":
|
elif self.cfg.rl in ["kto_pair", "sppo_hard", "nca_pair"]:
|
||||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
dpo_trainer_kwargs["loss_type"] = self.cfg.rl
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
@@ -1678,9 +1563,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
# these aren't used for the ORPO trainer
|
||||||
@@ -1693,9 +1578,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
elif self.cfg.rl == "kto":
|
|
||||||
trainer_cls = AxolotlKTOTrainer
|
|
||||||
trainer_cls_args = [self.model]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
|
|||||||
@@ -123,17 +123,6 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.LLAMA3:
|
|
||||||
if self.system_message:
|
|
||||||
# For llama3, the system message is NOT incorporated into the first human instruction
|
|
||||||
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
|
|
||||||
yield "", system_prompt
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
|
|
||||||
else:
|
|
||||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.GEMMA:
|
if self.sep_style == SeparatorStyle.GEMMA:
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
raise ValueError("Gemma chat template does not support system messages")
|
raise ValueError("Gemma chat template does not support system messages")
|
||||||
|
|||||||
@@ -42,9 +42,9 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
|||||||
return final_hidden_states, router_logits
|
return final_hidden_states, router_logits
|
||||||
|
|
||||||
from transformers.models.mixtral.modeling_mixtral import (
|
from transformers.models.mixtral.modeling_mixtral import (
|
||||||
MixtralBlockSparseTop2MLP,
|
MixtralBLockSparseTop2MLP,
|
||||||
MixtralSparseMoeBlock,
|
MixtralSparseMoeBlock,
|
||||||
)
|
)
|
||||||
|
|
||||||
MixtralBlockSparseTop2MLP.forward = mlp_forward
|
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
||||||
MixtralSparseMoeBlock.forward = moe_forward
|
MixtralSparseMoeBlock.forward = moe_forward
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
|||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||||
"llama",
|
|
||||||
"mixtral",
|
"mixtral",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
@@ -19,7 +18,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"gemma",
|
"gemma",
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"deepseek_v2",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -30,10 +28,6 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
)
|
)
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
patch_mixtral_moe_forward_zero3()
|
patch_mixtral_moe_forward_zero3()
|
||||||
elif model_type == "llama":
|
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "qwen2":
|
elif model_type == "qwen2":
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
@@ -62,8 +56,6 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||||
elif model_type == "jamba":
|
elif model_type == "jamba":
|
||||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
||||||
elif model_type == "deepseek_v2":
|
|
||||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name, config_name, modeling_name):
|
def patch_remote(model_name, config_name, modeling_name):
|
||||||
|
|||||||
@@ -1,267 +0,0 @@
|
|||||||
"""module for patching with unsloth optimizations"""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import types
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from peft import PeftModelForCausalLM
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
LlamaFlashAttention2,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
|
|
||||||
|
|
||||||
ORIGINAL_CEL_CODE = """ if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCHED_CEL_CODE = """ if labels is not None:
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
loss = fast_cross_entropy_loss(
|
|
||||||
logits = shift_logits,
|
|
||||||
labels = shift_labels,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
ORIGINAL_QKV_CODE = """
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
""".lstrip(
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
PATCHED_QKV_CODE = """
|
|
||||||
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
|
|
||||||
""".lstrip(
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
ORIGINAL_O_CODE = """
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
""".lstrip(
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
PATCHED_O_CODE = """
|
|
||||||
attn_output = self.apply_o(self, attn_output)
|
|
||||||
""".lstrip(
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def original_apply_qkv(self, hidden_states):
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
return query_states, key_states, value_states
|
|
||||||
|
|
||||||
|
|
||||||
def original_apply_o(self, hidden_states):
|
|
||||||
attn_output = self.o_proj(hidden_states)
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
def get_forward_code() -> str:
|
|
||||||
forward = inspect.getsource(LlamaForCausalLM.forward)
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def test_cel_is_patchable() -> bool:
|
|
||||||
forward = get_forward_code()
|
|
||||||
return ORIGINAL_CEL_CODE in forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_self_attn_code() -> str:
|
|
||||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def test_self_attn_is_patchable() -> bool:
|
|
||||||
qkv = get_self_attn_code()
|
|
||||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
|
|
||||||
|
|
||||||
|
|
||||||
def integrate_cross_entropy_loss_patch():
|
|
||||||
forward = get_forward_code()
|
|
||||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
|
||||||
forward, _ = detab_code(forward)
|
|
||||||
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
|
||||||
|
|
||||||
forward = forward.replace(
|
|
||||||
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
|
||||||
)
|
|
||||||
forward = forward.replace(
|
|
||||||
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
|
||||||
forward = forward.replace(
|
|
||||||
"def forward(",
|
|
||||||
"def fast_cross_entropy_loss_forward(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.models.llama.modeling_llama):
|
|
||||||
if item in forward:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from transformers.models.llama.modeling_llama import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
print("patching unsloth fast_cross_entropy_loss")
|
|
||||||
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
|
||||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
|
||||||
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
|
||||||
return code, spaces
|
|
||||||
|
|
||||||
|
|
||||||
def patch_self_attn_lora():
|
|
||||||
self_attn_forward = get_self_attn_code()
|
|
||||||
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
|
||||||
self_attn_forward
|
|
||||||
)
|
|
||||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
|
||||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
|
|
||||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
|
|
||||||
|
|
||||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
|
||||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
|
||||||
self_attn_forward = self_attn_forward.replace(
|
|
||||||
"def forward(",
|
|
||||||
"def unsloth_attn_forward(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.models.llama.modeling_llama):
|
|
||||||
if item in self_attn_forward:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from transformers.models.llama.modeling_llama import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
print("patching unsloth attn lora")
|
|
||||||
LlamaFlashAttention2.forward = (
|
|
||||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
|
||||||
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
|
||||||
from unsloth.kernels import apply_lora_mlp_swiglu
|
|
||||||
|
|
||||||
apply_lora_mlp = apply_lora_mlp_swiglu
|
|
||||||
elif peft_model.base_model.config.model_type == "gemma":
|
|
||||||
from unsloth.kernels import apply_lora_mlp_geglu_approx
|
|
||||||
|
|
||||||
apply_lora_mlp = apply_lora_mlp_geglu_approx
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Model type {peft_model.base_model.config.model_type} not supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx, layer in enumerate(peft_model.model.model.layers):
|
|
||||||
layer_modules = [
|
|
||||||
getattr(layer.mlp, linear_proj)
|
|
||||||
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
|
|
||||||
]
|
|
||||||
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
|
||||||
mlp_no_bias = all(
|
|
||||||
getattr(module, "base_layer", module).bias is None
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
mlp_not_dora = all(
|
|
||||||
getattr(module, "lora_magnitude_vector", None) is None
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
|
||||||
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
|
||||||
else:
|
|
||||||
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|
||||||
from unsloth.kernels import apply_lora_o, apply_lora_qkv
|
|
||||||
|
|
||||||
for idx, layer in enumerate(peft_model.model.model.layers):
|
|
||||||
if cfg.unsloth_lora_qkv:
|
|
||||||
layer_modules = [
|
|
||||||
getattr(layer.self_attn, linear_proj)
|
|
||||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
|
||||||
]
|
|
||||||
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
|
||||||
qkv_no_bias = all(
|
|
||||||
getattr(module, "base_layer", module).bias is None
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
qkv_not_dora = all(
|
|
||||||
getattr(module, "lora_magnitude_vector", None) is None
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
|
|
||||||
layer.self_attn.apply_qkv = apply_lora_qkv
|
|
||||||
else:
|
|
||||||
layer.self_attn.apply_qkv = original_apply_qkv
|
|
||||||
logging.warning(
|
|
||||||
"unable to apply unsloth lora qkv patch to layer %d", idx
|
|
||||||
)
|
|
||||||
if cfg.unsloth_lora_o:
|
|
||||||
layer_modules = [
|
|
||||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
|
||||||
]
|
|
||||||
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
|
||||||
o_no_bias = all(
|
|
||||||
getattr(module, "base_layer", module).bias is None
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
o_not_dora = all(
|
|
||||||
getattr(module, "lora_magnitude_vector", None) is None
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_o_lora and o_no_bias and o_not_dora:
|
|
||||||
layer.self_attn.apply_o = apply_lora_o
|
|
||||||
else:
|
|
||||||
layer.self_attn.apply_o = original_apply_o
|
|
||||||
logging.warning(
|
|
||||||
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
|
||||||
)
|
|
||||||
@@ -2,12 +2,9 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
|
|
||||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.prompt_strategies")
|
|
||||||
|
|
||||||
|
|
||||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||||
try:
|
try:
|
||||||
@@ -25,8 +22,5 @@ def load(strategy, tokenizer, cfg, ds_cfg):
|
|||||||
if "ds_cfg" in sig.parameters:
|
if "ds_cfg" in sig.parameters:
|
||||||
load_kwargs["ds_cfg"] = ds_cfg
|
load_kwargs["ds_cfg"] = ds_cfg
|
||||||
return func(tokenizer, cfg, **load_kwargs)
|
return func(tokenizer, cfg, **load_kwargs)
|
||||||
except ModuleNotFoundError:
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
return None
|
|
||||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
|
||||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,56 +1,24 @@
|
|||||||
"""
|
"""
|
||||||
HF Chat Templates prompt strategy
|
HF Chat Templates prompt strategy
|
||||||
"""
|
"""
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import Prompter
|
from axolotl.prompters import Prompter
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplatePrompter(Prompter):
|
class ChatTemplatePrompter(Prompter):
|
||||||
"""prompter for HF chat templates"""
|
"""prompter for HF chat templates"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, tokenizer, chat_template=None, max_length=2048):
|
||||||
self,
|
|
||||||
tokenizer,
|
|
||||||
chat_template=None,
|
|
||||||
max_length=2048,
|
|
||||||
message_field_role: str = "from",
|
|
||||||
message_field_content: str = "value",
|
|
||||||
roles: Optional[Dict[str, List[str]]] = None,
|
|
||||||
):
|
|
||||||
if roles:
|
|
||||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
|
||||||
else:
|
|
||||||
self.roles = {
|
|
||||||
"human": "user",
|
|
||||||
"user": "user",
|
|
||||||
"assistant": "assistant",
|
|
||||||
"gpt": "assistant",
|
|
||||||
"system": "system",
|
|
||||||
}
|
|
||||||
self.message_field_role = message_field_role
|
|
||||||
self.message_field_content = message_field_content
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||||
turns = [
|
|
||||||
{
|
|
||||||
"role": self.roles[t[self.message_field_role]],
|
|
||||||
"content": t[self.message_field_content],
|
|
||||||
}
|
|
||||||
for t in conversation
|
|
||||||
]
|
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
turns,
|
conversation,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
@@ -63,19 +31,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
Tokenizing strategy for instruction-based prompts.
|
Tokenizing strategy for instruction-based prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_messages = "conversations"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def messages(self):
|
|
||||||
return self._messages
|
|
||||||
|
|
||||||
@messages.setter
|
|
||||||
def messages(self, messages):
|
|
||||||
self._messages = messages
|
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = self.get_conversation_thread(prompt)
|
||||||
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
|
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
|
|
||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
@@ -93,37 +51,28 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
return prompt[self.messages]
|
conversations = prompt["conversations"]
|
||||||
|
# remap roles - allow for assistant turn
|
||||||
|
role_map = {
|
||||||
|
"human": "user",
|
||||||
|
"user": "user",
|
||||||
|
"assistant": "assistant",
|
||||||
|
"gpt": "assistant",
|
||||||
|
}
|
||||||
|
turns = [
|
||||||
|
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
|
||||||
|
]
|
||||||
|
return turns
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
chat_template = (
|
chat_template = (
|
||||||
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
||||||
)
|
)
|
||||||
message_field_role = (
|
|
||||||
ds_cfg["message_field_role"]
|
|
||||||
if ds_cfg and "message_field_role" in ds_cfg
|
|
||||||
else "from"
|
|
||||||
)
|
|
||||||
message_field_content = (
|
|
||||||
ds_cfg["message_field_content"]
|
|
||||||
if ds_cfg and "message_field_content" in ds_cfg
|
|
||||||
else "value"
|
|
||||||
)
|
|
||||||
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
|
||||||
tokenizer,
|
|
||||||
chat_templates(chat_template),
|
|
||||||
message_field_role=message_field_role,
|
|
||||||
message_field_content=message_field_content,
|
|
||||||
roles=roles,
|
|
||||||
),
|
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
|
||||||
strategy.messages = ds_cfg["field_messages"]
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
"""
|
|
||||||
DPO strategies for llama-3 chat template
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def argilla(
|
|
||||||
cfg,
|
|
||||||
**kwargs,
|
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
def transform_fn(sample):
|
|
||||||
if "system" in sample and sample["system"]:
|
|
||||||
sample["prompt"] = (
|
|
||||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
|
|
||||||
|
|
||||||
def argilla_chat(
|
|
||||||
cfg,
|
|
||||||
**kwargs,
|
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
"""
|
|
||||||
for argilla/dpo-mix-7k conversations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def transform_fn(sample):
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
|
|
||||||
|
|
||||||
def icr(
|
|
||||||
cfg,
|
|
||||||
**kwargs,
|
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
"""
|
|
||||||
chatml transforms for datasets with system, input, chosen, rejected
|
|
||||||
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
|
|
||||||
"""
|
|
||||||
|
|
||||||
def transform_fn(sample):
|
|
||||||
if "system" in sample and sample["system"]:
|
|
||||||
sample["prompt"] = (
|
|
||||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
|
|
||||||
|
|
||||||
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
"""
|
|
||||||
For Intel Orca DPO Pairs
|
|
||||||
"""
|
|
||||||
|
|
||||||
def transform_fn(sample):
|
|
||||||
if "system" in sample and sample["system"]:
|
|
||||||
sample["prompt"] = (
|
|
||||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_pairs(
|
|
||||||
cfg, **kwargs
|
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
def transform_fn(sample):
|
|
||||||
if "system" in sample and sample["system"]:
|
|
||||||
sample["prompt"] = (
|
|
||||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
|
|
||||||
|
|
||||||
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
"""
|
|
||||||
for ultrafeedback binarized conversations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def transform_fn(sample):
|
|
||||||
if "system" in sample and sample["system"]:
|
|
||||||
sample["prompt"] = (
|
|
||||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
return transform_fn
|
|
||||||
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
DPO strategies for mistral instruct
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def argilla_chat(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for argilla/dpo-mix-7k conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
|
||||||
|
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
|
||||||
|
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
"""
|
|
||||||
module for KTO style dataset transform strategies
|
|
||||||
"""
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from ..base import load as load_base
|
|
||||||
|
|
||||||
load = partial(load_base, module_base="axolotl.prompt_strategies.kto")
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user