Compare commits
1 Commits
yayi2
...
hamelsmu-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
856f5f6115 |
7
.github/workflows/base.yml
vendored
7
.github/workflows/base.yml
vendored
@@ -28,12 +28,7 @@ jobs:
|
|||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
51
.github/workflows/main.yml
vendored
51
.github/workflows/main.yml
vendored
@@ -27,56 +27,38 @@ jobs:
|
|||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.0
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v3
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v3
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl
|
images: winglian/axolotl
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
- name: Set up Docker Buildx
|
||||||
- name: Build and export to Docker
|
uses: docker/setup-buildx-action@v2
|
||||||
uses: docker/build-push-action@v5
|
- name: Build
|
||||||
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
load: true
|
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
file: ./docker/Dockerfile
|
file: ./docker/Dockerfile
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: |
|
tags: |
|
||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
- name: Unit Tests
|
|
||||||
run: |
|
|
||||||
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
|
||||||
- name: Push to Docker Hub
|
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
run: |
|
|
||||||
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
|
||||||
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
|
||||||
if [ -n "$latest_tag" ]; then
|
|
||||||
docker push "$latest_tag"
|
|
||||||
fi
|
|
||||||
|
|
||||||
build-axolotl-runpod:
|
build-axolotl-runpod:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||||
@@ -98,31 +80,26 @@ jobs:
|
|||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.0
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v3
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v3
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-runpod
|
images: winglian/axolotl-runpod
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v2
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
|
|||||||
30
README.md
30
README.md
@@ -102,7 +102,7 @@ pip3 install -e '.[flash-attn,deepspeed]'
|
|||||||
```
|
```
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
```bash
|
```bashtet
|
||||||
# finetune lora
|
# finetune lora
|
||||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||||
|
|
||||||
@@ -520,14 +520,6 @@ model_config:
|
|||||||
type: # linear | dynamic
|
type: # linear | dynamic
|
||||||
factor: # float
|
factor: # float
|
||||||
|
|
||||||
# optional overrides to the bnb 4bit quantization configuration
|
|
||||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
|
||||||
bnb_config_kwargs:
|
|
||||||
# These are default values
|
|
||||||
llm_int8_has_fp16_weight: false
|
|
||||||
bnb_4bit_quant_type: nf4
|
|
||||||
bnb_4bit_use_double_quant: true
|
|
||||||
|
|
||||||
|
|
||||||
# Whether you are training a 4-bit GPTQ quantized model
|
# Whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
@@ -589,9 +581,6 @@ datasets:
|
|||||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||||
field:
|
field:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
|
||||||
# Currently supports chatml and inst (mistral/mixtral)
|
|
||||||
chat_template: chatml
|
|
||||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
@@ -683,7 +672,6 @@ relora_warmup_steps: # Number of per-restart warmup steps
|
|||||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||||
|
|
||||||
# wandb configuration if you're using it
|
# wandb configuration if you're using it
|
||||||
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
|
||||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||||
wandb_project: # Your wandb project name
|
wandb_project: # Your wandb project name
|
||||||
wandb_entity: # A wandb Team name if using a Team
|
wandb_entity: # A wandb Team name if using a Team
|
||||||
@@ -810,6 +798,11 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
|||||||
# Whether to use scaled-dot-product attention
|
# Whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
|
# Landmark attention (only llama)
|
||||||
|
landmark_attention:
|
||||||
|
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||||
|
# LLaMA only
|
||||||
|
xpos_rope:
|
||||||
|
|
||||||
# Resume from a specific checkpoint dir
|
# Resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
@@ -976,8 +969,6 @@ fsdp_config:
|
|||||||
|
|
||||||
##### Weights & Biases Logging
|
##### Weights & Biases Logging
|
||||||
|
|
||||||
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
|
||||||
|
|
||||||
- wandb options
|
- wandb options
|
||||||
```yaml
|
```yaml
|
||||||
wandb_mode:
|
wandb_mode:
|
||||||
@@ -1004,12 +995,9 @@ tokens: # these are delimiters
|
|||||||
|
|
||||||
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
|
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
|
||||||
|
|
||||||
### Inference Playground
|
### Inference
|
||||||
|
|
||||||
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
|
Pass the appropriate flag to the train command:
|
||||||
The config file is the same config file used for training.
|
|
||||||
|
|
||||||
Pass the appropriate flag to the inference command, depending upon what kind of model was trained:
|
|
||||||
|
|
||||||
- Pretrained LORA:
|
- Pretrained LORA:
|
||||||
```bash
|
```bash
|
||||||
@@ -1038,7 +1026,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
|
|||||||
Add below flag to train command above
|
Add below flag to train command above
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model"
|
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||||
```
|
```
|
||||||
|
|
||||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
{
|
|
||||||
"zero_optimization": {
|
|
||||||
"stage": 3,
|
|
||||||
"offload_optimizer": {
|
|
||||||
"device": "cpu",
|
|
||||||
"pin_memory": true
|
|
||||||
},
|
|
||||||
"offload_param": {
|
|
||||||
"device": "cpu",
|
|
||||||
"pin_memory": true
|
|
||||||
},
|
|
||||||
"overlap_comm": true,
|
|
||||||
"contiguous_gradients": true,
|
|
||||||
"sub_group_size": 0,
|
|
||||||
"reduce_bucket_size": "auto",
|
|
||||||
"stage3_prefetch_bucket_size": "auto",
|
|
||||||
"stage3_param_persistence_threshold": "auto",
|
|
||||||
"stage3_max_live_parameters": 0,
|
|
||||||
"stage3_max_reuse_distance": 0,
|
|
||||||
"stage3_gather_16bit_weights_on_model_save": true
|
|
||||||
},
|
|
||||||
"bf16": {
|
|
||||||
"enabled": "auto"
|
|
||||||
},
|
|
||||||
"fp16": {
|
|
||||||
"enabled": "auto",
|
|
||||||
"auto_cast": false,
|
|
||||||
"loss_scale": 0,
|
|
||||||
"initial_scale_power": 32,
|
|
||||||
"loss_scale_window": 1000,
|
|
||||||
"hysteresis": 2,
|
|
||||||
"min_loss_scale": 1
|
|
||||||
},
|
|
||||||
"optimizer": {
|
|
||||||
"type": "AdamW",
|
|
||||||
"params": {
|
|
||||||
"lr": "auto",
|
|
||||||
"betas": "auto",
|
|
||||||
"eps": "auto",
|
|
||||||
"weight_decay": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
|
||||||
"train_batch_size": "auto",
|
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
|
||||||
"wall_clock_breakdown": false
|
|
||||||
}
|
|
||||||
@@ -19,15 +19,13 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
|||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
|
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn]; \
|
pip install -e .[deepspeed,flash-attn]; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
|
||||||
RUN pip install pytest
|
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
git config --get remote.origin.fetch
|
git config --get remote.origin.fetch
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ output_dir: ./out
|
|||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
eval_sample_packing: false
|
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -23,9 +23,6 @@ unfrozen_parameters:
|
|||||||
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
||||||
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
||||||
|
|
||||||
model_config:
|
|
||||||
output_router_logits: true
|
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|||||||
@@ -1,64 +0,0 @@
|
|||||||
base_model: models/yayi2-30b
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
is_mistral_derived_model: false
|
|
||||||
trust_remote_code: true
|
|
||||||
model_revision: refs/pr/5
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./out
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
eval_sample_packing: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.000005
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed: deepspeed/zero3_cpu.json
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
bos_token: "<s>"
|
|
||||||
eos_token: "</s>"
|
|
||||||
unk_token: "<unk>"
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
base_model: wenge-research/yayi2-30b
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
is_mistral_derived_model: false
|
|
||||||
trust_remote_code: true
|
|
||||||
model_revision: refs/pr/5
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./qlora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048 # Fits in 40gb VRAM. Can easily do 4096 in A100 80 or a A6000
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
lora_target_modules:
|
|
||||||
|
|
||||||
wandb_project: yayi2
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0005
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: false
|
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
bos_token: "<s>"
|
|
||||||
eos_token: "</s>"
|
|
||||||
unk_token: "<unk>"
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
# Overview
|
|
||||||
|
|
||||||
This is an example of a Yi-34B-Chat configuration. It demonstrates that it is possible to finetune a 34B model on a GPU with 24GB of VRAM.
|
|
||||||
|
|
||||||
Tested on an RTX 4090 with `python -m axolotl.cli.train examples/mistral/qlora.yml`, a single epoch of finetuning on the alpaca dataset using qlora runs in 47 mins, using 97% of available memory.
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
base_model: 01-ai/Yi-34B-Chat
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
is_mistral_derived_model: false
|
|
||||||
is_llama_derived_model: true
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
sequence_len: 1024
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
flash_attention: true
|
|
||||||
special_tokens:
|
|
||||||
bos_token: "<|startoftext|>"
|
|
||||||
eos_token: "<|endoftext|>"
|
|
||||||
unk_token: "<unk>"
|
|
||||||
|
|
||||||
# Data
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
warmup_steps: 10
|
|
||||||
|
|
||||||
# Iterations
|
|
||||||
num_epochs: 1
|
|
||||||
|
|
||||||
# Evaluation
|
|
||||||
val_set_size: 0.1
|
|
||||||
evals_per_epoch: 5
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
eval_sample_packing: false
|
|
||||||
eval_batch_size: 1
|
|
||||||
|
|
||||||
# LoRA
|
|
||||||
output_dir: ./qlora-out
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
lora_target_modules:
|
|
||||||
|
|
||||||
# Sampling
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
# Batching
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
gradient_checkpointing: true
|
|
||||||
|
|
||||||
# wandb
|
|
||||||
wandb_project:
|
|
||||||
|
|
||||||
# Optimizer
|
|
||||||
optimizer: paged_adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
# Misc
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
auto-gptq==0.5.1
|
auto-gptq==0.5.1
|
||||||
packaging
|
packaging
|
||||||
peft==0.6.0
|
peft==0.6.0
|
||||||
transformers==4.36.2
|
transformers @ git+https://github.com/huggingface/transformers.git@ebfdb9ca62205279d5019ef1403877461b3b2da4
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.24.1
|
accelerate==0.24.1
|
||||||
|
|||||||
15
setup.py
15
setup.py
@@ -1,7 +1,5 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
@@ -24,13 +22,12 @@ def parse_requirements():
|
|||||||
# Handle standard packages
|
# Handle standard packages
|
||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
try:
|
# TODO(wing) remove once xformers release supports torch 2.1.0
|
||||||
torch_version = version("torch")
|
if "torch==2.1.0" in _install_requires:
|
||||||
if torch_version.startswith("2.1.1"):
|
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
_install_requires.append(
|
||||||
_install_requires.append("xformers==0.0.23")
|
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
||||||
except PackageNotFoundError:
|
)
|
||||||
pass
|
|
||||||
|
|
||||||
return _install_requires, _dependency_links
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
|
|||||||
@@ -103,7 +103,15 @@ def do_inference(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
if cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||||
|
|
||||||
|
set_model_mem_id(model, tokenizer)
|
||||||
|
model.set_mem_cache_args(
|
||||||
|
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(cfg.device)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -168,7 +176,15 @@ def do_inference_gradio(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
if cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||||
|
|
||||||
|
set_model_mem_id(model, tokenizer)
|
||||||
|
model.set_mem_cache_args(
|
||||||
|
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(cfg.device)
|
||||||
|
|
||||||
def generate(instruction):
|
def generate(instruction):
|
||||||
if not instruction:
|
if not instruction:
|
||||||
|
|||||||
@@ -18,15 +18,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
parsed_cli_args.merge_lora = True
|
parsed_cli_args.merge_lora = True
|
||||||
|
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
|
||||||
parsed_cfg = load_cfg(
|
|
||||||
config,
|
|
||||||
merge_lora=True,
|
|
||||||
load_in_8bit=False,
|
|
||||||
load_in_4bit=False,
|
|
||||||
flash_attention=False,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import math
|
|||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -120,7 +120,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
args = None # type: AxolotlTrainingArguments
|
args = None # type: AxolotlTrainingArguments
|
||||||
tag_names = ["axolotl"]
|
|
||||||
|
|
||||||
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
||||||
self.num_epochs = num_epochs
|
self.num_epochs = num_epochs
|
||||||
@@ -291,41 +290,12 @@ class AxolotlTrainer(Trainer):
|
|||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
|
|
||||||
if isinstance(tag_names, str):
|
|
||||||
tag_names = [tag_names]
|
|
||||||
|
|
||||||
if kwargs is not None:
|
|
||||||
if "tags" not in kwargs:
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
|
||||||
kwargs["tags"].extend(tag_names)
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
|
||||||
tag_names.append(kwargs["tags"])
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
@wraps(Trainer.push_to_hub)
|
|
||||||
def push_to_hub(self, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
|
||||||
"""
|
|
||||||
kwargs = self._sanitize_kwargs_for_tagging(
|
|
||||||
tag_names=self.tag_names, kwargs=kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
Mamba specific trainer to handle loss calculation
|
Mamba specific trainer to handle loss calculation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "mamba"]
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@@ -352,8 +322,6 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
|||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "onecycle"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
@@ -383,8 +351,6 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "relora"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
@@ -780,6 +746,26 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||||
|
|
||||||
|
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||||
|
add_mem_tokens,
|
||||||
|
get_mem_id,
|
||||||
|
set_model_mem_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
set_model_mem_id(self.model, self.tokenizer)
|
||||||
|
|
||||||
|
LOG.info("Adding landmark attention tokens to dataset")
|
||||||
|
|
||||||
|
for dataset in [self.train_dataset, self.eval_dataset]:
|
||||||
|
dataset = dataset.map(
|
||||||
|
partial(
|
||||||
|
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
|
||||||
|
),
|
||||||
|
batched=False,
|
||||||
|
num_proc=32,
|
||||||
|
)
|
||||||
|
|
||||||
trainer_cls = self._get_trainer_cls()
|
trainer_cls = self._get_trainer_cls()
|
||||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||||
trainer_kwargs, trainer_cls
|
trainer_kwargs, trainer_cls
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role + ":", ""
|
yield role + ":", ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
|
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
if self.messages:
|
if self.messages:
|
||||||
# For llama, the system message is incorporated into the first human instruction
|
# For llama, the system message is incorporated into the first human instruction
|
||||||
@@ -101,28 +101,6 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
|
|
||||||
contains_sys_msg = False
|
|
||||||
if self.system_message:
|
|
||||||
contains_sys_msg = True
|
|
||||||
if self.messages:
|
|
||||||
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
|
|
||||||
first_role, first_msg = self.messages[0]
|
|
||||||
if first_role == self.roles[0]:
|
|
||||||
system_prompt = self.system_template.format(
|
|
||||||
system_message=" " + self.system_message
|
|
||||||
)
|
|
||||||
system_prompt += first_msg
|
|
||||||
self.messages.pop(0)
|
|
||||||
yield "", system_prompt
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message and i == 0 and not contains_sys_msg:
|
|
||||||
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
|
|
||||||
elif message:
|
|
||||||
yield role + " ", message
|
|
||||||
else:
|
|
||||||
yield role, ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||||
|
|||||||
1249
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
1249
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
94
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py
Normal file
94
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
"""
|
||||||
|
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class XposRotaryEmbedding(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
base=10000,
|
||||||
|
device=None,
|
||||||
|
scale_base=2048,
|
||||||
|
use_xpos=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.max_seq_len_cached = max_position_embeddings
|
||||||
|
self.scale_base = scale_base
|
||||||
|
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
|
||||||
|
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
||||||
|
freqs = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self.register_buffer("freqs_cached", freqs, persistent=False)
|
||||||
|
|
||||||
|
if not use_xpos:
|
||||||
|
self.register_buffer("scale", None)
|
||||||
|
self.register_buffer("scale_cached", torch.ones(1))
|
||||||
|
return
|
||||||
|
|
||||||
|
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||||
|
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
|
||||||
|
scale_cached = scale ** rearrange(power, "n -> n 1")
|
||||||
|
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
||||||
|
|
||||||
|
self.register_buffer("scale", scale, persistent=False)
|
||||||
|
self.register_buffer("scale_cached", scale_cached, persistent=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
seq_len,
|
||||||
|
):
|
||||||
|
if seq_len > self.max_seq_len_cached:
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
|
||||||
|
self.inv_freq
|
||||||
|
)
|
||||||
|
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
|
||||||
|
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
self.register_buffer("freqs_cached", freqs)
|
||||||
|
|
||||||
|
if self.scale is None:
|
||||||
|
self.register_buffer(
|
||||||
|
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
|
||||||
|
|
||||||
|
power = (t - (seq_len // 2)) / self.scale_base
|
||||||
|
scale = self.scale ** rearrange(power, "n -> n 1")
|
||||||
|
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
|
||||||
|
self.register_buffer("scale_cached", scale)
|
||||||
|
|
||||||
|
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
|
||||||
|
freqs = freqs[position_ids, :]
|
||||||
|
if scale.shape[-1] != 1:
|
||||||
|
scale = scale[position_ids, :]
|
||||||
|
|
||||||
|
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
|
||||||
|
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
|
||||||
|
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_rope_with_xpos_rope():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
|
||||||
|
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||||
@@ -39,23 +39,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
|
||||||
)
|
|
||||||
strategy = UltrachatShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation=conversation,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
if ds_cfg and "strict" in ds_cfg:
|
|
||||||
strategy.strict = ds_cfg["strict"]
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
def load_role(tokenizer, cfg):
|
||||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompterV2(),
|
ShareGPTPrompterV2(),
|
||||||
@@ -126,17 +109,3 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||||
]
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
|
||||||
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
sharegpt strategy that remaps ultrachat data to sharegpt format
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversations = prompt["messages"]
|
|
||||||
role_map = {"user": "human", "assistant": "gpt"}
|
|
||||||
turns = [
|
|
||||||
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import transformers.modelcard
|
|||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from pkg_resources import get_distribution # type: ignore
|
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
@@ -116,12 +115,6 @@ def train(
|
|||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||||
|
|
||||||
if getattr(cfg, "axolotl_config_path"):
|
|
||||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
|
||||||
version = get_distribution("axolotl").version
|
|
||||||
if raw_axolotl_cfg.is_file():
|
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
|
||||||
|
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from shutil import copyfile
|
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
@@ -563,15 +561,10 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
):
|
):
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
try:
|
try:
|
||||||
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
||||||
with NamedTemporaryFile(
|
artifact.add_file(local_path=self.axolotl_config_path)
|
||||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
wandb.run.log_artifact(artifact)
|
||||||
) as temp_file:
|
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
||||||
copyfile(self.axolotl_config_path, temp_file.name)
|
|
||||||
wandb.save(temp_file.name)
|
|
||||||
LOG.info(
|
|
||||||
"The Axolotl config has been saved to the WandB run under files."
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
return control
|
return control
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
"""
|
|
||||||
This module provides functionality for selecting chat templates based on user choices.
|
|
||||||
These templates are used for formatting messages in a conversation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def chat_templates(user_choice: str):
|
|
||||||
"""
|
|
||||||
Finds the correct chat_template for the tokenizer_config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_choice (str): The user's choice of template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The chosen template string.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the user_choice is not found in the templates.
|
|
||||||
"""
|
|
||||||
|
|
||||||
templates = {
|
|
||||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
|
||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
|
||||||
}
|
|
||||||
|
|
||||||
if user_choice in templates:
|
|
||||||
return templates[user_choice]
|
|
||||||
|
|
||||||
raise ValueError(f"Template '{user_choice}' not found.")
|
|
||||||
@@ -448,20 +448,6 @@ def validate_config(cfg):
|
|||||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.adapter
|
|
||||||
and cfg.tokens
|
|
||||||
and (
|
|
||||||
not cfg.lora_modules_to_save
|
|
||||||
or not all(
|
|
||||||
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -137,23 +136,6 @@ def load_tokenizer(cfg):
|
|||||||
|
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
for k, val in cfg.special_tokens.items():
|
for k, val in cfg.special_tokens.items():
|
||||||
# check if new special token is not already in tokenizer and
|
|
||||||
# is adapter training to make sure lora_modules_to_save is set
|
|
||||||
if (
|
|
||||||
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
|
||||||
and cfg.adapter
|
|
||||||
and (
|
|
||||||
not cfg.lora_modules_to_save
|
|
||||||
or not all(
|
|
||||||
x in cfg.lora_modules_to_save
|
|
||||||
for x in ["embed_tokens", "lm_head"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||||
)
|
)
|
||||||
@@ -187,12 +169,6 @@ def load_tokenizer(cfg):
|
|||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
|
|
||||||
if cfg.chat_template:
|
|
||||||
tokenizer.chat_template = chat_templates(cfg.chat_template)
|
|
||||||
else:
|
|
||||||
LOG.info(
|
|
||||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
|
||||||
)
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -254,6 +230,17 @@ def load_model(
|
|||||||
|
|
||||||
LOG.info("patching with sdp attention")
|
LOG.info("patching with sdp attention")
|
||||||
hijack_llama_sdp_attention()
|
hijack_llama_sdp_attention()
|
||||||
|
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||||
|
MEM_TOKEN,
|
||||||
|
patch_llama_with_landmark_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with landmark attention")
|
||||||
|
patch_llama_with_landmark_attn()
|
||||||
|
|
||||||
|
# Note: This might overwrite previous additional_special_tokens
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||||
|
|
||||||
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||||
@@ -275,6 +262,14 @@ def load_model(
|
|||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mixtral_attn_with_multipack_flash_attn()
|
replace_mixtral_attn_with_multipack_flash_attn()
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||||
|
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||||
|
replace_llama_rope_with_xpos_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with xpos rope")
|
||||||
|
replace_llama_rope_with_xpos_rope()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.is_llama_derived_model
|
cfg.is_llama_derived_model
|
||||||
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
||||||
@@ -308,20 +303,13 @@ def load_model(
|
|||||||
**model_config.quantization_config
|
**model_config.quantization_config
|
||||||
)
|
)
|
||||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||||
bnb_config = {
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"llm_int8_threshold": 6.0,
|
|
||||||
"llm_int8_has_fp16_weight": False,
|
|
||||||
"bnb_4bit_compute_dtype": cfg.torch_dtype,
|
|
||||||
"bnb_4bit_use_double_quant": True,
|
|
||||||
"bnb_4bit_quant_type": "nf4",
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.bnb_config_kwargs:
|
|
||||||
bnb_config.update(cfg.bnb_config_kwargs)
|
|
||||||
|
|
||||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
load_in_4bit=True,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
# sample packing uses custom FA2 patch
|
# sample packing uses custom FA2 patch
|
||||||
if cfg.flash_attention:
|
if cfg.flash_attention:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import unittest
|
import unittest
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -26,50 +25,6 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
|||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
test_data = {
|
|
||||||
"multi_turn_sys": {
|
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "lorem"},
|
|
||||||
{"from": "human", "value": "abc"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
{"from": "human", "value": "123"},
|
|
||||||
{"from": "gpt", "value": "sit"},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"single_turn_sys": {
|
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "lorem"},
|
|
||||||
{"from": "human", "value": "abc"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"single_turn_no_sys": {
|
|
||||||
"conversations": [
|
|
||||||
{"from": "human", "value": "abc"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"multi_turn_no_sys": {
|
|
||||||
"conversations": [
|
|
||||||
{"from": "human", "value": "abc"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
{"from": "human", "value": "123"},
|
|
||||||
{"from": "gpt", "value": "sit"},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_strat(conversation, tokenizer):
|
|
||||||
"Helper function to create a prompt strategy for testing."
|
|
||||||
prompter = ShareGPTPrompterV2(conversation=conversation)
|
|
||||||
return ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
@@ -161,68 +116,74 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
|
|
||||||
def test_sharegpt_llama(self):
|
def test_sharegpt_llama(self):
|
||||||
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
||||||
strat = prompt_strat("llama-2", self.tokenizer)
|
prompter = ShareGPTPrompterV2(conversation="llama-2")
|
||||||
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
|
prompter,
|
||||||
|
self.tokenizer,
|
||||||
|
False,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
|
||||||
def tokenize(conv):
|
def tokenize(conv):
|
||||||
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
return strat.tokenize_prompt(conv)["input_ids"]
|
||||||
|
|
||||||
def decode(ids):
|
def decode(ids):
|
||||||
return strat.tokenizer.decode(ids)
|
return strat.tokenizer.decode(ids)
|
||||||
|
|
||||||
|
# Multi-turn conversations
|
||||||
|
multi_turn_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "123"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
}
|
||||||
# fmt: off
|
# fmt: off
|
||||||
# System message, multi-turn conversations
|
mt_ids = tokenize(multi_turn_conv)
|
||||||
mt_ids = tokenize(test_data['multi_turn_sys'])
|
|
||||||
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||||
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
|
||||||
# System message, single-turn conversations
|
# Single-turn conversations
|
||||||
st_ids = tokenize(test_data['single_turn_sys'])
|
single_turn_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
st_ids = tokenize(single_turn_conv)
|
||||||
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
||||||
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
# No system message, single-turn
|
# No system message, single-turn
|
||||||
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
no_sys_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
ns_ids = tokenize(no_sys_conv)
|
||||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
# No system message, multi-turn
|
# No system message, multi-turn
|
||||||
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
no_sys_mt_conv = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "123"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
ns_mt_ids = tokenize(no_sys_mt_conv)
|
||||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def test_sharegpt_mistral(self):
|
|
||||||
"Make sure the sharegpt/mistral is tokenized and formatted correctly."
|
|
||||||
strat = prompt_strat("mistral", self.tokenizer)
|
|
||||||
|
|
||||||
def tokenize(conv):
|
|
||||||
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
|
||||||
|
|
||||||
def decode(ids):
|
|
||||||
return strat.tokenizer.decode(ids)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# System message, multi-turn conversations
|
|
||||||
mt_ids = tokenize(test_data['multi_turn_sys'])
|
|
||||||
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
|
||||||
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
|
||||||
|
|
||||||
# System message, single-turn conversations
|
|
||||||
st_ids = tokenize(test_data['single_turn_sys'])
|
|
||||||
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
|
|
||||||
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
|
||||||
|
|
||||||
# No system message, single-turn
|
|
||||||
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
|
||||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
|
||||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
|
||||||
|
|
||||||
# No system message, multi-turn
|
|
||||||
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
|
||||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
|
||||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_sharegpt_changes_roles(self):
|
def test_sharegpt_changes_roles(self):
|
||||||
conversation = {
|
conversation = {
|
||||||
"roles": ["USER", "CHARACTER"],
|
"roles": ["USER", "CHARACTER"],
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ Test cases for the tokenizer loading
|
|||||||
"""
|
"""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
|
||||||
@@ -33,40 +31,6 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert "Fast" not in tokenizer.__class__.__name__
|
assert "Fast" not in tokenizer.__class__.__name__
|
||||||
|
|
||||||
def test_special_tokens_modules_to_save(self):
|
|
||||||
# setting special_tokens to new token
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
"adapter": "lora",
|
|
||||||
"special_tokens": {"bos_token": "[INST]"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=r".*Please set lora_modules_to_save*",
|
|
||||||
):
|
|
||||||
load_tokenizer(cfg)
|
|
||||||
|
|
||||||
# setting special_tokens but not changing from default
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
"adapter": "lora",
|
|
||||||
"special_tokens": {"bos_token": "<s>"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
load_tokenizer(cfg)
|
|
||||||
|
|
||||||
# non-adapter setting special_tokens
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
|
||||||
"special_tokens": {"bos_token": "[INST]"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
load_tokenizer(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -682,43 +682,6 @@ class ValidationTest(unittest.TestCase):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_add_tokens_adapter(self):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
|
||||||
):
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"adapter": "qlora",
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"tokens": ["<|imstart|>"],
|
|
||||||
"lora_modules_to_save": ["embed_tokens"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
|
||||||
):
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"adapter": "qlora",
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"tokens": ["<|imstart|>"],
|
|
||||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationWandbTest(ValidationTest):
|
class ValidationWandbTest(ValidationTest):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user