Compare commits
17 Commits
yayi2
...
NanoCode01
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ecc3a408c | ||
|
|
9ca358b671 | ||
|
|
553c80f79a | ||
|
|
eb4c99431b | ||
|
|
cbdbf9e6e5 | ||
|
|
bdfefaf054 | ||
|
|
63fb3eb426 | ||
|
|
31d23504a5 | ||
|
|
f243c2186d | ||
|
|
59b2d302c8 | ||
|
|
bcc78d8fa3 | ||
|
|
74532ddc45 | ||
|
|
8ba27f3bde | ||
|
|
a3e8783328 | ||
|
|
b31038aae9 | ||
|
|
c75f916745 | ||
|
|
4d2e842e46 |
46
.github/workflows/tests-docker.yml
vendored
Normal file
46
.github/workflows/tests-docker.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
name: e2e-docker-tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- '**.py'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '.github/workflows/*.yml'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-axolotl:
|
||||||
|
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||||
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 118
|
||||||
|
cuda_version: 11.8.0
|
||||||
|
python_version: "3.10"
|
||||||
|
pytorch: 2.0.1
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.10"
|
||||||
|
pytorch: 2.1.1
|
||||||
|
runs-on: [self-hosted, gpu, docker]
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Build Docker image
|
||||||
|
run: |
|
||||||
|
# Set up build arguments
|
||||||
|
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
|
||||||
|
CUDA="${{ matrix.cuda }}"
|
||||||
|
PYTORCH_VERSION="${{ matrix.pytorch }}"
|
||||||
|
# Build the Docker image
|
||||||
|
docker build . \
|
||||||
|
--file ./docker/Dockerfile \
|
||||||
|
--build-arg BASE_TAG=$BASE_TAG \
|
||||||
|
--build-arg CUDA=$CUDA \
|
||||||
|
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
|
||||||
|
--tag test-axolotl
|
||||||
|
- name: Unit Tests w docker image
|
||||||
|
run: |
|
||||||
|
docker run --rm test-axolotl pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
28
README.md
28
README.md
@@ -550,6 +550,11 @@ tf32: true # require >=ampere
|
|||||||
bfloat16: true # require >=ampere
|
bfloat16: true # require >=ampere
|
||||||
float16: true
|
float16: true
|
||||||
|
|
||||||
|
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
|
||||||
|
gpu_memory_limit: 20GiB
|
||||||
|
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
||||||
|
lora_on_cpu: true
|
||||||
|
|
||||||
# A list of one or more datasets to finetune the model with
|
# A list of one or more datasets to finetune the model with
|
||||||
datasets:
|
datasets:
|
||||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||||
@@ -643,7 +648,8 @@ max_memory:
|
|||||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||||
adapter: lora
|
adapter: lora
|
||||||
# If you already have a lora model trained that you want to load, put that here.
|
# If you already have a lora model trained that you want to load, put that here.
|
||||||
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
# This means after training, if you want to test the model, you should set this to the value of `output_dir`.
|
||||||
|
# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`.
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
# LoRA hyperparameters
|
# LoRA hyperparameters
|
||||||
@@ -670,10 +676,6 @@ lora_modules_to_save:
|
|||||||
# - embed_tokens
|
# - embed_tokens
|
||||||
# - lm_head
|
# - lm_head
|
||||||
|
|
||||||
# Once you complete training, the model will be saved to the following directory.
|
|
||||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
|
||||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
|
||||||
lora_out_dir:
|
|
||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
|
|
||||||
# ReLoRA configuration
|
# ReLoRA configuration
|
||||||
@@ -741,6 +743,9 @@ group_by_length: false
|
|||||||
|
|
||||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
|
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||||
|
# gradient_checkpointing_kwargs:
|
||||||
|
# use_reentrant: false
|
||||||
|
|
||||||
# Stop training after this many evaluation losses have increased in a row
|
# Stop training after this many evaluation losses have increased in a row
|
||||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||||
@@ -932,8 +937,9 @@ accelerate launch -m axolotl.cli.train your_config.yml
|
|||||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||||
This is recommended for large datasets.
|
This is recommended for large datasets.
|
||||||
|
|
||||||
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
- Set `dataset_prepared_path:` to a local folder for saving and loading pre-tokenized dataset.
|
||||||
- Use `--debug` to see preprocessed examples.
|
- (Optional): Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
||||||
|
- (Optional): Use `--debug` to see preprocessed examples.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m axolotl.cli.preprocess your_config.yml
|
python -m axolotl.cli.preprocess your_config.yml
|
||||||
@@ -1035,18 +1041,20 @@ Please use `--sample_packing False` if you have it on and receive the error simi
|
|||||||
|
|
||||||
### Merge LORA to base
|
### Merge LORA to base
|
||||||
|
|
||||||
Add below flag to train command above
|
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model"
|
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
|
||||||
```
|
```
|
||||||
|
|
||||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
You may need to use the `gpu_memory_limit` and/or `lora_on_cpu` config options to avoid running out of memory. If you still run out of CUDA memory, you can try to merge in system RAM with
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
although this will be very slow, and using the config options above are recommended instead.
|
||||||
|
|
||||||
## Common Errors 🧰
|
## Common Errors 🧰
|
||||||
|
|
||||||
See also the [FAQ's](./docs/faq.md).
|
See also the [FAQ's](./docs/faq.md).
|
||||||
|
|||||||
35
docs/rlhf.md
Normal file
35
docs/rlhf.md
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# RLHF (Beta)
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
|
||||||
|
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
||||||
|
feedback. Various methods include, but not limited to:
|
||||||
|
|
||||||
|
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||||
|
- Direct Preference Optimization (DPO)
|
||||||
|
- Identity Preference Optimization (IPO)
|
||||||
|
|
||||||
|
|
||||||
|
### RLHF using Axolotl
|
||||||
|
|
||||||
|
[!IMPORTANT]
|
||||||
|
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||||
|
|
||||||
|
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
|
||||||
|
|
||||||
|
#### DPO
|
||||||
|
```yaml
|
||||||
|
rl: true
|
||||||
|
datasets:
|
||||||
|
- path: Intel/orca_dpo_pairs
|
||||||
|
split: train
|
||||||
|
type: intel_apply_chatml
|
||||||
|
- path: argilla/ultrafeedback-binarized-preferences
|
||||||
|
split: train
|
||||||
|
type: argilla_apply_chatml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### IPO
|
||||||
|
```yaml
|
||||||
|
rl: ipo
|
||||||
|
```
|
||||||
17
examples/tiny-llama/README.md
Normal file
17
examples/tiny-llama/README.md
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# Overview
|
||||||
|
|
||||||
|
This is a simple example of how to finetune TinyLlama1.1B using either lora or qlora:
|
||||||
|
|
||||||
|
LoRa:
|
||||||
|
|
||||||
|
```
|
||||||
|
accelerate launch -m axolotl.cli.train examples/tiny-llama/lora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
qLoRa:
|
||||||
|
|
||||||
|
```
|
||||||
|
accelerate launch -m axolotl.cli.train examples/tiny-llama/qlora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
Both take about 10 minutes to complete on a 4090.
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
|
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
||||||
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -17,6 +16,7 @@ output_dir: ./lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -55,7 +55,6 @@ flash_attention: true
|
|||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
@@ -63,6 +62,3 @@ weight_decay: 0.0
|
|||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
special_tokens:
|
special_tokens:
|
||||||
bos_token: "<s>"
|
|
||||||
eos_token: "</s>"
|
|
||||||
unk_token: "<unk>"
|
|
||||||
58
examples/tiny-llama/pretrain.yml
Normal file
58
examples/tiny-llama/pretrain.yml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||||
|
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_llama_derived_model: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
max_steps: 200
|
||||||
|
pretraining_dataset:
|
||||||
|
path: c4
|
||||||
|
name: en
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./model-out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch:
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
66
examples/tiny-llama/qlora.yml
Normal file
66
examples/tiny-llama/qlora.yml
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_llama_derived_model: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: paged_adamw_32bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
auto-gptq==0.5.1
|
auto-gptq==0.5.1
|
||||||
packaging
|
packaging
|
||||||
peft==0.6.0
|
peft==0.6.0
|
||||||
transformers==4.36.2
|
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.24.1
|
accelerate==0.24.1
|
||||||
@@ -37,3 +37,5 @@ tensorboard
|
|||||||
s3fs
|
s3fs
|
||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
|
trl @ git+https://github.com/huggingface/trl.git@main
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
@@ -16,6 +17,7 @@ import yaml
|
|||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
from accelerate.commands.config import config_args
|
from accelerate.commands.config import config_args
|
||||||
from art import text2art
|
from art import text2art
|
||||||
|
from datasets import concatenate_datasets, load_dataset
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
@@ -71,7 +73,7 @@ def do_merge_lora(
|
|||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
LOG.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload(progressbar=True)
|
||||||
model.to(dtype=cfg.torch_dtype)
|
model.to(dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
@@ -79,6 +81,7 @@ def do_merge_lora(
|
|||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
str(Path(cfg.output_dir) / "merged"),
|
str(Path(cfg.output_dir) / "merged"),
|
||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
|
progressbar=True,
|
||||||
)
|
)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|
||||||
@@ -325,6 +328,94 @@ def load_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_rl_datasets(
|
||||||
|
*,
|
||||||
|
cfg: DictDefault,
|
||||||
|
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
|
||||||
|
) -> TrainDatasetMeta:
|
||||||
|
train_datasets: List[Any] = []
|
||||||
|
for i, ds_cfg in enumerate(cfg.datasets):
|
||||||
|
train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
|
||||||
|
# eval_dataset = load_dataset(
|
||||||
|
# cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
|
||||||
|
# )
|
||||||
|
eval_dataset = None
|
||||||
|
|
||||||
|
def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||||
|
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||||
|
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||||
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||||
|
if "system" in sample and sample["system"]:
|
||||||
|
sample["prompt"] = (
|
||||||
|
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||||
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample[
|
||||||
|
"prompt"
|
||||||
|
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||||
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
for i, data_set in enumerate(train_datasets):
|
||||||
|
_type = cfg.datasets[i]["type"]
|
||||||
|
ds_type_fn = locals()[_type]
|
||||||
|
train_datasets[i] = data_set.map(ds_type_fn)
|
||||||
|
train_dataset = concatenate_datasets(train_datasets)
|
||||||
|
|
||||||
|
# eval_dataset = eval_dataset.map(intel_apply_chatml)
|
||||||
|
|
||||||
|
total_num_steps = int(
|
||||||
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainDatasetMeta(
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
total_num_steps=total_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_accelerate_default_config():
|
def check_accelerate_default_config():
|
||||||
if Path(config_args.default_yaml_config_file).exists():
|
if Path(config_args.default_yaml_config_file).exists():
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
|
|||||||
@@ -25,9 +25,16 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
flash_attention=False,
|
flash_attention=False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
|
||||||
|
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
|
||||||
|
if not Path(parsed_cfg.lora_model_dir).exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
|
||||||
|
)
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from axolotl.cli import (
|
|||||||
check_user_token,
|
check_user_token,
|
||||||
load_cfg,
|
load_cfg,
|
||||||
load_datasets,
|
load_datasets,
|
||||||
|
load_rl_datasets,
|
||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
@@ -30,7 +31,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
if parsed_cfg.rl:
|
||||||
|
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
else:
|
||||||
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
@@ -59,6 +60,12 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||||
)
|
)
|
||||||
|
pretraining: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||||
|
},
|
||||||
|
)
|
||||||
sample_packing: bool = field(
|
sample_packing: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use sample packing for efficient training."},
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
@@ -156,7 +163,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.sample_packing:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
RandomSampler(self.train_dataset),
|
RandomSampler(self.train_dataset),
|
||||||
self.args.train_batch_size,
|
self.args.train_batch_size,
|
||||||
@@ -192,7 +199,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
if self.args.sample_packing:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
train_dataset = self.train_dataset
|
train_dataset = self.train_dataset
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
data_collator = self.data_collator
|
data_collator = self.data_collator
|
||||||
@@ -420,12 +427,21 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
_train_dataset = None
|
_train_dataset = None
|
||||||
_eval_dataset = None
|
_eval_dataset = None
|
||||||
|
_model_ref = None
|
||||||
|
|
||||||
def __init__(self, cfg, model, tokenizer):
|
def __init__(self, cfg, model, tokenizer):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_ref(self):
|
||||||
|
return self._model_ref
|
||||||
|
|
||||||
|
@model_ref.setter
|
||||||
|
def model_ref(self, model):
|
||||||
|
self._model_ref = model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def train_dataset(self):
|
def train_dataset(self):
|
||||||
return self._train_dataset
|
return self._train_dataset
|
||||||
@@ -566,6 +582,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"gradient_checkpointing"
|
"gradient_checkpointing"
|
||||||
] = self.cfg.gradient_checkpointing
|
] = self.cfg.gradient_checkpointing
|
||||||
|
if self.cfg.gradient_checkpointing_kwargs:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"gradient_checkpointing_kwargs"
|
||||||
|
] = self.cfg.gradient_checkpointing_kwargs
|
||||||
|
else:
|
||||||
|
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
|
||||||
|
"use_reentrant": False
|
||||||
|
}
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
@@ -593,6 +617,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||||
training_arguments_kwargs["push_to_hub"] = True
|
training_arguments_kwargs["push_to_hub"] = True
|
||||||
training_arguments_kwargs["hub_private_repo"] = True
|
training_arguments_kwargs["hub_private_repo"] = True
|
||||||
|
training_arguments_kwargs["hub_always_push"] = True
|
||||||
|
|
||||||
if self.cfg.hub_strategy:
|
if self.cfg.hub_strategy:
|
||||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||||
@@ -749,6 +774,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
|
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||||
|
|
||||||
if self.cfg.neftune_noise_alpha is not None:
|
if self.cfg.neftune_noise_alpha is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
@@ -789,7 +815,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=self.build_collator(**data_collator_kwargs),
|
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@@ -810,7 +836,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def build_collator(self, **kwargs):
|
def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
|
||||||
|
if training_args.pretraining:
|
||||||
|
return None
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|
||||||
@@ -819,3 +848,96 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||||
|
"""
|
||||||
|
Trainer factory class for DPO Trainer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_callbacks(self):
|
||||||
|
callbacks = []
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
|
callbacks = []
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def build_training_arguments(self, total_num_steps):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
for arg in [
|
||||||
|
"adam_beta1",
|
||||||
|
"adam_beta2",
|
||||||
|
"adam_epsilon",
|
||||||
|
"dataloader_num_workers",
|
||||||
|
"dataloader_pin_memory",
|
||||||
|
]:
|
||||||
|
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||||
|
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
|
max_steps=total_num_steps,
|
||||||
|
remove_unused_columns=False,
|
||||||
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
|
learning_rate=self.cfg.learning_rate,
|
||||||
|
evaluation_strategy="no",
|
||||||
|
# eval_steps=self.cfg.eval_steps,
|
||||||
|
save_strategy="steps",
|
||||||
|
save_steps=self.cfg.save_steps,
|
||||||
|
output_dir=self.cfg.output_dir,
|
||||||
|
warmup_steps=self.cfg.warmup_steps,
|
||||||
|
bf16=True,
|
||||||
|
gradient_checkpointing=self.cfg.gradient_checkpointing,
|
||||||
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||||
|
logging_first_step=True,
|
||||||
|
logging_steps=1,
|
||||||
|
optim=self.cfg.optimizer,
|
||||||
|
save_total_limit=self.cfg.save_total_limit or 5,
|
||||||
|
**training_args_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return training_args
|
||||||
|
|
||||||
|
def build(self, total_num_steps):
|
||||||
|
training_args = self.build_training_arguments(total_num_steps)
|
||||||
|
dpo_trainer_kwargs = {}
|
||||||
|
if self.cfg.rl == "ipo":
|
||||||
|
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||||
|
if self.cfg.dpo_label_smoothing:
|
||||||
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
|
|
||||||
|
dpo_trainer = DPOTrainer(
|
||||||
|
self.model,
|
||||||
|
self.model_ref,
|
||||||
|
args=training_args,
|
||||||
|
beta=self.cfg.dpo_beta or 0.1,
|
||||||
|
train_dataset=self.train_dataset,
|
||||||
|
# eval_dataset=self.eval_dataset,
|
||||||
|
eval_dataset=None,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
max_length=self.cfg.sequence_len,
|
||||||
|
max_target_length=None,
|
||||||
|
max_prompt_length=self.cfg.sequence_len,
|
||||||
|
generate_during_eval=True,
|
||||||
|
**dpo_trainer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dpo_trainer
|
||||||
|
|
||||||
|
|
||||||
|
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||||
|
"""
|
||||||
|
HF Factory class for PPO Trainer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_callbacks(self):
|
||||||
|
callbacks = []
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
|
callbacks = []
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def build(self, total_num_steps):
|
||||||
|
# build PPOConfig
|
||||||
|
pass
|
||||||
|
|||||||
0
src/axolotl/core/trainers/__init__.py
Normal file
0
src/axolotl/core/trainers/__init__.py
Normal file
66
src/axolotl/core/trainers/trl.py
Normal file
66
src/axolotl/core/trainers/trl.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
module for TRL PPO training
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from trl import PPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class TRLPPOTrainer(PPOTrainer):
|
||||||
|
"""
|
||||||
|
wrapper for ppo trainer to handle customizations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
reward_pipe,
|
||||||
|
resume_from_checkpoint=None, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
generation_kwargs = {
|
||||||
|
"min_length": -1,
|
||||||
|
"top_k": 0.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"do_sample": True,
|
||||||
|
"pad_token_id": self.tokenizer.eos_token_id,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
}
|
||||||
|
sent_kwargs = {
|
||||||
|
"return_all_scores": True,
|
||||||
|
"function_to_apply": "none",
|
||||||
|
"batch_size": 16,
|
||||||
|
}
|
||||||
|
|
||||||
|
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
||||||
|
enumerate(self.dataloader)
|
||||||
|
):
|
||||||
|
query_tensors = batch["input_ids"]
|
||||||
|
|
||||||
|
# generate model response
|
||||||
|
response_tensors, ref_response_tensors = self.generate(
|
||||||
|
query_tensors,
|
||||||
|
return_prompt=False,
|
||||||
|
generate_ref_response=True,
|
||||||
|
**generation_kwargs
|
||||||
|
)
|
||||||
|
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
||||||
|
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
||||||
|
|
||||||
|
# Compute sentiment score
|
||||||
|
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||||
|
pipe_outputs = reward_pipe(texts, **sent_kwargs)
|
||||||
|
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
||||||
|
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
|
||||||
|
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
|
||||||
|
ref_rewards = [
|
||||||
|
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
|
||||||
|
]
|
||||||
|
batch["ref_rewards"] = ref_rewards
|
||||||
|
|
||||||
|
# Run PPO step
|
||||||
|
stats = self.step(query_tensors, response_tensors, rewards)
|
||||||
|
self.log_stats(
|
||||||
|
stats,
|
||||||
|
batch,
|
||||||
|
rewards,
|
||||||
|
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||||
|
)
|
||||||
@@ -147,6 +147,15 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role + "\n", ""
|
yield role + "\n", ""
|
||||||
return
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.CHATGLM3:
|
||||||
|
if self.system_message:
|
||||||
|
yield "", system_prompt
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + "\n", " " + message
|
||||||
|
else:
|
||||||
|
yield role
|
||||||
|
return
|
||||||
if self.sep_style == SeparatorStyle.CHATINTERN:
|
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||||
seps = [self.sep, self.sep2]
|
seps = [self.sep, self.sep2]
|
||||||
|
|||||||
@@ -17,6 +17,6 @@ def replace_mixtral_attn_with_multipack_flash_attn():
|
|||||||
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
||||||
mixtral_model_forward
|
mixtral_model_forward
|
||||||
)
|
)
|
||||||
transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
|
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
] = MixtralMultipackFlashAttention2
|
] = MixtralMultipackFlashAttention2
|
||||||
|
|||||||
@@ -261,7 +261,11 @@ def mixtral_model_forward(
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
if (
|
||||||
|
attention_mask is not None
|
||||||
|
and self._attn_implementation == "flash_attention_2"
|
||||||
|
and use_cache
|
||||||
|
):
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
if is_padding_right:
|
if is_padding_right:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -270,7 +274,7 @@ def mixtral_model_forward(
|
|||||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._use_flash_attention_2:
|
if self._attn_implementation == "flash_attention_2":
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
attention_mask
|
attention_mask
|
||||||
|
|||||||
@@ -61,6 +61,12 @@ def train(
|
|||||||
msg += " and peft_config..."
|
msg += " and peft_config..."
|
||||||
LOG.debug(msg)
|
LOG.debug(msg)
|
||||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
|
model_ref = None
|
||||||
|
if cfg.rl:
|
||||||
|
# load the model again for model_ref/baseline
|
||||||
|
model_ref, _ = load_model(
|
||||||
|
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
||||||
|
)
|
||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
@@ -83,7 +89,7 @@ def train(
|
|||||||
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(model, "config"):
|
if hasattr(model, "config"):
|
||||||
@@ -182,6 +188,9 @@ def train(
|
|||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||||
|
elif cfg.hub_model_id:
|
||||||
|
# defensively push to the hub to ensure the model card is updated
|
||||||
|
trainer.push_to_hub()
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -178,3 +178,24 @@ class MambaDataCollator:
|
|||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
"""
|
||||||
|
Collator for multipack specific to the using the BatchSampler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
chunked_data = {}
|
||||||
|
for feature in features.keys():
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
if feature == "attention_mask":
|
||||||
|
arrays = [(1) * np.array(item) for item in features[feature]]
|
||||||
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
|
else:
|
||||||
|
arrays = [np.array(item) for item in features[feature]]
|
||||||
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
|
features = [chunked_data]
|
||||||
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|||||||
@@ -422,11 +422,6 @@ def validate_config(cfg):
|
|||||||
if cfg.warmup_steps and cfg.warmup_ratio:
|
if cfg.warmup_steps and cfg.warmup_ratio:
|
||||||
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
||||||
|
|
||||||
if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
|
|
||||||
LOG.warning(
|
|
||||||
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||||
cfg.wandb_name = cfg.wandb_run_id
|
cfg.wandb_name = cfg.wandb_run_id
|
||||||
|
|
||||||
@@ -462,6 +457,11 @@ def validate_config(cfg):
|
|||||||
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ from datasets import (
|
|||||||
load_from_disk,
|
load_from_disk,
|
||||||
)
|
)
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from torch.utils.data import RandomSampler
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
@@ -39,11 +41,14 @@ from axolotl.prompters import (
|
|||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
UnsupportedPrompter,
|
UnsupportedPrompter,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
|
from axolotl.utils.samplers.multipack import MultipackBatchSampler
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
|
process_pretraining_datasets_for_packing,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -64,9 +69,17 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
path = cfg.pretraining_dataset
|
||||||
|
name = None
|
||||||
|
if isinstance(cfg.pretraining_dataset, dict):
|
||||||
|
path = cfg.pretraining_dataset["path"]
|
||||||
|
name = cfg.pretraining_dataset["name"]
|
||||||
|
|
||||||
train_dataset = load_pretraining_dataset(
|
train_dataset = load_pretraining_dataset(
|
||||||
cfg.pretraining_dataset,
|
path,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
name=name,
|
||||||
max_tokens=cfg.sequence_len,
|
max_tokens=cfg.sequence_len,
|
||||||
seed=cfg.seed or 42,
|
seed=cfg.seed or 42,
|
||||||
)
|
)
|
||||||
@@ -806,9 +819,27 @@ def encode_pretraining(
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
|
||||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
if cfg.sample_packing:
|
||||||
dataset = load_dataset(path, streaming=True, split="train")
|
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||||
|
tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
|
||||||
|
)
|
||||||
|
encode = functools.partial(
|
||||||
|
encode_packed_pretraining,
|
||||||
|
tokenizer,
|
||||||
|
collate_fn,
|
||||||
|
max_seq_length=max_tokens,
|
||||||
|
batch_size=cfg.micro_batch_size,
|
||||||
|
)
|
||||||
|
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||||
|
cfg.micro_batch_size = 1
|
||||||
|
else:
|
||||||
|
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||||
|
|
||||||
|
dataset = load_dataset(path, streaming=True, split="train", name=name)
|
||||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
encode,
|
encode,
|
||||||
@@ -819,3 +850,63 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
|||||||
remove_columns=dataset.features.keys(),
|
remove_columns=dataset.features.keys(),
|
||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def encode_packed_pretraining(
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
collate_fn,
|
||||||
|
examples: List[str],
|
||||||
|
max_seq_length: int = 2048,
|
||||||
|
batch_size: int = 4,
|
||||||
|
) -> Dict[str, List]:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
# tokenize all the examples
|
||||||
|
# rows get split with stride (overlap)
|
||||||
|
res = tokenizer(
|
||||||
|
examples,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_seq_length - 1,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_overflowing_tokens=True,
|
||||||
|
stride=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
|
||||||
|
attention_mask = [seq + [1] for seq in res["attention_mask"]]
|
||||||
|
|
||||||
|
tokenized_examples = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
train_dataset = Dataset.from_dict(tokenized_examples)
|
||||||
|
train_dataset = process_pretraining_datasets_for_packing(
|
||||||
|
train_dataset, max_seq_length
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = MultipackBatchSampler(
|
||||||
|
RandomSampler(train_dataset),
|
||||||
|
batch_size=batch_size,
|
||||||
|
drop_last=True,
|
||||||
|
batch_max_len=batch_size * max_seq_length,
|
||||||
|
lengths=(
|
||||||
|
train_dataset.data.column("position_ids")
|
||||||
|
.to_pandas()
|
||||||
|
.apply(lambda x: x[-1] + 1)
|
||||||
|
.values
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
chunked_data = defaultdict(list)
|
||||||
|
|
||||||
|
for data in sampler:
|
||||||
|
features = train_dataset[data]
|
||||||
|
features["labels"] = features["input_ids"].copy()
|
||||||
|
collated_features = collate_fn(features)
|
||||||
|
|
||||||
|
for feature in features.keys():
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||||
|
|
||||||
|
return chunked_data
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple # noqa: F401
|
from typing import Any, Optional, Tuple # noqa: F401
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
@@ -200,6 +200,7 @@ def load_model(
|
|||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
|
reference_model: bool = False,
|
||||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
@@ -287,9 +288,47 @@ def load_model(
|
|||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
model_kwargs["device_map"] = cfg.device_map
|
max_memory = cfg.max_memory
|
||||||
model_kwargs["max_memory"] = cfg.max_memory
|
device_map = cfg.device_map
|
||||||
|
|
||||||
|
if cfg.gpu_memory_limit:
|
||||||
|
gpu_memory_limit = (
|
||||||
|
str(cfg.gpu_memory_limit) + "GiB"
|
||||||
|
if isinstance(cfg.gpu_memory_limit, int)
|
||||||
|
else cfg.gpu_memory_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
max_memory = {}
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
max_memory[i] = gpu_memory_limit
|
||||||
|
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
|
||||||
|
|
||||||
|
if max_memory is not None:
|
||||||
|
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
|
||||||
|
from accelerate import infer_auto_device_map, init_empty_weights
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
model_canvas = AutoModelForCausalLM.from_config(model_config)
|
||||||
|
model_canvas.tie_weights()
|
||||||
|
device_map = infer_auto_device_map(
|
||||||
|
model_canvas,
|
||||||
|
max_memory=max_memory,
|
||||||
|
dtype=cfg.torch_dtype,
|
||||||
|
)
|
||||||
|
# We can discard max_memory now as we have a device map set up for us
|
||||||
|
max_memory = None
|
||||||
|
|
||||||
|
model_kwargs["device_map"] = device_map
|
||||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||||
|
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
||||||
|
# if cfg.rl:
|
||||||
|
# if torch.cuda.device_count() > 1:
|
||||||
|
# if reference_model:
|
||||||
|
# model_kwargs["device_map"] = "cuda:" + str(
|
||||||
|
# torch.cuda.current_device() + 1
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
@@ -332,15 +371,18 @@ def load_model(
|
|||||||
or cfg.is_mistral_derived_model
|
or cfg.is_mistral_derived_model
|
||||||
or model_config.model_type == "mixtral"
|
or model_config.model_type == "mixtral"
|
||||||
):
|
):
|
||||||
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if model_config.model_type == "mixtral":
|
if model_config.model_type == "mixtral":
|
||||||
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
model_kwargs["attn_implementation"] = "eager"
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"eager"
|
"eager"
|
||||||
)
|
)
|
||||||
@@ -413,7 +455,6 @@ def load_model(
|
|||||||
model_kwargs["device"] = torch.cuda.current_device()
|
model_kwargs["device"] = torch.cuda.current_device()
|
||||||
del model_kwargs["torch_dtype"]
|
del model_kwargs["torch_dtype"]
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
del model_kwargs["max_memory"]
|
|
||||||
|
|
||||||
model = MambaLMHeadModel.from_pretrained(
|
model = MambaLMHeadModel.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -557,9 +598,11 @@ def load_model(
|
|||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(cfg.torch_dtype)
|
module.to(cfg.torch_dtype)
|
||||||
|
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
lora_config = None
|
||||||
|
if not reference_model or cfg.lora_model_dir:
|
||||||
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||||
@@ -668,10 +711,15 @@ def load_lora(model, cfg, inference=False):
|
|||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.debug("Loading pretained PEFT - LoRA")
|
LOG.debug("Loading pretained PEFT - LoRA")
|
||||||
|
model_kwargs: Any = {}
|
||||||
|
if cfg.lora_on_cpu:
|
||||||
|
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
||||||
|
model_kwargs["device_map"] = {"": "cpu"}
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.lora_model_dir,
|
cfg.lora_model_dir,
|
||||||
is_trainable=(not inference),
|
is_trainable=(not inference),
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ def check_example_labels(example, tokenizer, text_only=False):
|
|||||||
)
|
)
|
||||||
colored_tokens.append(colored_token)
|
colored_tokens.append(colored_token)
|
||||||
|
|
||||||
delimiter = "" if text_only else " "
|
output = " ".join(colored_tokens)
|
||||||
LOG.info(delimiter.join(colored_tokens))
|
LOG.info(output)
|
||||||
LOG.info("\n\n\n")
|
LOG.info("\n\n\n")
|
||||||
|
|
||||||
return " ".join(colored_tokens)
|
return output
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from accelerate.logging import get_logger
|
|||||||
from datasets import set_caching_enabled
|
from datasets import set_caching_enabled
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
||||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler
|
from axolotl.utils.samplers import MultipackBatchSampler
|
||||||
|
|
||||||
@@ -143,6 +143,16 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
|
||||||
|
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
||||||
|
|
||||||
|
train_dataset = train_dataset.filter(drop_long)
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
add_position_ids,
|
||||||
|
)
|
||||||
|
return train_dataset
|
||||||
|
|
||||||
|
|
||||||
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||||
if not cfg.total_num_tokens:
|
if not cfg.total_num_tokens:
|
||||||
total_num_tokens = np.sum(
|
total_num_tokens = np.sum(
|
||||||
@@ -280,7 +290,12 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
if cfg.rl:
|
||||||
|
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
|
trainer_builder.model_ref = model[1]
|
||||||
|
else:
|
||||||
|
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
|
|
||||||
trainer_builder.train_dataset = train_dataset
|
trainer_builder.train_dataset = train_dataset
|
||||||
trainer_builder.eval_dataset = eval_dataset
|
trainer_builder.eval_dataset = eval_dataset
|
||||||
|
|
||||||
|
|||||||
59
tests/core/test_trainer_builder.py
Normal file
59
tests/core/test_trainer_builder.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""
|
||||||
|
unit tests for axolotl.core.trainer_builder
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="cfg")
|
||||||
|
def fixture_cfg():
|
||||||
|
return DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||||
|
"model_type": "AutoModelForCausalLM",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 0.00005,
|
||||||
|
"save_steps": 100,
|
||||||
|
"output_dir": "./model-out",
|
||||||
|
"warmup_steps": 10,
|
||||||
|
"gradient_checkpointing": False,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"rl": True,
|
||||||
|
"adam_beta1": 0.998,
|
||||||
|
"adam_beta2": 0.9,
|
||||||
|
"adam_epsilon": 0.00001,
|
||||||
|
"dataloader_num_workers": 1,
|
||||||
|
"dataloader_pin_memory": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="tokenizer")
|
||||||
|
def fixture_tokenizer(cfg):
|
||||||
|
return load_tokenizer(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="model")
|
||||||
|
def fixture_model(cfg, tokenizer):
|
||||||
|
return load_model(cfg, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHFDPOTrainerBuilder:
|
||||||
|
"""
|
||||||
|
TestCase class for DPO trainer builder
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||||
|
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
|
||||||
|
training_arguments = builder.build_training_arguments(100)
|
||||||
|
assert training_arguments.adam_beta1 == 0.998
|
||||||
|
assert training_arguments.adam_beta2 == 0.9
|
||||||
|
assert training_arguments.adam_epsilon == 0.00001
|
||||||
|
assert training_arguments.dataloader_num_workers == 1
|
||||||
|
assert training_arguments.dataloader_pin_memory is True
|
||||||
109
tests/e2e/test_mixtral.py
Normal file
109
tests/e2e/test_mixtral.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for mixtral
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMixtral(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_qlora(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 16,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_dropout": 0.1,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_ft(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
123
tests/e2e/test_mixtral_samplepack.py
Normal file
123
tests/e2e/test_mixtral_samplepack.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for mixtral
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMixtral(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_qlora(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 16,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_dropout": 0.1,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
"sample_packing": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_ft(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
"sample_packing": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (
|
||||||
|
"axolotl.monkeypatch.mixtral.modeling_mixtral"
|
||||||
|
in model.model.layers[0].self_attn.__class__.__module__
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"MixtralMultipackFlashAttention2"
|
||||||
|
in model.model.layers[0].self_attn.__class__.__name__
|
||||||
|
)
|
||||||
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
99
tests/e2e/test_model_patches.py
Normal file
99
tests/e2e/test_model_patches.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""
|
||||||
|
E2E smoke tests to check that the monkeypatches are in place for certain configurations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelPatches(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
TestCases for the multipack monkey patches
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_mixtral_multipack(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"axolotl.monkeypatch.mixtral.modeling_mixtral"
|
||||||
|
in model.model.layers[0].self_attn.__class__.__module__
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"MixtralMultipackFlashAttention2"
|
||||||
|
in model.model.layers[0].self_attn.__class__.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_mistral_multipack(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"axolotl.monkeypatch.mistral_attn_hijack_flash"
|
||||||
|
in model.model.layers[0].self_attn.forward.__module__
|
||||||
|
)
|
||||||
82
tests/test_packed_pretraining.py
Normal file
82
tests/test_packed_pretraining.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Module for testing streaming dataset sequence packing"""
|
||||||
|
import unittest
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.data import encode_packed_pretraining
|
||||||
|
|
||||||
|
|
||||||
|
class TestPacking(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test class for packing streaming dataset sequences
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
|
self.tokenizer.pad_token = "</s>"
|
||||||
|
self.max_seq_length = 2048
|
||||||
|
self.batch_size = 2
|
||||||
|
|
||||||
|
def test_packing_stream_dataset(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
dataset = load_dataset(
|
||||||
|
"c4",
|
||||||
|
"en",
|
||||||
|
streaming=True,
|
||||||
|
)["train"]
|
||||||
|
|
||||||
|
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||||
|
self.tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
pad_to_multiple_of=self.max_seq_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
encode = partial(
|
||||||
|
encode_packed_pretraining,
|
||||||
|
self.tokenizer,
|
||||||
|
collate_fn,
|
||||||
|
max_seq_length=self.max_seq_length,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
encode,
|
||||||
|
batched=True,
|
||||||
|
input_columns="text",
|
||||||
|
remove_columns=dataset.features.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer_loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=1,
|
||||||
|
collate_fn=None,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
idx = 0
|
||||||
|
for data in trainer_loader:
|
||||||
|
if idx > 10:
|
||||||
|
break
|
||||||
|
assert data["input_ids"].shape == torch.Size(
|
||||||
|
[1, self.batch_size * self.max_seq_length]
|
||||||
|
)
|
||||||
|
assert data["position_ids"].shape == torch.Size(
|
||||||
|
[1, self.batch_size * self.max_seq_length]
|
||||||
|
)
|
||||||
|
assert data["labels"].shape == torch.Size(
|
||||||
|
[1, self.batch_size * self.max_seq_length]
|
||||||
|
)
|
||||||
|
assert data["attention_mask"].shape == torch.Size(
|
||||||
|
[1, self.batch_size * self.max_seq_length]
|
||||||
|
)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user