Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
17330c05a3 shampoo checkpoint save workaround 2024-09-23 15:21:00 -04:00
Wing Lian
992ea517b7 setup precision config for bf16 2024-09-18 12:01:38 -07:00
Wing Lian
beaee36191 ddp shampoo 2024-09-18 10:50:46 -07:00
Wing Lian
69a29382e1 fix casting of optim args 2024-09-18 10:48:15 -07:00
Wing Lian
84dad0bd12 ensure epsilon is cast to float 2024-09-18 10:42:26 -07:00
Wing Lian
05f61a0ea5 remove accidental duplidcated line 2024-09-18 10:41:03 -07:00
Wing Lian
5334d0fc01 fixes 2024-09-18 10:38:59 -07:00
Wing Lian
52e6249d2e additional grafting config types and basic example doc 2024-09-18 08:16:11 -07:00
Wing Lian
eb3eab3450 wip shampoo optim support 2024-09-18 08:10:52 -07:00
110 changed files with 1205 additions and 5984 deletions

View File

@@ -28,19 +28,7 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout

View File

@@ -27,12 +27,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -89,12 +84,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -21,17 +21,10 @@ jobs:
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
- cuda: 124
cuda_version: 12.4.1
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"

View File

@@ -26,12 +26,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -88,12 +83,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -27,7 +27,7 @@ jobs:
run: |
pip3 install wheel packaging
pip3 install -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pip3 install -r requirements-tests.txt
- name: Extract tag name
id: tag

View File

@@ -25,7 +25,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20
steps:
@@ -47,14 +47,13 @@ jobs:
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pip3 install -r requirements-tests.txt
- name: Run tests
run: |
@@ -92,14 +91,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"

View File

@@ -36,7 +36,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20
steps:
@@ -49,20 +49,16 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
- name: Install dependencies
run: |
pip3 show torch
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pip3 install -r requirements-tests.txt
- name: Run tests
run: |
@@ -76,7 +72,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
timeout-minutes: 60
needs: [pre-commit, pytest]
strategy:
@@ -98,13 +94,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -1,3 +1,3 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_third_party=wandb

295
1991.yml
View File

@@ -1,295 +0,0 @@
base_model: Qwen/Qwen2.5-14B-Instruct
model_type: AutoModelForCausalLM #nohup accelerate launch -m axolotl.cli.train /home/ubuntu/qwen2.5_14B.yml > training_output.log 2>&1 &
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
chat_template: chatml
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
# input_layernorm layers
- model.layers.0.input_layernorm
- model.layers.1.input_layernorm
- model.layers.2.input_layernorm
- model.layers.3.input_layernorm
- model.layers.4.input_layernorm
- model.layers.5.input_layernorm
- model.layers.6.input_layernorm
- model.layers.7.input_layernorm
- model.layers.8.input_layernorm
- model.layers.9.input_layernorm
- model.layers.10.input_layernorm
- model.layers.11.input_layernorm
- model.layers.12.input_layernorm
- model.layers.13.input_layernorm
- model.layers.14.input_layernorm
- model.layers.15.input_layernorm
- model.layers.16.input_layernorm
- model.layers.17.input_layernorm
- model.layers.18.input_layernorm
- model.layers.19.input_layernorm
- model.layers.20.input_layernorm
- model.layers.21.input_layernorm
- model.layers.22.input_layernorm
- model.layers.23.input_layernorm
# lm_head layers
# mlp.down_proj layers
- model.layers.1.mlp.down_proj
- model.layers.35.mlp.down_proj
- model.layers.38.mlp.down_proj
- model.layers.37.mlp.down_proj
- model.layers.36.mlp.down_proj
- model.layers.15.mlp.down_proj
- model.layers.11.mlp.down_proj
- model.layers.12.mlp.down_proj
- model.layers.34.mlp.down_proj
- model.layers.44.mlp.down_proj
- model.layers.45.mlp.down_proj
- model.layers.9.mlp.down_proj
- model.layers.41.mlp.down_proj
- model.layers.33.mlp.down_proj
- model.layers.43.mlp.down_proj
- model.layers.40.mlp.down_proj
- model.layers.13.mlp.down_proj
- model.layers.8.mlp.down_proj
- model.layers.39.mlp.down_proj
- model.layers.10.mlp.down_proj
- model.layers.14.mlp.down_proj
- model.layers.16.mlp.down_proj
- model.layers.31.mlp.down_proj
- model.layers.32.mlp.down_proj
# mlp.gate_proj layers
- model.layers.1.mlp.gate_proj
- model.layers.44.mlp.gate_proj
- model.layers.46.mlp.gate_proj
- model.layers.45.mlp.gate_proj
- model.layers.43.mlp.gate_proj
- model.layers.47.mlp.gate_proj
- model.layers.42.mlp.gate_proj
- model.layers.32.mlp.gate_proj
- model.layers.27.mlp.gate_proj
- model.layers.33.mlp.gate_proj
- model.layers.28.mlp.gate_proj
- model.layers.39.mlp.gate_proj
- model.layers.41.mlp.gate_proj
- model.layers.40.mlp.gate_proj
- model.layers.30.mlp.gate_proj
- model.layers.29.mlp.gate_proj
- model.layers.31.mlp.gate_proj
- model.layers.26.mlp.gate_proj
- model.layers.37.mlp.gate_proj
- model.layers.10.mlp.gate_proj
- model.layers.38.mlp.gate_proj
- model.layers.12.mlp.gate_proj
- model.layers.36.mlp.gate_proj
- model.layers.13.mlp.gate_proj
# mlp.up_proj layers
- model.layers.1.mlp.up_proj
- model.layers.13.mlp.up_proj
- model.layers.11.mlp.up_proj
- model.layers.14.mlp.up_proj
- model.layers.15.mlp.up_proj
- model.layers.12.mlp.up_proj
- model.layers.8.mlp.up_proj
- model.layers.16.mlp.up_proj
- model.layers.9.mlp.up_proj
- model.layers.19.mlp.up_proj
- model.layers.10.mlp.up_proj
- model.layers.7.mlp.up_proj
- model.layers.17.mlp.up_proj
- model.layers.20.mlp.up_proj
- model.layers.21.mlp.up_proj
- model.layers.18.mlp.up_proj
- model.layers.38.mlp.up_proj
- model.layers.37.mlp.up_proj
- model.layers.39.mlp.up_proj
- model.layers.42.mlp.up_proj
- model.layers.41.mlp.up_proj
- model.layers.27.mlp.up_proj
- model.layers.28.mlp.up_proj
- model.layers.34.mlp.up_proj
# model.norm layers
# post_attention_layernorm layers
- model.layers.0.post_attention_layernorm
- model.layers.1.post_attention_layernorm
- model.layers.2.post_attention_layernorm
- model.layers.3.post_attention_layernorm
- model.layers.4.post_attention_layernorm
- model.layers.5.post_attention_layernorm
- model.layers.6.post_attention_layernorm
- model.layers.7.post_attention_layernorm
- model.layers.8.post_attention_layernorm
- model.layers.9.post_attention_layernorm
- model.layers.10.post_attention_layernorm
- model.layers.11.post_attention_layernorm
- model.layers.12.post_attention_layernorm
- model.layers.13.post_attention_layernorm
- model.layers.14.post_attention_layernorm
- model.layers.15.post_attention_layernorm
- model.layers.16.post_attention_layernorm
- model.layers.17.post_attention_layernorm
- model.layers.18.post_attention_layernorm
- model.layers.19.post_attention_layernorm
- model.layers.20.post_attention_layernorm
- model.layers.21.post_attention_layernorm
- model.layers.22.post_attention_layernorm
- model.layers.23.post_attention_layernorm
# self_attn.k_proj layers
- model.layers.47.self_attn.k_proj
- model.layers.39.self_attn.k_proj
- model.layers.41.self_attn.k_proj
- model.layers.37.self_attn.k_proj
- model.layers.35.self_attn.k_proj
- model.layers.44.self_attn.k_proj
- model.layers.38.self_attn.k_proj
- model.layers.14.self_attn.k_proj
- model.layers.7.self_attn.k_proj
- model.layers.12.self_attn.k_proj
- model.layers.11.self_attn.k_proj
- model.layers.32.self_attn.k_proj
- model.layers.10.self_attn.k_proj
- model.layers.8.self_attn.k_proj
- model.layers.9.self_attn.k_proj
- model.layers.6.self_attn.k_proj
- model.layers.45.self_attn.k_proj
- model.layers.42.self_attn.k_proj
- model.layers.5.self_attn.k_proj
- model.layers.40.self_attn.k_proj
- model.layers.33.self_attn.k_proj
- model.layers.0.self_attn.k_proj
- model.layers.34.self_attn.k_proj
- model.layers.13.self_attn.k_proj
# self_attn.o_proj layers
- model.layers.12.self_attn.o_proj
- model.layers.5.self_attn.o_proj
- model.layers.14.self_attn.o_proj
- model.layers.16.self_attn.o_proj
- model.layers.20.self_attn.o_proj
- model.layers.13.self_attn.o_proj
- model.layers.11.self_attn.o_proj
- model.layers.4.self_attn.o_proj
- model.layers.6.self_attn.o_proj
- model.layers.19.self_attn.o_proj
- model.layers.7.self_attn.o_proj
- model.layers.18.self_attn.o_proj
- model.layers.8.self_attn.o_proj
- model.layers.38.self_attn.o_proj
- model.layers.15.self_attn.o_proj
- model.layers.17.self_attn.o_proj
- model.layers.9.self_attn.o_proj
- model.layers.10.self_attn.o_proj
- model.layers.21.self_attn.o_proj
- model.layers.28.self_attn.o_proj
- model.layers.32.self_attn.o_proj
- model.layers.35.self_attn.o_proj
- model.layers.39.self_attn.o_proj
- model.layers.3.self_attn.o_proj
# self_attn.q_proj layers
- model.layers.1.self_attn.q_proj
- model.layers.2.self_attn.q_proj
- model.layers.3.self_attn.q_proj
- model.layers.44.self_attn.q_proj
- model.layers.29.self_attn.q_proj
- model.layers.45.self_attn.q_proj
- model.layers.43.self_attn.q_proj
- model.layers.32.self_attn.q_proj
- model.layers.38.self_attn.q_proj
- model.layers.19.self_attn.q_proj
- model.layers.42.self_attn.q_proj
- model.layers.34.self_attn.q_proj
- model.layers.36.self_attn.q_proj
- model.layers.40.self_attn.q_proj
- model.layers.26.self_attn.q_proj
- model.layers.20.self_attn.q_proj
- model.layers.39.self_attn.q_proj
- model.layers.28.self_attn.q_proj
- model.layers.35.self_attn.q_proj
- model.layers.41.self_attn.q_proj
- model.layers.33.self_attn.q_proj
- model.layers.25.self_attn.q_proj
- model.layers.30.self_attn.q_proj
- model.layers.27.self_attn.q_proj
# self_attn.v_proj layers
- model.layers.0.self_attn.v_proj
- model.layers.7.self_attn.v_proj
- model.layers.39.self_attn.v_proj
- model.layers.31.self_attn.v_proj
- model.layers.15.self_attn.v_proj
- model.layers.10.self_attn.v_proj
- model.layers.32.self_attn.v_proj
- model.layers.41.self_attn.v_proj
- model.layers.6.self_attn.v_proj
- model.layers.33.self_attn.v_proj
- model.layers.42.self_attn.v_proj
- model.layers.29.self_attn.v_proj
- model.layers.14.self_attn.v_proj
- model.layers.9.self_attn.v_proj
- model.layers.35.self_attn.v_proj
- model.layers.38.self_attn.v_proj
- model.layers.13.self_attn.v_proj
- model.layers.30.self_attn.v_proj
- model.layers.5.self_attn.v_proj
- model.layers.34.self_attn.v_proj
- model.layers.28.self_attn.v_proj
- model.layers.37.self_attn.v_proj
- model.layers.27.self_attn.v_proj
- model.layers.11.self_attn.v_proj
# model.embed_tokens layers
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: linear
learning_rate: 5e-6
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
gradient_checkpointing: unsloth
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 2
saves_per_epoch: 1
save_total_limit: 4
debug:
deepspeed: deepspeed_configs/zero3_bf16.json
weight_decay: 0.05
special_tokens:
eos_token: <|im_end|>

View File

@@ -14,7 +14,7 @@ Features:
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet
- Log results and optionally checkpoints to wandb or mlflow
- And more!
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
@@ -121,7 +121,7 @@ Features:
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt
@@ -515,22 +515,6 @@ wandb_name:
wandb_log_model:
```
##### Comet Logging
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```
##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:

View File

@@ -23,11 +23,11 @@ RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
@@ -37,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi
# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
RUN pip install -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -1,6 +1,6 @@
#!/bin/bash
set -e
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
timeout=45 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
)

View File

@@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
timeout=45 * 60,
cpu=8.0,
memory=131072,
)

View File

@@ -14,6 +14,15 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -24,6 +24,15 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -20,6 +20,15 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -20,6 +20,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -83,14 +83,13 @@ lora_on_cpu: true
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -124,48 +123,6 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:
# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
chat_template_jinja:
# The key in the data example that contains the messages. Default is "messages".
field_messages: messages
# The key in the message turn that contains the role. Default is "role".
message_field_role: role
# The key in the message turn that contains the content. Default is "content".
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping for the messages.
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant", "ai"]
system: ["system"]
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["gpt", "assistant"]
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn: train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
# See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
@@ -184,16 +141,9 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto'
rl:
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
# Currently supports chatml and inst (mistral/mixtral)
chat_template: chatml
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
@@ -315,21 +265,8 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
mlflow_run_name: # Your run name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
use_comet: # Enable or disable Comet integration.
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Where to save the full-finetuned model to
output_dir: ./completed-model
@@ -364,7 +301,7 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

View File

@@ -6,8 +6,6 @@ order: 3
## sharegpt
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```{.json filename="data.jsonl"}
@@ -71,138 +69,3 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f
```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "content": "..."}]}
```
See `config.qmd` for full configs and supported templates.
### Migrating from sharegpt
Most configs can be adapted as follows:
```yaml
# old
chat_template: chatml
datasets:
- path: ...
type: sharegpt
conversation: chatml
# new (if using tokenizer's chat_template)
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
```
We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
- path: ...
type: chat_template
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
```{.json filename="data.jsonl"}
{
"conversations": [
{"from": "system", "value": "You are an AI assistant.", "train": false},
{"from": "human", "value": "Hello", "train": false},
{"from": "assistant", "value": "Hello", "train": true},
{"from": "human", "value": "How are you?", "train": true},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": false},
{"begin_offset": 9, "end_offset": 18, "train": true},
{"begin_offset": 19, "end_offset": 30, "train": false},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": true,
},
{"from": "assistant", "value": "Hi there!", "train": true}
]
}
```
The configuration would look like:
```yaml
datasets:
- path: ...
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
message_field_training_detail: train_detail
```
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.

View File

@@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ignored by comparing the labels
We can check that the right tokens are ingored by comparing the labels
to each token:
```python

View File

@@ -1,28 +0,0 @@
# MultiModal / Vision Language Models (BETA)
### Supported Models
- Mllama, i.e. llama with vision models
### Usage
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
```yaml
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
skip_prepare_dataset: true
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
remove_unused_columns: false
sample_packing: false
# only finetune the Language model, leave the vision model and vision tower frozen
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
```

17
docs/optimizers.qmd Normal file
View File

@@ -0,0 +1,17 @@
# Optimizers
## Shampoo
```yaml
optimizer: shampoo
optim_shampoo_betas: [0.9, 0.999]
optim_args:
epsilon: 1e-12
max_preconditioner_dim: 8192
precondition_frequency: 100
use_decoupled_weight_decay: true
optim_shampoo_grafting_config_type: adam
optim_shampoo_grafting_config_kwargs:
beta2: 0.999
epsilon: 1e-12
```

View File

@@ -1,63 +0,0 @@
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,63 +0,0 @@
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -11,6 +11,7 @@ rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
chat_template: llama3
field_messages: conversation
field_chosen: chosen
field_rejected: rejected

View File

@@ -10,6 +10,7 @@ chat_template: llama3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: llama3
field_messages: messages
message_field_role: role
message_field_content: content

View File

@@ -1,77 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -2,4 +2,3 @@ pre-commit
black
mypy
types-requests
tbparse

View File

@@ -1,12 +1,12 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.0
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.1
datasets==3.0.1
deepspeed==0.15.3
peft==0.12.0
transformers @ git+https://github.com/huggingface/transformers.git@0963229e287501bed52ae1dabc17922524de6992
tokenizers>=0.19.1
bitsandbytes==0.43.3
accelerate==0.34.2
datasets==2.21.0
deepspeed==0.14.4
pydantic==2.6.3
addict
fire
@@ -16,7 +16,7 @@ flash-attn==2.6.3
sentencepiece
wandb
einops
xformers>=0.0.23.post1
xformers==0.0.27
optimum==1.16.2
hf_transfer
colorama
@@ -34,7 +34,8 @@ tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=2.3.0
liger-kernel==0.3.0
liger-kernel==0.2.1
distributed_shampoo @ git+https://github.com/facebookresearch/optimizers.git@main
mamba-ssm==1.2.0.post1
@@ -43,14 +44,6 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
trl==0.9.6
zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.4
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.5.0

View File

@@ -1,315 +0,0 @@
accelerate==0.34.1
addict==2.4.0
aiofiles==23.2.1
aiohttp==3.9.0
aiosignal==1.3.1
aiostream==0.5.2
alembic==1.13.1
annotated-types==0.6.0
annoy==1.17.3
ansible==6.7.0
ansible-core==2.13.13
ansible-vault==2.1.0
anyio==3.7.1
appdirs==1.4.4
art==6.0
asgiref==3.7.2
async-timeout==4.0.2
attrdict==2.0.1
attrs==22.2.0
awscli==1.32.75
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
backoff==2.2.1
base58==2.1.1
beartype==0.17.2
bitnet==0.2.1
bitsandbytes==0.42.0
bittensor==6.7.0
black==23.7.0
blinker==1.7.0
boto3==1.34.75
botocore==1.34.75
cachetools==5.3.3
cachy==0.1.1
certifi==2023.7.22
cffi==1.16.0
cfgv==3.3.1
chai-guanaco==1.2.4
charset-normalizer==3.2.0
cleo==0.6.8
click==8.1.7
cloudpickle==2.0.0
cohere==4.11.2
colorama==0.4.4
coloredlogs==15.0.1
CoLT5-attention==0.10.20
contextlib2==21.6.0
contourpy==1.2.0
cryptography==41.0.3
cycler==0.12.1
cytoolz==0.12.3
databricks-cli==0.18.0
dataclasses-json==0.5.7
datasets==2.11.0
ddt==1.6.0
decorator==5.1.1
deepspeed==0.15.0
# Editable Git install with no remote (dialogpt==0.1)
-e /Users/wing/Projects/ml/dialogpt/src
dill==0.3.6
distlib==0.3.6
docker==7.0.0
docker-pycreds==0.4.0
docstring-parser==0.15
docutils==0.16
ecdsa==0.18.0
einops==0.7.0
einops-exts==0.0.4
einx==0.1.3
entrypoints==0.4
eth-hash==0.6.0
eth-keys==0.5.0
eth-typing==4.0.0
eth-utils==2.3.1
evaluate==0.4.0
exceptiongroup==1.1.1
fastapi==0.109.2
fastcore==1.5.29
ffmpy==0.4.0
filelock==3.12.2
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
fire==0.5.0
first==2.0.2
flake8==7.0.0
Flask==3.0.1
fonttools==4.47.2
frozendict==2.4.1
frozenlist==1.3.3
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
fsspec==2023.6.0
fuzzywuzzy==0.18.0
gitdb==4.0.10
GitPython==3.1.31
google-pasta==0.2.0
gradio==4.42.0
gradio_client==1.3.0
greenlet==2.0.2
grpclib==0.4.7
gunicorn==21.2.0
h11==0.14.0
h2==4.1.0
hpack==4.0.0
httpcore==0.17.3
httpx==0.24.1
huggingface-hub==0.23.4
humanfriendly==10.0
hyperframe==6.0.1
identify==2.5.24
idna==3.4
immutables==0.20
importlib-metadata==6.7.0
importlib-resources==6.1.1
inflection==0.5.1
iniconfig==2.0.0
itsdangerous==2.1.2
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
jsonlines==3.1.0
jsonschema==2.6.0
kiwisolver==1.4.5
langchain==0.0.144
Levenshtein==0.24.0
libcst==1.1.0
liger-kernel==0.0.0
lion-pytorch==0.1.2
llama-cpp-python==0.1.36
llvmlite==0.40.1
local-attention==1.9.0
loguru==0.7.0
Mako==1.3.2
Markdown==3.5.2
markdown-it-py==3.0.0
markdown2==2.4.10
MarkupSafe==2.1.2
marshmallow==3.19.0
marshmallow-enum==1.5.1
matplotlib==3.8.2
mccabe==0.7.0
mdurl==0.1.2
MEGABYTE-pytorch==0.0.7
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
mlflow==2.10.0
modal==0.62.77
more-itertools==10.2.0
mpmath==1.2.1
msgpack==1.0.7
msgpack-numpy-opentensor==0.5.0
multidict==6.0.4
multiprocess==0.70.14
munch==2.5.0
mypy==1.3.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
netaddr==0.10.1
networkx==3.0rc1
nh3==0.2.14
nodeenv==1.8.0
nomic==2.0.2
numba==0.57.1
numexpr==2.8.4
numpy==1.24.4
oauthlib==3.2.2
openai==0.27.4
openapi==1.1.0
openapi-schema-pydantic==1.2.4
optimum==1.8.6
orjson==3.10.7
packaging==23.1
pandas==2.0.0
parameterized==0.9.0
password-strength==0.0.3.post2
pastel==0.1.1
pathos==0.3.0
pathspec==0.11.1
pathtools==0.1.2
peft==0.11.1
pendulum==3.0.0
Pillow==9.5.0
pip-tools==1.11.0
platformdirs==3.2.0
pluggy==1.4.0
poetry==0.7.1
pox==0.3.2
ppft==1.7.6.6
pre-commit==3.3.2
prettytable==3.10.0
prompt-toolkit==3.0.39
protobuf==3.20.2
protobuf3-to-dict==0.1.5
psutil==5.9.5
psycopg==3.1.18
PuLP==2.8.0
py==1.11.0
py-bip39-bindings==0.1.11
py-cpuinfo==9.0.0
py-ed25519-zebra-bindings==1.0.1
py-sr25519-bindings==0.2.0
pyarrow==11.0.0
pyasn1==0.6.0
pycodestyle==2.11.1
pycparser==2.21
pycryptodome==3.20.0
pydantic==2.5.3
pydantic_core==2.14.6
pydub==0.25.1
pyfiglet==0.8.post1
pyflakes==3.2.0
Pygments==2.15.1
PyJWT==2.8.0
pylev==1.4.0
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==2.4.7
pyrsistent==0.14.11
pytest==8.0.2
pytest-asyncio==0.23.4
python-dateutil==2.8.2
python-dotenv==1.0.1
python-Levenshtein==0.24.0
python-multipart==0.0.9
pytz==2023.3
PyYAML==6.0.1
querystring-parser==1.2.4
rapidfuzz==3.6.1
regex==2023.6.3
requests==2.31.0
requests-toolbelt==0.8.0
resolvelib==0.8.1
responses==0.18.0
retry==0.9.2
rich==13.7.0
rsa==4.7.2
ruff==0.6.3
s3transfer==0.10.1
safetensors==0.4.5
sagemaker==2.148.0
scalecodec==1.2.7
schedulefree==1.2.1
schema==0.7.5
scikit-learn==1.4.0
scipy==1.9.3
seaborn==0.13.2
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==1.19.1
setproctitle==1.3.2
shellingham==1.5.4
shortuuid==1.0.11
shtab==1.6.5
sigtools==4.0.1
six==1.16.0
skypilot==0.4.1
smdebug-rulesconfig==1.0.1
smmap==5.0.0
sniffio==1.3.0
SQLAlchemy==1.4.47
sqlparse==0.4.4
starlette==0.36.3
substrate-interface==1.5.2
svgwrite==1.4.3
sympy==1.11.1
synchronicity==0.6.7
tabulate==0.9.0
tblib==1.7.0
tenacity==8.2.2
tensor-parallel==2.0.0
termcolor==2.2.0
text2art==0.2.0
threadpoolctl==3.2.0
tiktoken==0.6.0
time-machine==2.14.1
timm==0.9.16
tokenizers==0.19.1
tokenmonster==1.1.12
toml==0.9.6
tomli==2.0.1
tomlkit==0.12.0
toolz==0.12.1
torch==2.2.0
torchdata==0.6.1
torchdiffeq==0.2.3
TorchFix==0.4.0
torchtext==0.15.2
torchvision==0.17.0
tqdm==4.66.2
transformers==4.44.2
trl==0.9.6
typer==0.12.5
types-certifi==2021.10.8.3
types-requests==2.31.0.20240125
types-setuptools==69.0.0.20240125
types-toml==0.10.8.7
typing==3.7.4.3
typing-inspect==0.8.0
typing_extensions==4.9.0
tyro==0.5.18
tzdata==2023.3
unique-names-generator==1.0.2
urllib3==2.2.2
uvicorn==0.22.0
vector_quantize_pytorch==1.14.1
virtualenv==20.23.0
voyager==2.0.2
wandb==0.16.2
watchfiles==0.21.0
wavedrom==2.0.3.post3
wcwidth==0.2.6
websocket-client==1.7.0
websockets==12.0
Werkzeug==3.0.1
wonderwords==2.2.0
xxhash==3.2.0
yarl==1.8.2
zetascale==2.2.7
zipp==3.15.0

View File

@@ -1,60 +0,0 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset
@click.command()
@click.argument("dataset", type=str)
@click.option("--split", type=str, default="train")
def parse_dataset(dataset=None, split="train"):
ds_cfg = {}
ds_cfg["path"] = dataset
ds_cfg["split"] = split
ds_cfg["type"] = "chat_template"
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
dataset = load_dataset(dataset, split=split)
features = dataset.features
feature_keys = features.keys()
field_messages = None
for key in ["conversation", "conversations", "messages"]:
if key in feature_keys:
field_messages = key
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
break
if not message_field_role:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role
message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
break
if not message_field_content:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
print(yaml.dump({"datasets": [ds_cfg]}))
if __name__ == "__main__":
parse_dataset()

View File

@@ -30,9 +30,6 @@ def parse_requirements():
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
@@ -52,35 +49,20 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if (major, minor) >= (2, 3):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
@@ -109,7 +91,6 @@ setup(
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",

View File

@@ -30,8 +30,6 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
@@ -41,7 +39,7 @@ from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -55,22 +53,8 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_legacy_axolotl_text_art(suffix=None):
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
@@ -83,13 +67,6 @@ def print_legacy_axolotl_text_art(suffix=None):
print_dep_versions()
def print_axolotl_text_art(
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
print(AXOLOTL_LOGO)
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
@@ -257,8 +234,7 @@ def do_inference_gradio(
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
default_tokens: Dict[str, str] = {}
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
@@ -266,13 +242,10 @@ def do_inference_gradio(
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -286,24 +259,7 @@ def do_inference_gradio(
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
@@ -326,7 +282,6 @@ def do_inference_gradio(
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
@@ -443,8 +398,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg
@@ -454,20 +407,12 @@ def load_datasets(
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
cfg, tokenizer
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(

View File

@@ -27,7 +27,6 @@ from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -71,11 +70,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching():
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.download:
model_name = parsed_cfg.base_model

View File

@@ -3,11 +3,13 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
from typing import Tuple, Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.cli import (
check_accelerate_default_config,
@@ -18,7 +20,6 @@ from axolotl.cli import (
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
@@ -38,7 +39,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return do_train(parsed_cfg, parsed_cli_args)
def do_train(cfg, cli_args) -> None:
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
@@ -63,13 +64,7 @@ def do_train(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":

View File

@@ -23,7 +23,7 @@ class TrainerCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)

View File

@@ -1,34 +0,0 @@
"""
ChatML transformation functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages
from .shared import wrap_tools
def format_message(
message: Messages,
message_index: Optional[int] = None, # pylint: disable=unused-argument
) -> Messages:
if message.is_chat_formatted:
return message
# prepend the role prefix within a MessageContents to message.content
message.content.insert(
0,
MessageContents(
type="text",
value=f"<|im_start|>{message.role}\n",
weight=0,
),
)
message.content.append(
MessageContents(type="text", value="<|im_end|>", weight=message.weight)
)
message.content.append(MessageContents(type="text", value="\n", weight=0))
message = wrap_tools(message)
message.is_chat_formatted = True
return message

View File

@@ -1,45 +0,0 @@
"""
Llama 3.x chat formatting functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages
from .shared import wrap_tools
def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
if message.is_chat_formatted:
return message
message_role = message.role
if message.role == "tool":
message_role = "ipython"
# prepend the role prefix within a MessageContents to message.content
message.content.insert(
0,
MessageContents(
type="text",
value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n",
weight=0,
),
)
message.content.append(
MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
)
message = wrap_tools(message)
if message_index == 0:
message.content.insert(
0,
MessageContents(
type="text",
value="<|begin_of_text|>",
weight=0,
),
)
message.is_chat_formatted = True
return message

View File

@@ -1,47 +0,0 @@
"""
shared functions for format transforms
"""
from axolotl.core.chat.messages import MessageContents, Messages
def wrap_tools(message: Messages):
# loop over message.content by index to find tool calls, we need to wrap each with tags,
# so be wary of indexing issues when changing the list while iterating.
# iterate over the range in reverse order to avoid index shifting
for i in range(len(message.content) - 1, -1, -1):
if message.content[i].type == "tool_call":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_call>\n", weight=message.weight
),
)
# make sure the actual tool call content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_call>\n", weight=message.weight
),
)
elif message.content[i].type == "tool_response":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_response>\n", weight=message.weight
),
)
# make sure the actual tool response content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_response>\n", weight=message.weight
),
)
return message

View File

@@ -1,230 +0,0 @@
"""
internal message representations of chat messages
"""
import json
from enum import Enum
from typing import Any, Callable, List, Optional, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizer
class MessageRoles(str, Enum):
"""
Message roles for the system, user, assistant, and tools
"""
system = "system" # pylint: disable=invalid-name
user = "user" # pylint: disable=invalid-name
assistant = "assistant" # pylint: disable=invalid-name
tool = "tool" # pylint: disable=invalid-name
ipython = ( # pylint: disable=invalid-name
# for responses from builtin tools
"ipython"
)
class MessageContentTypes(str, Enum):
"""
Message content types for text, image, audio, tool calls, and tool responses
"""
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
text = "text" # pylint: disable=invalid-name
image = "image" # pylint: disable=invalid-name
audio = "audio" # pylint: disable=invalid-name
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
tool_response = "tool_response" # pylint: disable=invalid-name
class SpecialToken(str, Enum):
"""
Special tokens for beginning of string and end of string
"""
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
class ToolCallFunction(BaseModel):
"""
Tool call function with name and arguments
"""
name: str
arguments: dict[str, str]
class Tool(BaseModel):
"""
Tool with description, function, and parameters
"""
description: str
function: ToolCallFunction
parameters: dict[str, str] # .properties
class ToolCallContents(BaseModel):
"""
Tool call contents with name, arguments, and optional id
"""
name: str
arguments: dict[str, Union[str, int]]
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments}
if self.id is not None:
data["id"] = self.id
return json.dumps(data)
class ToolResponseContents(BaseModel):
"""
Tool response contents with name, content, and optional id
"""
name: str
content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "content": self.content}
if self.id is not None:
data["id"] = self.id
return json.dumps(data)
class MessageContents(BaseModel):
"""
Message contents with type, value, metadata, weight, newline, and end of contents
"""
type: Union[str, MessageContentTypes]
value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
weight: Optional[Union[int, float]] = None
has_newline: bool = False
eoc: bool = False # end of contents
def __str__(self) -> str:
str_val = str(self.value)
if self.has_newline and not str_val.endswith("\n"):
str_val += "\n"
return str_val
class Messages(BaseModel):
"""
Messages with role, content, metadata, weight, and chat formatting
"""
role: Union[MessageRoles, str] # allows for arbitrary roles
content: List["MessageContents"]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
weight: Optional[Union[int, float]] = None
is_chat_formatted: bool = False
def __str__(self) -> str:
return "".join(str(c) for c in self.content)
def tokenized(
self, tokenizer: PreTrainedTokenizer, ignore_index=-100
) -> dict[str, List[int]]:
# iterate over the contents, tokenizing the concatenated string values up to the current MessageContents
# returns a dictionary mapping w input_ids, attention_mask, and labels
input_ids: List[int] = []
labels: List[int] = []
pending_input_ids: List[int] = []
pending_weight = self.weight
running_content = ""
for _, msg_content in enumerate(self.content):
# TODO also handle non-text content types
if msg_content.type in [
MessageContentTypes.text.value,
MessageContentTypes.tool_call.value,
MessageContentTypes.tool_response.value,
]:
running_content += str(msg_content)
tok_results = tokenizer(running_content, add_special_tokens=False)
tok_input_ids = tok_results["input_ids"]
if pending_input_ids:
new_pending_inputs = tok_input_ids[
len(input_ids) : len(input_ids) + len(pending_input_ids)
]
if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids)
if pending_weight:
labels.extend(pending_input_ids)
else:
labels.extend([ignore_index] * len(pending_input_ids))
pending_input_ids = tok_results["input_ids"][len(input_ids) :]
pending_weight = self.weight and msg_content.weight not in [0, 0.0]
input_ids.extend(pending_input_ids)
if pending_weight:
labels.extend(pending_input_ids)
else:
labels.extend([ignore_index] * len(pending_input_ids))
attention_mask = [1] * len(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class Chats(BaseModel):
"""
top level data structure for chat conversations
"""
conversation: List[Messages]
def __str__(self) -> str:
return "".join(str(c) for c in self.conversation)
def tokenized(
self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100
) -> dict[str, List[int]]:
input_ids = []
attention_mask = []
labels = []
for msg in self.conversation:
msg_results = msg.tokenized(tokenizer, ignore_index)
input_ids.extend(msg_results["input_ids"])
attention_mask.extend(msg_results["attention_mask"])
labels.extend(msg_results["labels"])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class ChatFormattedChats(Chats):
"""
Chat formatted chats with formatter and optional train on inputs
"""
formatter: Callable # [[Union[dict, Chats]], Chats]
train_on_inputs: bool = False
def model_post_init(self, __context):
for i, msg in enumerate(self.conversation):
self.conversation[i] = self.formatter(msg, message_index=i)
if self.train_on_inputs:
self.conversation[i].weight = 1
class PreferenceChats(BaseModel):
"""
representation for preference data for chat
"""
prompt: List[Messages]
chosen: Messages
rejected: Messages

View File

@@ -1,55 +0,0 @@
"""
chat dataset module
"""
import os
from typing import Callable, Optional, Union
from datasets import Dataset
from transformers import PreTrainedTokenizer
from axolotl.core.chat.messages import ChatFormattedChats
class TokenizedChatDataset(Dataset):
"""
Tokenized chat dataset
"""
def __init__(
self,
data: Dataset,
model_transform: Union[PreTrainedTokenizer, Callable],
*args,
message_transform: Optional[Callable] = None,
formatter=None,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,
):
def map_fn(ex):
if message_transform is not None:
ex = message_transform(ex)
if formatter is not None:
ex = ChatFormattedChats(
formatter=formatter,
**ex,
)
else:
ex = ChatFormattedChats(
**ex,
)
return ex.tokenized(model_transform)
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(64, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,
num_proc=num_proc,
keep_in_memory=keep_in_memory,
remove_columns=features,
desc="Tokenizing Chats",
)
super().__init__(tokenized_data.data, *args, **kwargs)

View File

@@ -1,150 +0,0 @@
"""
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
"""
from typing import Any, Mapping, Union
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
train_on_inputs=False,
conversations_field: str = "conversations",
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
message_field_content: Union[str, list[str]] = [
"value",
"text",
"content",
], # commonly "content"
message_field_training: Union[str, list[str]] = [
"train",
"weight",
], # commonly "weight"
):
"""Builds a transform that takes a row from the dataset and converts it to a Chat
Args:
train_on_inputs (bool, optional):
If True, the transform will train on the inputs. If False, the transform will train on the targets.
Defaults to False.
conversations_field (str, optional):
The field name of the conversations. Defaults to "conversations".
message_field_role (str | list[str], optional):
The field name of the role. Defaults to "role".
message_field_content (str | list[str], optional):
The field name of the message content. Defaults to "content".
message_field_training (str | list[str], optional):
The field name of the train/weight. Defaults to "weight".
Returns:
Callable:
A function that takes a list of conversations and returns a list of messages.
"""
message_field_role = (
[message_field_role]
if isinstance(message_field_role, str)
else message_field_role
)
message_field_content = (
[message_field_content]
if isinstance(message_field_content, str)
else message_field_content
)
message_weight_fields = (
[message_field_training]
if isinstance(message_field_training, str)
else message_field_training
)
role_value_mappings = {
"system": "system",
"user": "user",
"human": "user",
"assistant": "assistant",
"gpt": "assistant",
"tool": "tool",
"ipython": "ipython",
}
if train_on_inputs:
role_default_weights_mappings = {
"system": 1,
"user": 1,
"assistant": 1,
"tool": 1,
"ipython": 1,
}
else:
role_default_weights_mappings = {
"system": 0,
"user": 0,
"assistant": 1,
"tool": 0,
"ipython": 0,
}
def transform_builder(sample: Mapping[str, Any]):
if conversations_field not in sample:
raise ValueError(f"Field '{conversations_field}' not found in sample.")
# if none of the role fields are in the message, raise an error
if not any(
role in sample[conversations_field][0] for role in message_field_role
):
raise ValueError("No role field found in message.")
role_field = next(
role
for role in message_field_role
if role in sample[conversations_field][0]
)
if not any(
field in sample[conversations_field][0] for field in message_field_content
):
raise ValueError("No message_content field found in message.")
message_content_field = next(
field
for field in message_field_content
if field in sample[conversations_field][0]
)
if not any(
field in sample[conversations_field][0] for field in message_field_training
):
message_weight_field = None
else:
message_weight_field = next(
field
for field in message_weight_fields
if field in sample[conversations_field][0]
)
messages = []
for message in sample[conversations_field]:
role = role_value_mappings[message[role_field]]
weight = (
int(message[message_weight_field])
if message_weight_field
else role_default_weights_mappings[role]
)
# TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents
if isinstance(message[message_content_field], str):
messages.append(
{
"role": role,
"content": [
{
"type": "text",
"value": message[message_content_field],
}
],
"weight": weight,
}
)
else:
messages.append(
{
"role": role,
"content": message[message_content_field],
"weight": weight,
}
)
return {"conversation": messages}
return transform_builder

View File

@@ -7,7 +7,6 @@ import abc
import gc
import importlib
import importlib.util
import inspect
import logging
import math
import os
@@ -17,17 +16,17 @@ from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
import torch
import transformers
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
EarlyStoppingCallback,
PreTrainedModel,
Trainer,
TrainerCallback,
TrainingArguments,
@@ -43,14 +42,13 @@ from trl import (
KTOTrainer,
ORPOConfig,
ORPOTrainer,
RewardConfig,
RewardTrainer,
)
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
@@ -63,14 +61,12 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
@@ -254,10 +250,11 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
optim_shampoo_grafting_config_type: Optional[
Literal["adam", "sgd", "adagrad"]
] = None
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
optim_shampoo_betas: Optional[Tuple[float, float]] = None
@dataclass
@@ -303,13 +300,6 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
)
@dataclass
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
Reward config for Reward training
"""
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
@@ -407,10 +397,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
def __init__(
self,
*_args,
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs,
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
@@ -435,7 +427,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"shampoo",
]
):
return super().create_optimizer()
@@ -469,14 +467,110 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
self.args, "loraplus_lr_embedding", None
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "shampoo":
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
CommunicationDType,
DDPShampooConfig,
FSDPShampooConfig,
PrecisionConfig,
SGDGraftingConfig,
)
from distributed_shampoo.utils.shampoo_fsdp_utils import (
compile_fsdp_parameter_metadata,
)
# parse args.optim_args
optim_args = {}
if self.args.optim_args:
for mapping in self.args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
optim_args["betas"] = self.args.optim_shampoo_betas
if "max_preconditioner_dim" in optim_args:
optim_args["max_preconditioner_dim"] = int(
optim_args["max_preconditioner_dim"]
)
if "precondition_frequency" in optim_args:
optim_args["precondition_frequency"] = int(
optim_args["precondition_frequency"]
)
if "use_decoupled_weight_decay" in optim_args:
optim_args["use_decoupled_weight_decay"] = bool(
optim_args["use_decoupled_weight_decay"]
)
if isinstance(optim_args["epsilon"], str):
optim_args["epsilon"] = float(optim_args["epsilon"])
optim_args["lr"] = self.args.learning_rate
optim_args["weight_decay"] = self.args.weight_decay
if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs:
if isinstance(
self.args.optim_shampoo_grafting_config_kwargs["epsilon"], str
):
self.args.optim_shampoo_grafting_config_kwargs[
"epsilon"
] = float(
self.args.optim_shampoo_grafting_config_kwargs["epsilon"]
)
if self.args.optim_shampoo_grafting_config_type == "adam":
grafting_config = AdamGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "sgd":
grafting_config = SGDGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
grafting_config = AdaGradGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
distributed_config = None
if self.args.world_size > 1:
if self.args.fsdp and self.args.fsdp_config:
distributed_config = FSDPShampooConfig(
param_to_metadata=compile_fsdp_parameter_metadata(
self.model_wrapped
)
)
else:
distributed_config = DDPShampooConfig(
communication_dtype=CommunicationDType.BF16,
num_trainers_per_group=self.args.world_size,
communicate_params=False,
)
precision_config = None
if self.args.bf16:
precision_config = PrecisionConfig(
computation_dtype=torch.bfloat16,
factor_matrix_dtype=torch.bfloat16,
inv_factor_matrix_dtype=torch.bfloat16,
filtered_grad_dtype=torch.bfloat16,
momentum_dtype=torch.bfloat16,
grafting_state_dtype=torch.bfloat16,
)
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
DistributedShampoo(
optimizer_grouped_parameters,
grafting_config=grafting_config,
distributed_config=distributed_config,
precision_config=precision_config,
**optim_args,
)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
@@ -666,9 +760,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
@@ -676,18 +768,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
return super().compute_loss(model, inputs, return_outputs=return_outputs)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
@@ -783,13 +865,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
def orpo_compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
):
def orpo_compute_loss(self, model, inputs, return_outputs=False):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
label_pad_token=-100,
@@ -895,13 +971,17 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial):
def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial)
try:
return super()._save_checkpoint(model, trial, metrics=metrics)
except NotImplementedError as exc:
LOG.warning(f"Failed to save checkpoint: {exc}")
return None
class AxolotlMambaTrainer(AxolotlTrainer):
@@ -916,7 +996,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
@@ -1001,9 +1080,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if is_sagemaker_mp_enabled():
@@ -1024,32 +1103,18 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
def tokenize_row(
self,
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
) -> Dict:
res = super().tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
res = super().tokenize_row(feature, model=model)
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
loss: torch.Tensor = super().training_step(model, inputs)
gc.collect()
torch.cuda.empty_cache()
return loss
@@ -1079,14 +1144,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
@@ -1097,11 +1154,10 @@ class TrainerBuilderBase(abc.ABC):
_model_ref = None
_peft_config = None
def __init__(self, cfg, model, tokenizer, processor=None):
def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor
# in case the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
@@ -1152,23 +1208,12 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
return callbacks
@@ -1238,11 +1283,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
@@ -1267,8 +1307,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -1490,22 +1528,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_arguments_kwargs["report_to"] = report_to
if self.cfg.use_wandb:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
@@ -1521,6 +1552,21 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
# shampoo optimizer config
if self.cfg.optim_shampoo_betas:
training_arguments_kwargs[
"optim_shampoo_betas"
] = self.cfg.optim_shampoo_betas
if self.cfg.optim_shampoo_grafting_config_type:
training_arguments_kwargs[
"optim_shampoo_grafting_config_type"
] = self.cfg.optim_shampoo_grafting_config_type
if self.cfg.optim_shampoo_grafting_config_kwargs:
training_arguments_kwargs[
"optim_shampoo_grafting_config_kwargs"
] = self.cfg.optim_shampoo_grafting_config_kwargs
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
@@ -1593,10 +1639,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template
)
if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
@@ -1608,14 +1650,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.optimizer in [
# pylint: disable=duplicate-code
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"shampoo",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
@@ -1654,22 +1695,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"accelerator_config"
] = self.cfg.accelerator_config
training_args_cls = (
AxolotlTrainingArguments
if not self.cfg.reward_model
else AxolotlRewardConfig
)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
@@ -1682,37 +1714,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not self.cfg.reward_model:
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not self.cfg.reward_model:
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
eval_data_collator=self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
@@ -1746,14 +1768,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
RewardDataCollatorWithPadding,
]
]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
@@ -1764,12 +1781,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
else:
collator = DataCollatorForSeq2Seq
collator = DataCollatorForSeq2Seq
return collator(
self.tokenizer,
@@ -1955,7 +1967,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
dpo_trainer_kwargs["generate_during_eval"] = True
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
@@ -1967,17 +1979,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
dpo_trainer_kwargs["processing_class"] = self.tokenizer
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
tokenizer=self.tokenizer,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)

View File

@@ -159,29 +159,6 @@ class BasePlugin:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
def post_train(self, cfg, model):
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
"""
def post_train_unload(self, cfg):
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def load_plugin(plugin_name: str) -> BasePlugin:
"""
@@ -404,17 +381,3 @@ class PluginManager:
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.post_train_unload(cfg)

View File

@@ -1,13 +0,0 @@
# LM Eval Harness
### Usage
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
```

View File

@@ -1,42 +0,0 @@
"""
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
class LMEvalPlugin(BasePlugin):
"""
Plugin for LM Evaluation Harness integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)

View File

@@ -1,15 +0,0 @@
"""
Module for handling lm eval harness input arguments.
"""
from typing import List, Optional
from pydantic import BaseModel
class LMEvalArgs(BaseModel):
"""
Input args for lm eval harness
"""
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8

133
src/axolotl/loraplus.py Normal file
View File

@@ -0,0 +1,133 @@
"""Module for LoRA+"""
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import logging
from functools import reduce
from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
LOG = logging.getLogger("axolotl.loraplus")
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
return optimizer

View File

@@ -1,229 +0,0 @@
"""
Monkeypatch for Vision Llama for FA2 support
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
from transformers.models.mllama.modeling_mllama import (
MllamaTextCrossAttention,
MllamaTextSelfAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import is_flash_attn_greater_or_equal_2_10
class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention):
"""
Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[ # pylint: disable=unused-argument
torch.Tensor
] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
query_states = self.q_norm(query_states)
if cross_attention_states is not None:
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = self.k_norm(key_states)
if past_key_value is not None:
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx,
{"cache_position": cache_position},
)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Apply Flash Attention
dropout_rate = self.dropout if self.training else 0.0
output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout_rate,
softmax_scale=None,
causal=False,
return_attn_probs=output_attentions,
)
attn_output = output.contiguous().view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention):
"""
Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs):
super().__init__(config, layer_idx, *args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
past_key_value=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x num_heads x head_dim
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# Handle potential silent casting to float32
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = (
self.config._pre_quantization_dtype # pylint: disable=protected-access
)
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def patch_mllama():
from transformers.models.mllama.modeling_mllama import (
MLLAMA_TEXT_ATTENTION_CLASSES,
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES,
MLLAMA_VISION_ATTENTION_CLASSES,
MllamaPreTrainedModel,
)
MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access
True
)
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
"flash_attention_2"
] = MllamaTextCrossFlashAttention2
# fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[
"flash_attention_2"
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]

View File

@@ -22,6 +22,7 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
@@ -43,19 +44,7 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def is_xformers_available() -> bool:
try:
import xformers # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
def is_xformers_swiglu_available() -> bool:
if not is_xformers_available():
return False
from xformers.ops.common import get_xformers_operator
try:
@@ -68,11 +57,6 @@ def is_xformers_swiglu_available() -> bool:
def replace_llama_mlp_with_swiglu(model):
if is_xformers_swiglu_available():
from axolotl.monkeypatch.xformers_ import FusedMLP
else:
raise RuntimeError("xformers SwiGLU not available for this environment")
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
@@ -197,6 +181,49 @@ class FusedAttention(LlamaAttention):
set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(

View File

@@ -10,7 +10,6 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mllama_text_model",
"llama",
"mistral",
"mixtral",
@@ -27,18 +26,15 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
# def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
# elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
elif hasattr(transformers, "modeling_flash_attention_utils"):
if not has_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return

View File

@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str],
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)

View File

@@ -16,7 +16,6 @@
# This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """
import importlib
import math

View File

@@ -16,6 +16,26 @@ from transformers.models.llama.modeling_llama import (
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
"""
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@@ -60,6 +80,12 @@ def get_forward_code() -> str:
return forward
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
@@ -72,31 +98,48 @@ def check_self_attn_is_patchable() -> bool:
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
logits,
labels,
vocab_size: int, # pylint: disable=unused-argument
num_items_in_batch: int = None,
ignore_index: int = -100, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
return loss
if model_type == "llama":
from transformers.loss import loss_utils
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else:
raise ValueError("Unsupported model type")

View File

@@ -1,51 +0,0 @@
"""
Fused MLP layer for incrementally improved training efficiency
"""
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import set_module_name
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)

View File

@@ -9,12 +9,8 @@ from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies")
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
def load(strategy, tokenizer, cfg, ds_cfg):
try:
if strategy == "messages":
from .messages import load as messages_load
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
@@ -28,12 +24,9 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None
return None

View File

@@ -1,10 +0,0 @@
### example yaml
```yaml
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
```

View File

@@ -1,35 +0,0 @@
"""Module to load prompt strategies."""
import importlib
import inspect
import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
def load(strategy, tokenizer, cfg, ds_cfg):
# pylint: disable=duplicate-code
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.bradley_terry"
)
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
return None

View File

@@ -1,102 +0,0 @@
"""
Bradley-Terry model with chat template prompt strategy.
"""
import logging
from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
class BTChatTemplateStrategy(ChatTemplateStrategy):
"""
Bradley-Terry reward model pairwise chat template prompt strategy.
"""
def tokenize_prompt(self, prompt):
"""
:param prompt: the actual row of data from the underlying dataset
:return:
"""
self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt)
return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
"labels_chosen": 1.0,
"input_ids_rejected": rejected_tokenized["input_ids"],
"attention_mask_rejected": rejected_tokenized["attention_mask"],
"labels_rejected": 0.0,
}
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1
if not cfg.reward_model
else cfg.sequence_len,
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
}
strategy = BTChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -1,27 +0,0 @@
"""
chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template
"""
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
prompt = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -5,11 +5,9 @@ HF Chat Templates prompt strategy
import logging
from typing import Any, Dict, List, Optional
from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.chat_templates import chat_templates
# Configure the logger
LOG = logging.getLogger("axolotl")
@@ -22,7 +20,6 @@ class ChatTemplatePrompter(Prompter):
def __init__(
self,
tokenizer,
processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "from",
@@ -47,12 +44,11 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer
self.processor: ProcessorMixin = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
def build_prompt(self, conversation, add_generation_prompt=False):
turns = [
{
"role": self.roles[t[self.message_field_role]],
@@ -65,28 +61,6 @@ class ChatTemplatePrompter(Prompter):
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
if self.processor:
text = self.processor.apply_chat_template(
turns,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
batch = self.processor(
text=text,
images=images,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
# workaround since processor works in batches instead of single examples
for k, val in batch.items():
if k in ["pixel_values"]:
batch[k] = val.tolist()
else:
batch[k] = val.squeeze().tolist()
return batch
return self.tokenizer.apply_chat_template(
turns,
truncation=True,
@@ -217,7 +191,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.train_on_eos = train_on_eos
self.images = "images"
@property
def messages(self):
@@ -236,21 +209,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
and not self.prompter.message_field_training_detail
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt(
turns[:-1],
add_generation_prompt=True,
images=images,
turns[:-1], add_generation_prompt=True
)
tokenized_res = self.prompter.build_prompt(turns, images=images)
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
tokenized_prompt["input_ids"] = input_ids
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
@@ -258,9 +220,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
else:
labels = input_ids
tokenized_prompt["labels"] = labels
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
LOG.info(self.roles_to_train)
LOG.info(self.train_on_eos)
LOG.info(self.prompter.message_field_training)
LOG.info(self.prompter.message_field_training_detail)
turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns)
@@ -398,23 +368,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
return prompt[self.messages]
def get_images(self, prompt):
return prompt.get(self.images, None)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
@@ -424,7 +386,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
strategy_params = {

View File

@@ -2,16 +2,15 @@
DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
from axolotl.utils.chat_templates import chat_templates
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
chat_template_str = chat_templates(cfg.chat_template)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
@@ -31,12 +30,6 @@ def default(
role_map[source] = target
def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
messages = sample[field_messages]
messages = [
{
@@ -53,29 +46,28 @@ def default(
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[dummy_user_message, chosen],
[chosen],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[dummy_user_message, rejected],
[rejected],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])

View File

@@ -1,34 +0,0 @@
"""Module to load message prompt strategies."""
import importlib
import inspect
import logging
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
def load(tokenizer, cfg, ds_cfg, processor=None):
try:
strategy = ds_cfg.get("input_transform", "chat")
# pylint: disable=duplicate-code
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.messages"
)
func = getattr(mod, load_fn)
load_kwargs = {}
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None

View File

@@ -1,84 +0,0 @@
"""
Chat dataset wrapping strategy for new internal messages representations
"""
from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset
from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):
"""
Chat dataset wrapping strategy for new internal messages representations
"""
def __init__(
self,
processor,
message_transform=None,
formatter=None,
**kwargs, # pylint: disable=unused-argument
):
"""
:param processor: tokenizer or image processor
:param kwargs:
"""
self.processor = processor
self.dataset = None
self.message_transform = message_transform
self.formatter = formatter
def wrap_dataset(
self,
dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs, # pylint: disable=unused-argument
):
self.dataset = TokenizedChatDataset(
dataset,
message_transform=self.message_transform,
model_transform=self.processor,
formatter=self.formatter,
process_count=process_count,
keep_in_memory=keep_in_memory,
)
return self.dataset
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
field_messages = ds_cfg.get("field_messages")
message_field_role = ds_cfg.get("message_field_role")
message_field_content = ds_cfg.get("message_field_content")
message_field_training = ds_cfg.get("message_field_training")
builder_kwargs = {}
if field_messages:
builder_kwargs["conversations_field"] = field_messages
if message_field_role:
builder_kwargs["message_field_role"] = message_field_role
if message_field_content:
builder_kwargs["message_field_content"] = message_field_content
if message_field_training:
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
)
if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):
from axolotl.core.chat.format.llama3x import format_message # noqa F811
message_transform: Callable = chat_message_transform_builder(
train_on_inputs=ds_cfg.get("train_on_inputs", False),
**builder_kwargs,
)
strategy = ChatMessageDatasetWrappingStrategy(
tokenizer, message_transform=message_transform, formatter=format_message
)
return strategy

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.chat_templates import chat_templates
class Message(BaseModel):
@@ -28,13 +28,18 @@ def load(
"""
chatml transforms for datasets with system, input, chosen, rejected
"""
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
tokenizer.chat_template = chat_template_string
chat_template = chat_templates("chatml")
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
except ValueError:
pass
tokenizer.chat_template = chat_template
return ORPOTokenizingStrategy(
ORPOPrompter(chat_template_string, tokenizer),
ORPOPrompter(chat_template, tokenizer),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -243,30 +248,28 @@ class ORPOPrompter(Prompter):
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, tokenizer=tokenizer
)
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]

View File

@@ -61,9 +61,6 @@ def build_loader(
default_conversation: Optional[str] = None,
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
)
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg

View File

@@ -30,12 +30,6 @@ class InvalidDataException(Exception):
"""
class DatasetWrappingStrategy(abc.ABC):
"""
Abstract class for wrapping datasets for Chat Messages
"""
class PromptTokenizingStrategy(abc.ABC):
"""
Abstract class for tokenizing strategies

View File

@@ -10,6 +10,7 @@ from typing import Optional, Tuple, Union
import torch
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
@@ -23,7 +24,7 @@ from axolotl.core.tokenizer_utils import fix_untrained_tokens
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
try:
@@ -68,9 +69,6 @@ def train(
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
processor = None
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
@@ -96,11 +94,10 @@ def train(
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
if model.generation_config is not None:
model.generation_config.do_sample = True
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True
model_ref = None
if cfg.rl and cfg.rl != "orpo":
@@ -125,7 +122,6 @@ def train(
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
@@ -260,10 +256,8 @@ def train(
if not cfg.hub_model_id:
try:
trainer.create_model_card(
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
)
except (AttributeError, UnicodeDecodeError):
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
except AttributeError:
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated

View File

@@ -1,12 +1,8 @@
"""
Basic utils for Axolotl
"""
import importlib.util
import importlib
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None

View File

@@ -29,7 +29,7 @@ from transformers import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
references=[[r] for r in references],
predictions=predictions,
)
scores["eval_" + metric_name] = score
scores[metric_name] = score
return scores
def predict_with_generate():
@@ -747,15 +747,6 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
elif logger == "comet_ml" and is_comet_available():
import comet_ml
experiment = comet_ml.get_running_experiment()
if experiment:
experiment.log_table(
f"{name} - Predictions vs Ground Truth.csv",
pd.DataFrame(table_data),
)
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)

View File

@@ -1,43 +0,0 @@
"""Comet module for trainer callbacks"""
import logging
from typing import TYPE_CHECKING
import comet_ml
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
"""Callback to save axolotl config to comet"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
comet_experiment = comet_ml.start(source="axolotl")
comet_experiment.log_other("Created from", "axolotl")
comet_experiment.log_asset(
self.axolotl_config_path,
file_name="axolotl-config",
)
LOG.info(
"The Axolotl config has been saved to the Comet Experiment under assets."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
return control

File diff suppressed because one or more lines are too long

View File

@@ -1,14 +1,17 @@
"""
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union
import numpy as np
import torch
import transformers
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
IGNORE_INDEX = -100
@dataclass
class DataCollatorForSeq2Seq:
@@ -180,6 +183,34 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
return super().__call__(out_features, return_tensors=return_tensors)
@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return {
"input_ids": input_ids,
"labels": labels,
}
@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""

View File

@@ -1,10 +0,0 @@
"""
shared axolotl collators for multipack, mamba, multimodal
"""
from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from .mamba import MambaDataCollator # noqa: F401

View File

@@ -1,4 +0,0 @@
"""
basic shared collator constants
"""
IGNORE_INDEX = -100

View File

@@ -1,38 +0,0 @@
"""
collators for Mamba
"""
from dataclasses import dataclass
from typing import Dict, Sequence
import torch
import transformers
from axolotl.utils.collators.core import IGNORE_INDEX
@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return {
"input_ids": input_ids,
"labels": labels,
}

View File

@@ -1,83 +0,0 @@
"""
Collators for multi-modal chat messages and packing
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy
@dataclass
class MultiModalChatDataCollator(DataCollatorMixin):
"""
Collator for multi-modal chat messages
"""
tokenizer: PreTrainedTokenizerBase
processor: ProcessorMixin
return_tensors: str = "pt"
chat_template: Optional[str] = None
packing: bool = False
max_images: int = -1
padding: Union[bool, str, PaddingStrategy] = True
pad_to_multiple_of: Optional[int] = None
def __post_init__(self):
if self.packing:
raise ValueError("Packing is currently not supported.")
def torch_call(
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows(
examples, self.processor, self.chat_template, self.max_images
)
@staticmethod
def process_rows(examples, processor, chat_template, max_images, length_only=False):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]
images = [
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels[labels == image_token_id] = -100
batch["labels"] = labels
if length_only:
return {
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
}
return batch

View File

@@ -1,93 +0,0 @@
"""Module for wandb utilities"""
import logging
import os
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.utils.comet_")
COMET_ENV_MAPPING_OVERRIDE = {
"comet_mode": "COMET_START_MODE",
"comet_online": "COMET_START_ONLINE",
}
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
"auto_log_co2": "COMET_AUTO_LOG_CO2",
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
"log_code": "COMET_AUTO_LOG_CODE",
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
"log_graph": "COMET_AUTO_LOG_GRAPH",
"name": "COMET_START_EXPERIMENT_NAME",
"offline_directory": "COMET_OFFLINE_DIRECTORY",
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
"tags": "COMET_START_EXPERIMENT_TAGS",
}
def python_value_to_environ_value(python_value):
if isinstance(python_value, bool):
if python_value is True:
return "true"
return "false"
if isinstance(python_value, int):
return str(python_value)
if isinstance(python_value, list): # Comet only have one list of string parameter
return ",".join(map(str, python_value))
return python_value
def setup_comet_env_vars(cfg: DictDefault):
# TODO, we need to convert Axolotl configuration to environment variables
# as Transformers integration are call first and would create an
# Experiment first
for key in cfg.keys():
if key.startswith("comet_") and key != "comet_experiment_config":
value = cfg.get(key, "")
if value is not None and value != "":
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
final_value = python_value_to_environ_value(value)
os.environ[env_variable_name] = final_value
if cfg.comet_experiment_config:
for key, value in cfg.comet_experiment_config.items():
if value is not None and value != "":
config_env_variable_name = (
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
)
if config_env_variable_name is None:
LOG.warning(
f"Unknown Comet Experiment Config name {key}, ignoring it"
)
continue
final_value = python_value_to_environ_value(value)
os.environ[config_env_variable_name] = final_value
# Enable comet if project name is present
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
cfg.use_comet = True

View File

@@ -121,36 +121,15 @@ def normalize_config(cfg):
cfg.base_model_config = cfg.base_model
model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
cfg.tokenizer_config = (
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
)
cfg.is_multimodal = (
hasattr(model_config, "model_type")
and model_config.model_type in ["llava", "mllama"]
or any(
multimodal_name in cfg.base_model.lower()
for multimodal_name in [
"pixtral",
]
)
or cfg.is_multimodal
)
if cfg.is_multimodal:
cfg.processor_config = (
cfg.processor_config or cfg.base_model_config or cfg.base_model
)
model_config = model_config.text_config
cfg.model_config_type = model_config.model_type
# figure out if the model is llama
cfg.is_llama_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type == ["llama", "mllama_text_model"]
)
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
@@ -228,7 +207,6 @@ def normalize_cfg_datasets(cfg):
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
)
cfg.datasets[idx].chat_template = cfg.chat_template
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):

View File

@@ -8,16 +8,9 @@ import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import (
BaseModel,
Field,
StringConstraints,
conlist,
field_validator,
model_validator,
)
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
@@ -28,37 +21,6 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -140,22 +102,14 @@ class SFTDataset(BaseModel):
path: Optional[str] = None
split: Optional[str] = None
type: Optional[Union[str, UserDefinedPrompterType]] = None
input_transform: Optional[str] = None
shards: Optional[int] = None
conversation: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: Optional[
Union[
ChatTemplate,
str,
]
] = None
chat_template_jinja: Optional[str] = None
chat_template: Optional[str] = None
data_files: Optional[Union[str, List[str]]] = None
input_format: Optional[str] = None
name: Optional[str] = None
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
@@ -166,31 +120,11 @@ class SFTDataset(BaseModel):
message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
class UserDefinedDPOType(BaseModel):
@@ -212,7 +146,6 @@ class DPODataset(BaseModel):
split: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None
class UserDefinedKTOType(BaseModel):
@@ -234,7 +167,31 @@ class KTODataset(BaseModel):
type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
@@ -271,12 +228,11 @@ class LoraConfig(BaseModel):
lora_r: Optional[int] = None
lora_alpha: Optional[int] = None
lora_fan_in_fan_out: Optional[bool] = None
lora_target_modules: Optional[Union[str, List[str]]] = None
lora_target_modules: Optional[List[str]] = None
lora_target_linear: Optional[bool] = None
lora_modules_to_save: Optional[List[str]] = None
lora_dropout: Optional[float] = 0.0
peft_layers_to_transform: Optional[List[int]] = None
peft_layers_pattern: Optional[List[str]] = None
peft: Optional[PeftConfig] = None
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
@@ -342,13 +298,6 @@ class LoraConfig(BaseModel):
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@field_validator("loraplus_lr_embedding")
@classmethod
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
loraplus_lr_embedding = float(loraplus_lr_embedding)
return loraplus_lr_embedding
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
@@ -372,9 +321,6 @@ class ModelInputConfig(BaseModel):
tokenizer_type: Optional[str] = Field(
default=None, metadata={"help": "transformers tokenizer class"}
)
processor_type: Optional[str] = Field(
default=None, metadata={"help": "transformers processor class"}
)
trust_remote_code: Optional[bool] = None
model_kwargs: Optional[Dict[str, Any]] = None
@@ -426,6 +372,7 @@ class HyperparametersConfig(BaseModel):
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"shampoo",
],
]
] = OptimizerNames.ADAMW_HF.value
@@ -438,6 +385,12 @@ class HyperparametersConfig(BaseModel):
"help": "The target modules to optimize, i.e. the module names that you would like to train."
},
)
optim_shampoo_grafting_config_type: Optional[
Literal["adam", "sgd", "adagrad"]
] = None
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
optim_shampoo_betas: Optional[Tuple[float, float]] = None
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
@@ -486,7 +439,6 @@ class MLFlowConfig(BaseModel):
use_mlflow: Optional[bool] = None
mlflow_tracking_uri: Optional[str] = None
mlflow_experiment_name: Optional[str] = None
mlflow_run_name: Optional[str] = None
hf_mlflow_log_artifacts: Optional[bool] = None
@@ -532,19 +484,6 @@ class WandbConfig(BaseModel):
return data
class CometConfig(BaseModel):
"""Comet configuration subset"""
use_comet: Optional[bool] = None
comet_api_key: Optional[str] = None
comet_workspace: Optional[str] = None
comet_project_name: Optional[str] = None
comet_experiment_key: Optional[str] = None
comet_mode: Optional[str] = None
comet_online: Optional[bool] = None
comet_experiment_config: Optional[Dict[str, Any]] = None
class GradioConfig(BaseModel):
"""Gradio configuration subset"""
@@ -565,7 +504,6 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
CometConfig,
LISAConfig,
GradioConfig,
RemappedParameters,
@@ -583,10 +521,8 @@ class AxolotlInputConfig(
resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False
rl: Optional[RLType] = None
reward_model: Optional[bool] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
@@ -594,7 +530,6 @@ class AxolotlInputConfig(
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None
skip_prepare_dataset: Optional[bool] = False
pretraining_dataset: Optional[ # type: ignore
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
@@ -753,13 +688,7 @@ class AxolotlInputConfig(
gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None
chat_template: Optional[
Union[
ChatTemplate,
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
]
] = None
chat_template_jinja: Optional[str] = None
chat_template: Optional[ChatTemplate] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
@@ -868,23 +797,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):
@@ -915,17 +827,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def hint_reward_model_pad(cls, data):
if data.get("reward_model") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using reward_model"
)
if data.get("pad_to_sequence_len") is None:
data["pad_to_sequence_len"] = True
return data
@model_validator(mode="before")
@classmethod
def check_gas_bsz(cls, data):
@@ -1059,26 +960,6 @@ class AxolotlInputConfig(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if data.get("do_bench_eval") and not (
data.get("evals_per_epoch") or data.get("eval_steps")
):
raise ValueError(
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
)
return data
@model_validator(mode="before")
@classmethod
def check_test_datasets_bench(cls, data):
if (
data.get("do_bench_eval")
and not data.get("test_datasets")
and not data.get("val_set_size")
):
LOG.warning(
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
)
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
return data
@model_validator(mode="before")
@@ -1116,18 +997,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_mm_prepare(cls, data):
if data.get("skip_prepare_dataset"):
if data.get("remove_unused_columns") is None:
LOG.info(
"setting `remove_unused_columns: false` for skip_prepare_dataset"
)
data["remove_unused_columns"] = False
return data
@model_validator(mode="before")
@classmethod
def check_warmup(cls, data):
@@ -1155,20 +1024,12 @@ class AxolotlInputConfig(
return neftune_noise_alpha
@model_validator(mode="after")
def check_rl_beta(self):
def check(self):
if self.dpo_beta and not self.rl_beta:
self.rl_beta = self.dpo_beta
del self.dpo_beta
return self
@model_validator(mode="after")
def check_simpo_warmup(self):
if self.rl == "simpo" and self.warmup_ratio:
raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
)
return self
@model_validator(mode="before")
@classmethod
def check_frozen(cls, data):
@@ -1183,15 +1044,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_peft_layers_pattern(cls, data):
if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"):
raise ValueError(
"peft_layers_pattern requires peft_layers_to_transform to be set"
)
return data
@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (

View File

@@ -90,7 +90,6 @@ def load_prepare_dpo_datasets(cfg):
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
)
split_datasets.insert(i, ds)

View File

@@ -19,12 +19,10 @@ from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
DatasetWrappingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
@@ -53,31 +51,20 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def prepare_dataset(cfg, tokenizer, processor=None):
def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_local_main_process()):
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="train",
processor=processor,
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
processor=processor,
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
path = cfg.pretraining_dataset
@@ -136,7 +123,6 @@ def load_tokenized_prepared_datasets(
cfg,
default_dataset_prepared_path,
split="train",
processor=None,
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = cfg.tokenizer_config
@@ -194,7 +180,6 @@ def load_tokenized_prepared_datasets(
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
and not cfg.skip_prepare_dataset
):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
@@ -244,7 +229,6 @@ def load_tokenized_prepared_datasets(
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -349,7 +333,6 @@ def load_tokenized_prepared_datasets(
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
@@ -384,7 +367,6 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
@@ -394,7 +376,6 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
@@ -439,19 +420,15 @@ def load_tokenized_prepared_datasets(
config_dataset=config_dataset,
tokenizer=tokenizer,
cfg=cfg,
d_base_type=d_base_type,
dataset=ds,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
processor=processor,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
if len(datasets) == 1:
dataset = datasets[0]
else:
LOG.info("merging datasets")
dataset = concatenate_datasets(datasets)
LOG.info("merging datasets")
dataset = concatenate_datasets(datasets)
if len(datasets) > 1:
if cfg.shuffle_merged_datasets:
@@ -460,10 +437,9 @@ def load_tokenized_prepared_datasets(
else:
LOG.debug("NOT shuffling merged datasets")
if cfg.sample_packing and not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
@@ -502,14 +478,9 @@ def load_prepare_datasets(
cfg,
default_dataset_prepared_path,
split="train",
processor=None,
) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer,
cfg,
default_dataset_prepared_path,
split=split,
processor=processor,
tokenizer, cfg, default_dataset_prepared_path, split=split
)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
@@ -575,7 +546,6 @@ def get_dataset_wrapper(
d_base_type,
dataset,
d_prompt_style=None,
processor=None, # pylint: disable=unused-argument
):
dataset_wrapper = None
dataset_prompter = None
@@ -608,31 +578,13 @@ def get_dataset_wrapper(
dataset,
**ds_kwargs,
)
elif cfg.skip_prepare_dataset:
dataset_wrapper = dataset
elif ds_strategy := config_dataset.type.startswith(
"bradley_terry"
) and bradley_terry_load(
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
):
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
**ds_kwargs,
)
elif ds_strategy := load(
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
):
if isinstance(ds_strategy, DatasetWrappingStrategy):
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
else:
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
**ds_kwargs,
)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy(

View File

@@ -16,7 +16,3 @@ def setup_mlflow_env_vars(cfg: DictDefault):
# Enable mlflow if experiment name is present
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
cfg.use_mlflow = True
# Enable logging hf artifacts in mlflow if value is truthy
if cfg.hf_mlflow_log_artifacts is True:
os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true"

File diff suppressed because it is too large Load Diff

View File

@@ -133,8 +133,6 @@ class MultipackBatchSampler(BatchSampler):
self.eff_total_used = 0
self.eff_total_slots = 0
self.len_across_ranks = None
def set_epoch(self, epoch: int):
self.epoch = epoch
@@ -197,14 +195,15 @@ class MultipackBatchSampler(BatchSampler):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
min_len_batches = reduce_and_broadcast(
lambda: num,
calc_min_len,
)
return min_len_batches
def __len__(self):
if not self.len_across_ranks:
len_batches = self.num_batches()
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks
len_batches = self.num_batches()
return self.gather_len_batches(len_batches)
def _len_est(self):
efficiency = (

View File

@@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.cuda
from accelerate.logging import get_logger
from datasets import disable_caching, enable_caching
from datasets import set_caching_enabled
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
@@ -87,10 +87,10 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
@contextmanager
def disable_datasets_caching():
try:
disable_caching()
set_caching_enabled(False)
yield
finally:
enable_caching()
set_caching_enabled(True)
def add_position_ids(sample):
@@ -306,11 +306,7 @@ def process_pretraining_datasets_for_packing(
def calculate_total_num_steps(cfg, train_dataset, update=True):
if (
not cfg.total_num_tokens
and not cfg.skip_prepare_dataset
and not cfg.reward_model
):
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
@@ -323,12 +319,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
skip_estimates = cfg.model_config_type == "mamba"
if (
not skip_estimates
and not cfg.total_supervised_tokens
and not cfg.skip_prepare_dataset
and not cfg.reward_model
):
if not skip_estimates and not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
@@ -487,15 +478,13 @@ def prepare_opinionated_env(cfg):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
else:
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

View File

@@ -1,197 +0,0 @@
"""
Tests for the chat messages module
"""
import unittest
import pytest
from transformers import AddedToken, AutoTokenizer
from axolotl.core.chat.format.chatml import format_message
from axolotl.core.chat.messages import ChatFormattedChats, Chats
@pytest.fixture(scope="session", name="llama_tokenizer")
def llama_tokenizer_fixture():
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
@pytest.fixture(scope="session", name="chatml_tokenizer")
def llama_tokenizer_w_chatml(llama_tokenizer):
llama_tokenizer.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
llama_tokenizer.add_tokens(
[
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
]
)
return llama_tokenizer
@pytest.fixture(scope="session", name="chat_msgs")
def chat_msgs_fixture():
return {
"conversation": [
{
"role": "system",
"content": [
{"type": "text", "value": "You are a helpful assistant."},
],
},
{
"role": "user",
"content": [
{"type": "text", "value": "What is today's stock price of Apple?"},
],
},
{
"role": "assistant",
"content": [
{
"type": "tool_call",
"value": {
"name": "get_date",
"arguments": {},
},
},
{
"type": "tool_call",
"value": {
"name": "get_stock_price",
"arguments": {"symbol": "AAPL"},
},
},
],
"weight": 1,
},
{
"role": "tool",
"content": [
{
"type": "tool_response",
"value": {
"name": "get_date",
"content": {"date": "2024-09-09"},
},
},
{
"type": "tool_response",
"value": {
"name": "get_stock_price",
"content": {"symbol": "AAPL", "price": 123.45},
},
},
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"value": "The stock price of Apple is $123.45.\n",
"weight": 0,
},
{
"type": "text",
"value": "<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>",
},
{
"type": "text",
"value": "The stock price of Apple on September 9, 2024 is $123.45.",
},
],
"weight": 1,
},
]
}
class TestMessagesCase:
"""
Test cases for the chat messages module
"""
def test_tool_call_stringify(self, chat_msgs):
chat_msgs_as_obj = Chats(**chat_msgs)
assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str(
chat_msgs_as_obj.conversation[2].content[1].value
)
def test_chatml_formatted_wrapper(self, chat_msgs):
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
target_chatml = """<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is today's stock price of Apple?<|im_end|>
<|im_start|>assistant
<tool_call>
{"name": "get_date", "arguments": {}}
</tool_call>
<tool_call>
{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}
</tool_call>
<|im_end|>
<|im_start|>tool
<tool_response>
{"name": "get_date", "content": {"date": "2024-09-09"}}
</tool_response>
<tool_response>
{"name": "get_stock_price", "content": {"symbol": "AAPL", "price": 123.45}}
</tool_response>
<|im_end|>
<|im_start|>assistant
The stock price of Apple is $123.45.
<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\n"""
assert target_chatml == str(chat_msg_formatted)
def test_chatml_formatting_tool_call(self, chat_msgs):
chat_msgs_as_obj = Chats(**chat_msgs)
target_chatml_turn2 = """<|im_start|>assistant\n<tool_call>\n{"name": "get_date", "arguments": {}}\n</tool_call>\n<tool_call>\n{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}\n</tool_call>\n<|im_end|>\n"""
assert target_chatml_turn2 == str(
format_message(chat_msgs_as_obj.conversation[2])
)
def test_train_labels(self, chatml_tokenizer, chat_msgs):
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer)
# fmt: off
target_labels = [
-100, -100, -100, # role
27, 14506, 13735, 397, 5018, 609, 794,
330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524,
14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794,
330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314,
794, 330, 84016, 43, 96742, 524, 14506, 13735, 397,
128256, # <|im_end|>
-100 # trailing newline
]
# fmt: on
assert tokenized["labels"] == target_labels
def test_train_labels_2(self, chatml_tokenizer, chat_msgs):
# also test if indivudal contents are set not to train
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer)
# fmt: off
target_labels = [
-100, -100, -100, # role
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # initial response
27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430,
315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457,
5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315,
8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400,
4513, 13, 1774, 13,
128256, # <|im_end|>
-100, # trailing newline
]
# fmt: on
assert tokenized["labels"] == target_labels
if __name__ == "__main__":
unittest.main()

View File

@@ -1,155 +0,0 @@
"""
E2E tests for multigpu eval
"""
import logging
import os
import unittest
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
class TestMultiGPUEval(unittest.TestCase):
"""
Test case for MultiGPU Eval Sample Packing
"""
@with_temp_dir
def test_eval_sample_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": True,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"evals_per_epoch": 2,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_eval(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"evals_per_epoch": 2,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -19,8 +19,6 @@ from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@pytest.fixture(scope="session", autouse=True)
def download_model():
@@ -348,115 +346,3 @@ class TestMultiGPULlama(unittest.TestCase):
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_ds_zero3_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_ds_zero3_qlora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -1,12 +1,22 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
from axolotl.monkeypatch.unsloth_ import (
check_cel_is_patchable,
check_self_attn_is_patchable,
)
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""
def test_is_cel_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_cel_is_patchable(),
"HF transformers loss code has changed and isn't patchable",
)
def test_is_self_attn_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(

View File

@@ -1,95 +0,0 @@
"""Module for testing ModelLoader."""
import shutil
import tempfile
import pytest
import torch
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
@pytest.fixture(name="temp_dir")
def fixture_temp_dir():
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
class TestLoadModelUtils:
"""
Testing module testing ModelLoader.
"""
def setup_method(self):
# load config
self.cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"tokenizer_config": "JackFram/llama-68m",
"sequence_len": 1024,
"load_in_8bit": False,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer="",
)
)
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
@pytest.mark.parametrize(
"dist_dtype", [torch.bfloat16, torch.float16, torch.float32]
)
@pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False])
def test_convert_embedding_modules_dtype(
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
):
self.cfg.output_dir = temp_dir
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
self.model_loader.model, _ = load_model(
self.cfg,
self.model_loader.tokenizer,
inference=False,
reference_model=True,
)
self.model_loader.convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune
)
for name, module in self.model_loader.model.named_modules():
if (
"norm" in name
or (before_kbit_train_or_finetune and name.endswith(".gate"))
or (
any(m in name for m in embedding_modules)
and hasattr(module, "weight")
)
):
for _, param in module.named_parameters():
assert param.dtype == dist_dtype

View File

@@ -1,74 +0,0 @@
"""
E2E tests for packed training
"""
import logging
import os
import unittest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPackedLlama(unittest.TestCase):
"""
Test case for Packed training of llama models
"""
@with_temp_dir
def test_loss_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

View File

@@ -1,74 +0,0 @@
"""
E2E tests for reward model lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestRewardModelLoraLlama(unittest.TestCase):
"""
Test case for Llama reward models using LoRA
"""
@with_temp_dir
def test_rm_fft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"model_type": "AutoModelForSequenceClassification",
"tokenizer_type": "LlamaTokenizer",
"chat_template": "alpaca",
"reward_model": True,
"sequence_len": 1024,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "bradley_terry.chat_template",
},
],
"remove_unused_columns": False,
"max_steps": 10,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

Some files were not shown because too many files have changed in this diff Show More