Compare commits
36 Commits
save_only_
...
sppo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a9ac4ad27 | ||
|
|
027f7d54f0 | ||
|
|
0554105baa | ||
|
|
f58fcd09ec | ||
|
|
60fecac367 | ||
|
|
b301068098 | ||
|
|
df645906eb | ||
|
|
7fea5822f0 | ||
|
|
3367fca732 | ||
|
|
1ac899800b | ||
|
|
70185763f6 | ||
|
|
120b809465 | ||
|
|
29cf15a28c | ||
|
|
dde02fcb94 | ||
|
|
b9bb169602 | ||
|
|
601c08b4c2 | ||
|
|
cc5d31e0d9 | ||
|
|
1aeece6e24 | ||
|
|
5294653a2d | ||
|
|
98c25e15cb | ||
|
|
68601ec6ad | ||
|
|
60f5ce0569 | ||
|
|
7477a53287 | ||
|
|
7d1d22f72f | ||
|
|
0e8f340945 | ||
|
|
59ef25470c | ||
|
|
c10563c444 | ||
|
|
37c037c69d | ||
|
|
15f7910d33 | ||
|
|
d28ba2e405 | ||
|
|
0eadfc8c86 | ||
|
|
bcaa92325d | ||
|
|
7d9bafcb88 | ||
|
|
e07dcb288c | ||
|
|
6319da1f9b | ||
|
|
132eb740f0 |
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -32,6 +32,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "121"
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|||||||
1
.github/workflows/lint.yml
vendored
1
.github/workflows/lint.yml
vendored
@@ -7,6 +7,7 @@ on:
|
|||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
- "*.md"
|
- "*.md"
|
||||||
|
- "examples/**/*.y[a]?ml"
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -30,6 +30,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -86,6 +91,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -29,6 +29,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -86,6 +91,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -133,6 +133,7 @@ venv/
|
|||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
|
venv3.10/
|
||||||
|
|
||||||
# Spyder project settings
|
# Spyder project settings
|
||||||
.spyderproject
|
.spyderproject
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ Features:
|
|||||||
- Advanced Topics
|
- Advanced Topics
|
||||||
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
|
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
- [Common Errors](#common-errors-)
|
- [Common Errors](#common-errors-)
|
||||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||||
- [Debugging Axolotl](#debugging-axolotl)
|
- [Debugging Axolotl](#debugging-axolotl)
|
||||||
@@ -81,6 +82,7 @@ Features:
|
|||||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
@@ -425,7 +427,7 @@ deepspeed: deepspeed_configs/zero1.json
|
|||||||
```
|
```
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
|
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
|
||||||
```
|
```
|
||||||
|
|
||||||
##### FSDP
|
##### FSDP
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
{
|
{
|
||||||
|
"zero_force_ds_cpu_optimizer": false,
|
||||||
|
"zero_allow_untested_optimizer": true,
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 3,
|
"stage": 3,
|
||||||
"offload_optimizer": {
|
"offload_optimizer": {
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
{
|
{
|
||||||
|
"zero_force_ds_cpu_optimizer": false,
|
||||||
|
"zero_allow_untested_optimizer": true,
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 3,
|
"stage": 3,
|
||||||
"offload_param": {
|
"offload_param": {
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard'
|
||||||
rl:
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
@@ -227,6 +227,12 @@ lora_modules_to_save:
|
|||||||
|
|
||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
|
|
||||||
|
# LoRA+ hyperparameters
|
||||||
|
# For more details about the following options, see:
|
||||||
|
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||||
|
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
|
||||||
|
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
|
||||||
|
|
||||||
peft:
|
peft:
|
||||||
# Configuration options for loftq initialization for LoRA
|
# Configuration options for loftq initialization for LoRA
|
||||||
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
||||||
@@ -268,6 +274,7 @@ torch_compile_backend: # Optional[str]
|
|||||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||||
|
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size:
|
eval_batch_size:
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
@@ -412,6 +419,7 @@ special_tokens:
|
|||||||
# bos_token: "<s>"
|
# bos_token: "<s>"
|
||||||
# eos_token: "</s>"
|
# eos_token: "</s>"
|
||||||
# unk_token: "<unk>"
|
# unk_token: "<unk>"
|
||||||
|
# pad_token: "[PAD]"
|
||||||
|
|
||||||
# Add extra tokens.
|
# Add extra tokens.
|
||||||
tokens:
|
tokens:
|
||||||
|
|||||||
35
docs/dataset_preprocessing.qmd
Normal file
35
docs/dataset_preprocessing.qmd
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
---
|
||||||
|
title: Dataset Preprocessing
|
||||||
|
description: How datasets are processed
|
||||||
|
---
|
||||||
|
|
||||||
|
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||||
|
the (dataset format)[../dataset-formats/] and prompt strategies to:
|
||||||
|
- parse the dataset based on the *dataset format*
|
||||||
|
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||||
|
- tokenize the dataset based on the configured model & tokenizer
|
||||||
|
- shuffle and merge multiple datasets together if using more than one
|
||||||
|
|
||||||
|
The processing of the datasets can happen one of two ways:
|
||||||
|
|
||||||
|
1. Before kicking off training by calling `python -m axolotl.cli.preprocess /path/to/your.yaml --debug`
|
||||||
|
2. When training is started
|
||||||
|
|
||||||
|
What are the benefits of pre-processing? When training interactively or for sweeps
|
||||||
|
(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly
|
||||||
|
slow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent
|
||||||
|
training parameters so that it will intelligently pull from its cache when possible.
|
||||||
|
|
||||||
|
The path of the cache is controlled by `dataset_prepared_path:` and is often left blank in example
|
||||||
|
YAMLs as this leads to a more robust solution that prevents unexpectedly reusing cached data.
|
||||||
|
|
||||||
|
If `dataset_prepared_path:` is left empty, when training, the processed dataset will be cached in a
|
||||||
|
default path of `./last_run_prepared/`, but will ignore anything already cached there. By explicitly
|
||||||
|
setting `dataset_prepared_path: ./last_run_prepared`, the trainer will use whatever pre-processed
|
||||||
|
data is in the cache.
|
||||||
|
|
||||||
|
What are the edge cases? Let's say you are writing a custom prompt strategy or using a user-defined
|
||||||
|
prompt template. Because the trainer cannot readily detect these changes, we cannot change the
|
||||||
|
calculated hash value for the pre-processed dataset. If you have `dataset_prepared_path: ...` set
|
||||||
|
and change your prompt templating logic, it may not pick up the changes you made and you will be
|
||||||
|
training over the old prompt.
|
||||||
@@ -49,7 +49,7 @@ remove_unused_columns: false
|
|||||||
chat_template: chatml
|
chat_template: chatml
|
||||||
datasets:
|
datasets:
|
||||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||||
type: orpo.chat_template
|
type: chat_template.argilla
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Using local dataset files
|
#### Using local dataset files
|
||||||
|
|||||||
81
examples/dbrx/16bit-lora.yaml
Normal file
81
examples/dbrx/16bit-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: LnL-AI/dbrx-base-converted-v2
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
# w1, w2, & v1 will hang the trainer
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj # attn
|
||||||
|
- k_proj # attn
|
||||||
|
- v_proj # attn
|
||||||
|
- out_proj # attn
|
||||||
|
- layer # router
|
||||||
|
# - w1
|
||||||
|
# - w2
|
||||||
|
# - v1
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_activation_checkpointing: true
|
||||||
81
examples/dbrx/8bit-lora.yaml
Normal file
81
examples/dbrx/8bit-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: LnL-AI/dbrx-base-converted-v2
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
# w1, w2, & v1 will hang the trainer
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj # attn
|
||||||
|
- k_proj # attn
|
||||||
|
- v_proj # attn
|
||||||
|
- out_proj # attn
|
||||||
|
- layer # router
|
||||||
|
# - w1
|
||||||
|
# - w2
|
||||||
|
# - v1
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_activation_checkpointing: true
|
||||||
26
examples/dbrx/README.md
Normal file
26
examples/dbrx/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# DBRX MoE
|
||||||
|
|
||||||
|
Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.
|
||||||
|
|
||||||
|
We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
|
||||||
|
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
|
||||||
|
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
|
||||||
|
results in the trainer hanging.
|
||||||
|
|
||||||
|
|
||||||
|
### FSDP
|
||||||
|
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.
|
||||||
|
|
||||||
|
The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.
|
||||||
|
|
||||||
|
- 16-bit LoRA w/ FSDP
|
||||||
|
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
|
||||||
|
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
|
||||||
|
- ✅ 8-bit LoRA w/ FSDP
|
||||||
|
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
|
||||||
|
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)
|
||||||
|
|
||||||
|
|
||||||
|
### Deepspeed
|
||||||
|
|
||||||
|
WIP
|
||||||
56
examples/dbrx/fft-ds-zero3.yaml
Normal file
56
examples/dbrx/fft-ds-zero3.yaml
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
base_model: LnL-AI/dbrx-base-converted-v2
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
unfrozen_parameters:
|
||||||
|
- transformer.blocks.[0-7].
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||||
@@ -65,12 +65,14 @@ deepspeed:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
- full_shard
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
fsdp_limit_all_gathers: true
|
fsdp_limit_all_gathers: true
|
||||||
fsdp_sync_module_states: true
|
fsdp_sync_module_states: true
|
||||||
fsdp_offload_params: true
|
fsdp_offload_params: true
|
||||||
fsdp_use_orig_params: false
|
fsdp_use_orig_params: false
|
||||||
fsdp_cpu_ram_efficient_loading: true
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
13
examples/llama-3/README.md
Normal file
13
examples/llama-3/README.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Llama-3
|
||||||
|
|
||||||
|
https://llama.meta.com/llama3/
|
||||||
|
|
||||||
|
[8B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
|
||||||
|
- [Full Fine Tune](./fft-8b.yaml)
|
||||||
|
- Single GPU @ 48GB VRAM
|
||||||
|
- [LoRA](./lora-8b.yml)
|
||||||
|
- Single GPU @ 11GB VRAM
|
||||||
|
|
||||||
|
[70B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
|
||||||
|
- [QLORA+FSDP](./qlora-fsdp-70b.yaml)
|
||||||
|
- Dual GPU @ 21GB VRAM
|
||||||
58
examples/llama-3/fft-8b.yaml
Normal file
58
examples/llama-3/fft-8b.yaml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
evals_per_epoch: 2
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
67
examples/llama-3/lora-8b.yml
Normal file
67
examples/llama-3/lora-8b.yml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
80
examples/llama-3/qlora-fsdp-70b.yaml
Normal file
80
examples/llama-3/qlora-fsdp-70b.yaml
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
base_model: casperhansen/llama-3-70b-fp16
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer # PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./out/qlora-llama3-70b
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.00001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
67
examples/llama-3/qlora.yml
Normal file
67
examples/llama-3/qlora.yml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: aaditya/alpaca_subset_1
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: paged_adamw_32bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
63
examples/mistral/bigstral-ds-zero3.yaml
Normal file
63
examples/mistral/bigstral-ds-zero3.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
unfrozen_parameters:
|
||||||
|
- ^lm_head.weight$
|
||||||
|
- ^model.embed_tokens.weight$
|
||||||
|
- model.layers.4[4-9]+.block_sparse_moe.gate
|
||||||
|
- model.layers.4[4-9]+.block_sparse_moe.experts
|
||||||
|
- model.layers.5[0-5]+.block_sparse_moe.gate
|
||||||
|
- model.layers.5[0-5]+.block_sparse_moe.experts
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 3
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
save_total_limit: 1
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
eos_token: "<|im_end|>"
|
||||||
|
tokens:
|
||||||
|
- "<|im_start|>"
|
||||||
82
examples/mistral/mistral-qlora-fsdp.yml
Normal file
82
examples/mistral/mistral-qlora-fsdp.yml
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.02
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 1024
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: false
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
special_tokens:
|
||||||
82
examples/mistral/mistral-qlora-orpo.yml
Normal file
82
examples/mistral/mistral-qlora-orpo.yml
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
base_model: mistralai/Mistral-7B-v0.1
|
||||||
|
model_type: MistralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
rl: orpo
|
||||||
|
orpo_alpha: 0.1
|
||||||
|
remove_unused_columns: false
|
||||||
|
|
||||||
|
chat_template: chatml
|
||||||
|
datasets:
|
||||||
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||||
|
type: chat_template.argilla
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./mistral-qlora-orpo-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
81
examples/mistral/mixtral-8x22b-qlora-fsdp.yml
Normal file
81
examples/mistral/mixtral-8x22b-qlora-fsdp.yml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.02
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 1024
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
special_tokens:
|
||||||
@@ -39,7 +39,7 @@ wandb_log_model:
|
|||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
optimizer: paged_adamw_8bit
|
optimizer: adamw_torch
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ train_on_inputs: false
|
|||||||
group_by_length: false
|
group_by_length: false
|
||||||
bf16: auto
|
bf16: auto
|
||||||
fp16:
|
fp16:
|
||||||
tf32: false
|
tf32: true
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
@@ -69,6 +69,17 @@ debug:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
- full_shard
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
fsdp_forward_prefetch: false
|
||||||
|
fsdp_backward_prefetch: BACKWARD_PRE
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
61
examples/mistral/mixtral_22.yml
Normal file
61
examples/mistral/mixtral_22.yml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
unfrozen_parameters:
|
||||||
|
- ^lm_head.weight$
|
||||||
|
- ^model.embed_tokens.weight$
|
||||||
|
- model.layers.4[4-9]+.block_sparse_moe.gate
|
||||||
|
- model.layers.4[4-9]+.block_sparse_moe.experts
|
||||||
|
- model.layers.5[0-5]+.block_sparse_moe.gate
|
||||||
|
- model.layers.5[0-5]+.block_sparse_moe.experts
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: yahma/alpaca-cleaned
|
||||||
|
type: alpaca
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 8000
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 3
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
save_total_limit: 1
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
eos_token: "<|im_end|>"
|
||||||
|
tokens:
|
||||||
|
- "<|im_start|>"
|
||||||
@@ -11,7 +11,7 @@ addict
|
|||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
datasets>=2.15.0
|
datasets==2.15.0
|
||||||
flash-attn==2.5.5
|
flash-attn==2.5.5
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
@@ -28,7 +28,7 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat==0.2.36
|
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
||||||
@@ -39,5 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
|
fastcore
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ fi
|
|||||||
|
|
||||||
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
||||||
# Run Jupyter Lab in the background
|
# Run Jupyter Lab in the background
|
||||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
|
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Execute the passed arguments (CMD)
|
# Execute the passed arguments (CMD)
|
||||||
|
|||||||
@@ -264,8 +264,8 @@ def do_inference_gradio(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
repetition_penalty=1.1,
|
repetition_penalty=1.1,
|
||||||
max_new_tokens=1024,
|
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
||||||
temperature=0.9,
|
temperature=cfg.get("gradio_temperature", 0.9),
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=40,
|
top_k=40,
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
@@ -300,7 +300,13 @@ def do_inference_gradio(
|
|||||||
outputs="text",
|
outputs="text",
|
||||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||||
)
|
)
|
||||||
demo.queue().launch(show_api=False, share=True)
|
|
||||||
|
demo.queue().launch(
|
||||||
|
show_api=False,
|
||||||
|
share=cfg.get("gradio_share", True),
|
||||||
|
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||||
|
server_port=cfg.get("gradio_server_port", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
@@ -433,6 +439,23 @@ def load_rl_datasets(
|
|||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cli_args.debug or cfg.debug:
|
||||||
|
LOG.info("check_dataset_labels...")
|
||||||
|
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
check_dataset_labels(
|
||||||
|
train_dataset.select(
|
||||||
|
[
|
||||||
|
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||||
|
for _ in range(cli_args.debug_num_examples)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
num_examples=cli_args.debug_num_examples,
|
||||||
|
text_only=cli_args.debug_text_only,
|
||||||
|
rl_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
return TrainDatasetMeta(
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
LOG.warning(msg)
|
LOG.warning(msg)
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
else:
|
else:
|
||||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
|
||||||
if cfg.rl and cfg.rl != "orpo":
|
if cfg.rl: # and cfg.rl != "orpo":
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -43,6 +43,7 @@ from axolotl.utils.callbacks import (
|
|||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
|
SaveModelOnTrainEndCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -54,6 +55,7 @@ from axolotl.utils.collators import (
|
|||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.models import ensure_dtype
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
get_cosine_schedule_with_min_lr,
|
get_cosine_schedule_with_min_lr,
|
||||||
@@ -211,6 +213,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "path under the model to access the layers"},
|
metadata={"help": "path under the model to access the layers"},
|
||||||
)
|
)
|
||||||
|
curriculum_sampling: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -346,6 +352,8 @@ class AxolotlTrainer(Trainer):
|
|||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
)
|
)
|
||||||
|
if self.args.curriculum_sampling:
|
||||||
|
return SequentialSampler(self.train_dataset)
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
@@ -810,6 +818,14 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlORPOTrainer(ORPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Base class for trainer builder
|
Base class for trainer builder
|
||||||
@@ -873,6 +889,14 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_mlflow and is_mlflow_available():
|
||||||
|
from axolotl.utils.callbacks.mlflow_ import (
|
||||||
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -918,22 +942,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
):
|
):
|
||||||
callbacks.append(SaveBetterTransformerModelCallback())
|
callbacks.append(SaveBetterTransformerModelCallback())
|
||||||
|
|
||||||
if self.cfg.use_wandb:
|
|
||||||
callbacks.append(
|
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
|
||||||
)
|
|
||||||
if self.cfg.use_mlflow and is_mlflow_available():
|
|
||||||
from axolotl.utils.callbacks.mlflow_ import (
|
|
||||||
SaveAxolotlConfigtoMlflowCallback,
|
|
||||||
)
|
|
||||||
|
|
||||||
callbacks.append(
|
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.loss_watchdog_threshold is not None:
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
|
callbacks.append(SaveModelOnTrainEndCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -1058,9 +1071,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.save_safetensors is not None:
|
if self.cfg.save_safetensors is not None:
|
||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.cfg.save_only_model is not None:
|
|
||||||
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
|
|
||||||
|
|
||||||
if self.cfg.sample_packing_eff_est:
|
if self.cfg.sample_packing_eff_est:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"sample_packing_efficiency"
|
"sample_packing_efficiency"
|
||||||
@@ -1191,6 +1201,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
False if self.cfg.ddp else None
|
False if self.cfg.ddp else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
report_to = None
|
report_to = None
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
report_to = "wandb"
|
report_to = "wandb"
|
||||||
@@ -1411,13 +1422,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""
|
||||||
Trainer factory class for DPO Trainer
|
Trainer factory class for DPO Trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
|
callbacks.append(SaveModelOnTrainEndCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -1453,6 +1466,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["evaluation_strategy"] = "no"
|
training_args_kwargs["evaluation_strategy"] = "no"
|
||||||
|
|
||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
|
|
||||||
@@ -1504,7 +1518,19 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
training_args = TrainingArguments(
|
if self.cfg.orpo_alpha:
|
||||||
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|
||||||
|
training_args_cls = TrainingArguments
|
||||||
|
if self.cfg.rl == "orpo":
|
||||||
|
training_args_cls = ORPOConfig
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
|
||||||
|
training_args_cls = DPOConfig
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
|
training_args = training_args_cls(
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=self.cfg.max_steps or total_num_steps,
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
@@ -1529,6 +1555,8 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
elif self.cfg.rl == "kto_pair":
|
elif self.cfg.rl == "kto_pair":
|
||||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||||
|
elif self.cfg.rl == "sppo_hard":
|
||||||
|
dpo_trainer_kwargs["loss_type"] = "sppo_hard"
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
@@ -1537,20 +1565,34 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
dpo_trainer = AxolotlDPOTrainer(
|
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
|
||||||
self.model,
|
trainer_cls = AxolotlDPOTrainer
|
||||||
self.model_ref,
|
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||||
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
|
# these aren't used for the ORPO trainer
|
||||||
|
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
dpo_trainer_kwargs["max_target_length"] = None
|
||||||
|
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
|
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||||
|
if self.cfg.rl == "dpo":
|
||||||
|
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
elif self.cfg.rl == "orpo":
|
||||||
|
trainer_cls = AxolotlORPOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
dpo_trainer = trainer_cls(
|
||||||
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
beta=self.cfg.dpo_beta or 0.1,
|
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
max_length=self.cfg.sequence_len,
|
|
||||||
max_target_length=None,
|
|
||||||
max_prompt_length=self.cfg.sequence_len,
|
|
||||||
generate_during_eval=True,
|
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
**dpo_trainer_kwargs,
|
**dpo_trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
if self.cfg.fsdp:
|
||||||
|
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
||||||
|
|
||||||
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
||||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||||
dpo_trainer.add_callback(callback)
|
dpo_trainer.add_callback(callback)
|
||||||
|
|||||||
@@ -123,6 +123,14 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.GEMMA:
|
||||||
|
if self.system_message:
|
||||||
|
raise ValueError("Gemma chat template does not support system messages")
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
prefix = "<bos>" if i == 0 else ""
|
||||||
|
message_str = message if message else ""
|
||||||
|
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
|
||||||
|
return
|
||||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||||
|
|||||||
@@ -516,24 +516,18 @@ def mistral_model_forward(
|
|||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = (
|
||||||
def create_custom_forward(module):
|
self._gradient_checkpointing_func( # pylint: disable=protected-access
|
||||||
def custom_forward(*inputs):
|
decoder_layer.__call__,
|
||||||
# None for past_key_value
|
hidden_states,
|
||||||
return module(*inputs)
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
return custom_forward
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
None,
|
||||||
create_custom_forward(decoder_layer),
|
cu_seqlens,
|
||||||
hidden_states,
|
max_seqlen,
|
||||||
attention_mask,
|
)
|
||||||
position_ids,
|
|
||||||
past_key_value,
|
|
||||||
output_attentions,
|
|
||||||
None,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
|
|||||||
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
30
src/axolotl/prompt_strategies/dpo/mistral.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
DPO strategies for mistral instruct
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample["prompt"] = f"[INST]{sample['prompt']}[/INST]"
|
||||||
|
sample["chosen"] = f"{sample['chosen']}"
|
||||||
|
sample["rejected"] = f"{sample['rejected']}"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|
||||||
|
|
||||||
|
def argilla_chat(
|
||||||
|
cfg,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
"""
|
||||||
|
for argilla/dpo-mix-7k conversations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform_fn(sample):
|
||||||
|
sample["prompt"] = f"[INST] {sample['chosen'][0]['content']} [/INST]"
|
||||||
|
sample["chosen"] = f"{sample['chosen'][1]['content']}</s>"
|
||||||
|
sample["rejected"] = f"{sample['rejected'][1]['content']}</s>"
|
||||||
|
return sample
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -6,4 +6,4 @@ from functools import partial
|
|||||||
|
|
||||||
from ..base import load as load_base
|
from ..base import load as load_base
|
||||||
|
|
||||||
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
|
||||||
|
|||||||
@@ -78,6 +78,57 @@ class ORPODatasetParsingStrategy:
|
|||||||
)
|
)
|
||||||
return MessageList(messages=messages)
|
return MessageList(messages=messages)
|
||||||
|
|
||||||
|
def get_prompt(self, prompt) -> MessageList:
|
||||||
|
"""Map the data to extract everything up to the last turn"""
|
||||||
|
total_msg_len = len(prompt["chosen"])
|
||||||
|
total_msg_turns, remainder = divmod(total_msg_len, 2)
|
||||||
|
assert remainder == 0, "invalid number of turns"
|
||||||
|
|
||||||
|
messages: List[Message] = []
|
||||||
|
if system := prompt.get("system", None):
|
||||||
|
messages.append(Message(role="system", content=system, label=False))
|
||||||
|
for i in range(total_msg_turns):
|
||||||
|
if "prompt" in prompt:
|
||||||
|
messages.append(
|
||||||
|
Message(role="user", content=prompt["prompt"], label=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
messages.append(
|
||||||
|
Message(
|
||||||
|
role="user",
|
||||||
|
content=prompt["chosen"][i * 2]["content"],
|
||||||
|
label=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if i < total_msg_turns - 1:
|
||||||
|
messages.append(
|
||||||
|
Message(
|
||||||
|
role="assistant",
|
||||||
|
content=prompt["chosen"][i * 2 + 1]["content"],
|
||||||
|
label=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return MessageList(messages=messages)
|
||||||
|
|
||||||
|
def get_chosen(self, prompt) -> MessageList:
|
||||||
|
res = self.get_prompt(prompt)
|
||||||
|
res.messages.append(
|
||||||
|
Message(
|
||||||
|
role="assistant", content=prompt["chosen"][-1]["content"], label=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_rejected(self, prompt) -> MessageList:
|
||||||
|
res = self.get_prompt(prompt)
|
||||||
|
res.messages.append(
|
||||||
|
Message(
|
||||||
|
role="assistant", content=prompt["rejected"][-1]["content"], label=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -186,3 +237,36 @@ class ORPOPrompter(Prompter):
|
|||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
), True
|
), True
|
||||||
|
|
||||||
|
|
||||||
|
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
dataset_parser = ORPODatasetParsingStrategy()
|
||||||
|
|
||||||
|
chat_template_str = chat_templates(cfg.chat_template)
|
||||||
|
|
||||||
|
def transform_fn(sample, tokenizer=None):
|
||||||
|
res = {}
|
||||||
|
|
||||||
|
res["prompt"] = tokenizer.apply_chat_template(
|
||||||
|
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
prompt_str_len = len(res["prompt"])
|
||||||
|
res["chosen"] = tokenizer.apply_chat_template(
|
||||||
|
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)[prompt_str_len:]
|
||||||
|
res["rejected"] = tokenizer.apply_chat_template(
|
||||||
|
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)[prompt_str_len:]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|||||||
@@ -3,12 +3,14 @@
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
@@ -81,6 +83,8 @@ def train(
|
|||||||
if cfg.adapter:
|
if cfg.adapter:
|
||||||
msg += " and peft_config..."
|
msg += " and peft_config..."
|
||||||
LOG.debug(msg)
|
LOG.debug(msg)
|
||||||
|
# we wait unitl the last possible moment to setup Accelerator
|
||||||
|
Accelerator()
|
||||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
@@ -124,14 +128,20 @@ def train(
|
|||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
|
|
||||||
def terminate_handler(_, __, model):
|
def terminate_handler(_, __, model_weakref):
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if model_weakref() is not None:
|
||||||
model = BetterTransformer.reverse(model)
|
_model = model_weakref()
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
|
_model = BetterTransformer.reverse(_model)
|
||||||
|
_model.save_pretrained(
|
||||||
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
_model_weakref = weakref.ref(model)
|
||||||
signal.signal(
|
signal.signal(
|
||||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
signal.SIGINT,
|
||||||
|
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||||
)
|
)
|
||||||
|
|
||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||||
|
|||||||
@@ -773,3 +773,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class SaveModelOnTrainEndCallback(TrainerCallback):
|
||||||
|
"""Callback to save model on train end"""
|
||||||
|
|
||||||
|
def on_train_end( # pylint: disable=unused-argument
|
||||||
|
self, args, state, control, **kwargs
|
||||||
|
):
|
||||||
|
control.should_save = True
|
||||||
|
return control
|
||||||
|
|||||||
@@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
|
|||||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.gptq and cfg.revision_of_model:
|
if cfg.gptq and cfg.revision_of_model:
|
||||||
@@ -448,10 +448,14 @@ def legacy_validate_config(cfg):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||||
)
|
)
|
||||||
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||||
)
|
)
|
||||||
|
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||||
|
raise ValueError(
|
||||||
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
|
)
|
||||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||||
@@ -464,11 +468,6 @@ def legacy_validate_config(cfg):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
)
|
)
|
||||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
|
||||||
raise ValueError(
|
|
||||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.evaluation_strategy
|
cfg.evaluation_strategy
|
||||||
and cfg.eval_steps
|
and cfg.eval_steps
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ class RLType(str, Enum):
|
|||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
|
sppo_hard = "sppo_hard" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
@@ -259,6 +260,7 @@ class ModelInputConfig(BaseModel):
|
|||||||
|
|
||||||
base_model: str
|
base_model: str
|
||||||
base_model_config: Optional[str] = None
|
base_model_config: Optional[str] = None
|
||||||
|
cls_model_config: Optional[str] = None
|
||||||
tokenizer_config: Optional[str] = None
|
tokenizer_config: Optional[str] = None
|
||||||
tokenizer_use_fast: Optional[bool] = None
|
tokenizer_use_fast: Optional[bool] = None
|
||||||
tokenizer_legacy: Optional[bool] = None
|
tokenizer_legacy: Optional[bool] = None
|
||||||
@@ -355,7 +357,6 @@ class ModelOutputConfig(BaseModel):
|
|||||||
hub_model_id: Optional[str] = None
|
hub_model_id: Optional[str] = None
|
||||||
hub_strategy: Optional[str] = None
|
hub_strategy: Optional[str] = None
|
||||||
save_safetensors: Optional[bool] = None
|
save_safetensors: Optional[bool] = None
|
||||||
save_only_model: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
class MLFlowConfig(BaseModel):
|
class MLFlowConfig(BaseModel):
|
||||||
@@ -409,6 +410,17 @@ class WandbConfig(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class GradioConfig(BaseModel):
|
||||||
|
"""Gradio configuration subset"""
|
||||||
|
|
||||||
|
gradio_title: Optional[str] = None
|
||||||
|
gradio_share: Optional[bool] = None
|
||||||
|
gradio_server_name: Optional[str] = None
|
||||||
|
gradio_server_port: Optional[int] = None
|
||||||
|
gradio_max_new_tokens: Optional[int] = None
|
||||||
|
gradio_temperature: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-public-methods,too-many-ancestors
|
# pylint: disable=too-many-public-methods,too-many-ancestors
|
||||||
class AxolotlInputConfig(
|
class AxolotlInputConfig(
|
||||||
ModelInputConfig,
|
ModelInputConfig,
|
||||||
@@ -419,6 +431,7 @@ class AxolotlInputConfig(
|
|||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
LISAConfig,
|
LISAConfig,
|
||||||
|
GradioConfig,
|
||||||
RemappedParameters,
|
RemappedParameters,
|
||||||
DeprecatedParameters,
|
DeprecatedParameters,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -479,6 +492,7 @@ class AxolotlInputConfig(
|
|||||||
eval_causal_lm_metrics: Optional[List[str]] = None
|
eval_causal_lm_metrics: Optional[List[str]] = None
|
||||||
do_bench_eval: Optional[bool] = None
|
do_bench_eval: Optional[bool] = None
|
||||||
bench_dataset: Optional[str] = None
|
bench_dataset: Optional[str] = None
|
||||||
|
bench_split: Optional[str] = None
|
||||||
metric_for_best_model: Optional[str] = None
|
metric_for_best_model: Optional[str] = None
|
||||||
greater_is_better: Optional[bool] = None
|
greater_is_better: Optional[bool] = None
|
||||||
|
|
||||||
@@ -494,15 +508,25 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: Optional[torch.dtype]
|
# torch_dtype: Optional[torch.dtype]
|
||||||
|
|
||||||
gradient_checkpointing: Optional[bool] = Field(default=False)
|
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
|
||||||
|
default=False
|
||||||
|
)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
|
|
||||||
sequence_len: int = Field(default=512)
|
sequence_len: int = Field(default=512)
|
||||||
|
min_sample_len: Optional[int] = None
|
||||||
sample_packing: Optional[bool] = None
|
sample_packing: Optional[bool] = None
|
||||||
eval_sample_packing: Optional[bool] = None
|
eval_sample_packing: Optional[bool] = None
|
||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
|
curriculum_sampling: Optional[bool] = None
|
||||||
|
|
||||||
|
# for PoSE context length extension
|
||||||
|
use_pose: Optional[bool] = None
|
||||||
|
pose_split_on_token_ids: Optional[List[int]] = None
|
||||||
|
pose_max_context_len: Optional[int] = None
|
||||||
|
pose_num_chunks: Optional[int] = None
|
||||||
|
|
||||||
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
||||||
pretrain_multipack_attn: Optional[bool] = Field(
|
pretrain_multipack_attn: Optional[bool] = Field(
|
||||||
@@ -551,6 +575,7 @@ class AxolotlInputConfig(
|
|||||||
neftune_noise_alpha: Optional[float] = None
|
neftune_noise_alpha: Optional[float] = None
|
||||||
|
|
||||||
orpo_alpha: Optional[float] = None
|
orpo_alpha: Optional[float] = None
|
||||||
|
dpo_beta: Optional[float] = None
|
||||||
|
|
||||||
max_memory: Optional[
|
max_memory: Optional[
|
||||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||||
@@ -769,11 +794,11 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_push_save(cls, data):
|
def check_push_save(cls, data):
|
||||||
if data.get("hub_model_id") and not (
|
if data.get("hub_model_id") and (
|
||||||
data.get("save_steps") or data.get("saves_per_epoch")
|
data.get("save_strategy") not in ["steps", "epoch", None]
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -972,9 +997,16 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_w_8bit_optimizer(cls, data):
|
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||||
if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
|
if (
|
||||||
raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
|
data.get("fsdp")
|
||||||
|
and "8bit" in data.get("optimizer", "")
|
||||||
|
and data.get("fsdp_config")
|
||||||
|
and data["fsdp_config"].get("fsdp_offload_params")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"FSDP Offload not compatible with {data.get('optimizer')}"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
Data processing modules
|
Data processing modules
|
||||||
"""
|
"""
|
||||||
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
|
|
||||||
from axolotl.utils.data.pretraining import ( # noqa: F401
|
from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||||
encode_pretraining,
|
encode_pretraining,
|
||||||
wrap_pretraining_dataset,
|
wrap_pretraining_dataset,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
||||||
from axolotl.utils.data.sft import ( # noqa: F401
|
from axolotl.utils.data.sft import ( # noqa: F401
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
load_prepare_datasets,
|
load_prepare_datasets,
|
||||||
|
|||||||
@@ -1,17 +1,20 @@
|
|||||||
"""data handling specific to DPO"""
|
"""data handling specific to DPO"""
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from datasets import concatenate_datasets, load_dataset, load_from_disk
|
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||||
|
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||||
from axolotl.utils.data.utils import md5
|
from axolotl.utils.data.utils import md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -72,16 +75,29 @@ def load_prepare_dpo_datasets(cfg):
|
|||||||
)
|
)
|
||||||
split_datasets.insert(i, ds)
|
split_datasets.insert(i, ds)
|
||||||
|
|
||||||
|
tokenizer = None
|
||||||
for i, data_set in enumerate(split_datasets):
|
for i, data_set in enumerate(split_datasets):
|
||||||
_type = dataset_cfgs[i]["type"]
|
_type = dataset_cfgs[i]["type"]
|
||||||
if _type:
|
if _type:
|
||||||
if isinstance(_type, DictDefault):
|
if isinstance(_type, DictDefault):
|
||||||
_type = "user_defined.default"
|
_type = "user_defined.default"
|
||||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
if _cfg.rl == "orpo":
|
||||||
split_datasets[i] = data_set.map(
|
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
||||||
|
else:
|
||||||
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||||
|
sig = inspect.signature(ds_transform_fn)
|
||||||
|
if "tokenizer" in sig.parameters:
|
||||||
|
if not tokenizer:
|
||||||
|
tokenizer = load_tokenizer(_cfg)
|
||||||
|
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
data_set = data_set.map(
|
||||||
ds_transform_fn,
|
ds_transform_fn,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
)
|
)
|
||||||
|
if isinstance(data_set, DatasetDict):
|
||||||
|
data_set = data_set["train"]
|
||||||
|
split_datasets[i] = data_set
|
||||||
else:
|
else:
|
||||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||||
# "prompt", "chosen" and "rejected" already preprocessed
|
# "prompt", "chosen" and "rejected" already preprocessed
|
||||||
@@ -421,7 +421,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
dataset.save_to_disk(prepared_ds_path)
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
|
|||||||
@@ -4,27 +4,25 @@ utility helpers for distributed checks
|
|||||||
import os
|
import os
|
||||||
import pickle # nosec
|
import pickle # nosec
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate import Accelerator
|
from accelerate import PartialState
|
||||||
|
|
||||||
accelerate = None # pylint: disable=invalid-name
|
distributed_state = None # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def load_accelerate():
|
|
||||||
global accelerate # pylint: disable=global-statement
|
|
||||||
accelerate = Accelerator()
|
|
||||||
|
|
||||||
|
|
||||||
def is_distributed():
|
def is_distributed():
|
||||||
"""
|
"""
|
||||||
Check if distributed training is initialized.
|
Check if distributed training is initialized.
|
||||||
"""
|
"""
|
||||||
global accelerate # pylint: disable=global-statement
|
global distributed_state # pylint: disable=global-statement
|
||||||
if not accelerate:
|
if not distributed_state:
|
||||||
accelerate = Accelerator()
|
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||||
return dist.is_available() and dist.is_initialized()
|
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||||
|
|
||||||
|
return distributed_state.use_distributed and distributed_state.initialized
|
||||||
|
|
||||||
|
|
||||||
def barrier():
|
def barrier():
|
||||||
|
|||||||
13
src/axolotl/utils/gradient_checkpointing/__init__.py
Normal file
13
src/axolotl/utils/gradient_checkpointing/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""custom checkpointing utils"""
|
||||||
|
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||||
|
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hf_grad_checkpoint_unsloth_wrapper(
|
||||||
|
decoder_layer, *args, use_reentrant=None
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||||
|
decoder_layer.__self__,
|
||||||
|
*args,
|
||||||
|
)
|
||||||
52
src/axolotl/utils/gradient_checkpointing/unsloth.py
Normal file
52
src/axolotl/utils/gradient_checkpointing/unsloth.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Unsloth checkpointing"""
|
||||||
|
|
||||||
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||||
|
torch.autograd.Function
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Saves VRAM by smartly offloading to RAM.
|
||||||
|
Tiny hit to performance, since we mask the movement via non blocking calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.cuda.amp.custom_fwd
|
||||||
|
def forward(ctx, forward_function, hidden_states, *args):
|
||||||
|
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = forward_function(hidden_states, *args)
|
||||||
|
ctx.save_for_backward(saved_hidden_states)
|
||||||
|
ctx.forward_function = forward_function
|
||||||
|
ctx.args = args
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.cuda.amp.custom_bwd
|
||||||
|
def backward(ctx, dY):
|
||||||
|
(hidden_states,) = ctx.saved_tensors
|
||||||
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
|
hidden_states.requires_grad = True
|
||||||
|
with torch.enable_grad():
|
||||||
|
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||||
|
torch.autograd.backward(output, dY)
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
hidden_states.grad,
|
||||||
|
) + (
|
||||||
|
None,
|
||||||
|
) * len(ctx.args)
|
||||||
259
src/axolotl/utils/model_shard_quant.py
Normal file
259
src/axolotl/utils/model_shard_quant.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
"""
|
||||||
|
module to handle loading model on cpu/meta device for FSDP
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Type, Union
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from bitsandbytes.nn import Linear4bit, Params4bit
|
||||||
|
from fastcore.parallel import parallel
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_linear(
|
||||||
|
model: nn.Module,
|
||||||
|
linear_replacement: Type[nn.Module],
|
||||||
|
quant_config: Union[dict, None] = None,
|
||||||
|
skip_modules=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Replace linear modules with a new Linear module.
|
||||||
|
Parameters:
|
||||||
|
model (`torch.nn.Module`):
|
||||||
|
Input model or `torch.nn.Module` as the function is run recursively.
|
||||||
|
linear_replacement (`torch.nn.Module`):
|
||||||
|
The linear module that replaces the old one. Only expects standard arguments.
|
||||||
|
If other arguments need to be passed, use a lambda.
|
||||||
|
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||||
|
List of modules names not to convert. Defaults to `lm_head`.
|
||||||
|
"""
|
||||||
|
if skip_modules is None:
|
||||||
|
skip_modules = ["lm_head"]
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if len(list(module.children())) > 0:
|
||||||
|
_replace_linear(
|
||||||
|
module, linear_replacement, quant_config, skip_modules, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||||
|
if issubclass(linear_replacement, Linear4bit):
|
||||||
|
model._modules[ # pylint: disable=protected-access
|
||||||
|
name
|
||||||
|
] = linear_replacement(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported linear replacement: {type(linear_replacement)}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_quantize(
|
||||||
|
module: nn.Module,
|
||||||
|
name: str,
|
||||||
|
value: Tensor,
|
||||||
|
device: torch.device = None,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
skip_names: Optional[List[str]] = None,
|
||||||
|
to_cpu: bool = False,
|
||||||
|
to_meta: bool = False,
|
||||||
|
verbose: bool = False,
|
||||||
|
quant_method: str = "bnb",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
||||||
|
|
||||||
|
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not skip_names:
|
||||||
|
skip_names = []
|
||||||
|
|
||||||
|
def place_on_device(value):
|
||||||
|
if to_meta:
|
||||||
|
device = "meta"
|
||||||
|
elif to_cpu:
|
||||||
|
device = "cpu"
|
||||||
|
return value.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if any(skip_name in name for skip_name in skip_names):
|
||||||
|
if verbose:
|
||||||
|
print(f"Skipping {name} because it is in skip_names")
|
||||||
|
return
|
||||||
|
|
||||||
|
module_key, _, value_key = name.rpartition(".")
|
||||||
|
try:
|
||||||
|
submodule = module.get_submodule(module_key)
|
||||||
|
except AttributeError as exc:
|
||||||
|
print(f"Module {module_key} not found:\n{exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if quant_method == "bnb":
|
||||||
|
param = submodule.get_parameter(value_key)
|
||||||
|
if isinstance(param, Params4bit):
|
||||||
|
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
|
||||||
|
# shape as the quantized Params4bit with an initialized quant_state. However,
|
||||||
|
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
||||||
|
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
||||||
|
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
||||||
|
value = type(param)(
|
||||||
|
value.to(device=device, dtype=dtype).data, **param.__dict__
|
||||||
|
).cuda(device)
|
||||||
|
if to_meta:
|
||||||
|
value = type(param)(value.data.to("meta"), **value.__dict__)
|
||||||
|
elif to_cpu:
|
||||||
|
value = type(param)(value.data.to("cpu"), **value.__dict__)
|
||||||
|
else:
|
||||||
|
value = type(param)(place_on_device(value).data)
|
||||||
|
|
||||||
|
except AttributeError:
|
||||||
|
# it's a buffer
|
||||||
|
value = place_on_device(value)
|
||||||
|
|
||||||
|
setattr(submodule, value_key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def n_loading_workers(quant_method: str, param_count: float):
|
||||||
|
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||||
|
left = int(os.cpu_count() / torch.cuda.device_count())
|
||||||
|
model_params_b = 70
|
||||||
|
right = int(
|
||||||
|
(4 if quant_method == "hqq" else 8)
|
||||||
|
* (devprops.total_memory / 1e9 / 40)
|
||||||
|
* (model_params_b / (param_count / 1e9))
|
||||||
|
)
|
||||||
|
return min(left, right)
|
||||||
|
|
||||||
|
|
||||||
|
def load_sharded_model(
|
||||||
|
model_name,
|
||||||
|
model_config,
|
||||||
|
cfg,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
low_memory=True,
|
||||||
|
):
|
||||||
|
if (low_memory and cfg.local_rank == 0) or not low_memory:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
use_cache=False,
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
_attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
|
||||||
|
trust_remote_code=cfg.trust_remote_code,
|
||||||
|
)
|
||||||
|
dtype = torch_dtype if not cfg.float32 else None
|
||||||
|
model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
|
||||||
|
else:
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
model_config,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=cfg.trust_remote_code,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_sharded_model_quant(
|
||||||
|
model_name,
|
||||||
|
model_config,
|
||||||
|
cfg,
|
||||||
|
compute_dtype=torch.bfloat16,
|
||||||
|
quant_storage=torch.float32,
|
||||||
|
low_memory=True,
|
||||||
|
verbose=False,
|
||||||
|
loading_workers=2,
|
||||||
|
):
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
model_config,
|
||||||
|
trust_remote_code=cfg.trust_remote_code,
|
||||||
|
)
|
||||||
|
if hasattr(model, "transformer"):
|
||||||
|
model.transformer = _replace_linear(
|
||||||
|
model.transformer,
|
||||||
|
Linear4bit,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
quant_type="nf4",
|
||||||
|
quant_storage=quant_storage,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# this is the more common case with HF transformers
|
||||||
|
model.model = _replace_linear(
|
||||||
|
model.model,
|
||||||
|
Linear4bit,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
quant_type="nf4",
|
||||||
|
quant_storage=quant_storage,
|
||||||
|
)
|
||||||
|
model.is_loaded_in_4bit = True
|
||||||
|
|
||||||
|
# Grab the safetensors files that hold the weights
|
||||||
|
try:
|
||||||
|
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
|
||||||
|
except OSError:
|
||||||
|
try:
|
||||||
|
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
|
||||||
|
files = []
|
||||||
|
files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
|
||||||
|
except OSError as exc:
|
||||||
|
# This means the model probably doesn't have a safetensors file
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
|
||||||
|
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
|
||||||
|
def load_and_quantize_parallel(name_param, model, **kwargs):
|
||||||
|
name, param = name_param
|
||||||
|
load_and_quantize(model, name, param, **kwargs)
|
||||||
|
|
||||||
|
quant_method = "bnb"
|
||||||
|
param_count = sum((p.numel() for n, p in model.named_parameters()))
|
||||||
|
|
||||||
|
n_workers = (
|
||||||
|
n_loading_workers(quant_method, param_count)
|
||||||
|
if loading_workers == -1
|
||||||
|
else loading_workers
|
||||||
|
)
|
||||||
|
if cfg.local_rank == 0 and verbose:
|
||||||
|
print(f"Using n_workers: {n_workers} for loading")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for filename in tqdm(
|
||||||
|
files,
|
||||||
|
desc="Loading & Quantizing Model Shards",
|
||||||
|
disable=cfg.local_rank != 0,
|
||||||
|
position=0,
|
||||||
|
):
|
||||||
|
weights = safetensors.torch.load_file(filename)
|
||||||
|
parallel(
|
||||||
|
load_and_quantize_parallel,
|
||||||
|
iter(weights.items()),
|
||||||
|
n_workers=n_workers,
|
||||||
|
threadpool=True,
|
||||||
|
model=model,
|
||||||
|
dtype=quant_storage,
|
||||||
|
device=cfg.local_rank,
|
||||||
|
skip_names=[],
|
||||||
|
to_cpu=(low_memory and cfg.local_rank == 0),
|
||||||
|
to_meta=(low_memory and cfg.local_rank != 0),
|
||||||
|
verbose=verbose,
|
||||||
|
quant_method=quant_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.local_rank == 0 and verbose:
|
||||||
|
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
||||||
|
# cleanup any extra memory usage from parallel loading
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return model
|
||||||
@@ -11,6 +11,7 @@ import addict
|
|||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
import transformers.modeling_utils
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from bitsandbytes.nn import Params4bit
|
from bitsandbytes.nn import Params4bit
|
||||||
from peft import (
|
from peft import (
|
||||||
@@ -44,11 +45,37 @@ from axolotl.utils.bench import log_gpu_memory_usage
|
|||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import zero_only
|
from axolotl.utils.distributed import zero_only
|
||||||
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||||
|
def get_module_class_from_name(module, name):
|
||||||
|
"""
|
||||||
|
Gets a class from a module by its name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (`torch.nn.Module`): The module to get the class from.
|
||||||
|
name (`str`): The name of the class.
|
||||||
|
"""
|
||||||
|
modules_children = list(module.children())
|
||||||
|
if module.__class__.__name__ == name:
|
||||||
|
return module.__class__
|
||||||
|
|
||||||
|
if len(modules_children) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for child_module in modules_children:
|
||||||
|
module_class = get_module_class_from_name(child_module, name)
|
||||||
|
if module_class is not None:
|
||||||
|
return module_class
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||||
quant_config_exists = (
|
quant_config_exists = (
|
||||||
hasattr(model_config, "quantization_config")
|
hasattr(model_config, "quantization_config")
|
||||||
@@ -285,6 +312,9 @@ def load_model(
|
|||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
|
|
||||||
|
if cfg.gradient_checkpointing == "unsloth":
|
||||||
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||||
|
|
||||||
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
|
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
|
||||||
if cfg.flash_attention:
|
if cfg.flash_attention:
|
||||||
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
||||||
@@ -459,7 +489,7 @@ def load_model(
|
|||||||
"bnb_4bit_quant_type": "nf4",
|
"bnb_4bit_quant_type": "nf4",
|
||||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||||
}
|
}
|
||||||
if not cfg.deepspeed:
|
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
|
||||||
# for some reason, this causes the loss to be off by an order of magnitude
|
# for some reason, this causes the loss to be off by an order of magnitude
|
||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
@@ -470,6 +500,13 @@ def load_model(
|
|||||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
elif cfg.adapter == "lora" and cfg.load_in_8bit:
|
||||||
|
bnb_config = {
|
||||||
|
"load_in_8bit": True,
|
||||||
|
}
|
||||||
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
**bnb_config,
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.load_in_8bit and cfg.adapter is not None:
|
if cfg.load_in_8bit and cfg.adapter is not None:
|
||||||
model_kwargs["load_in_8bit"] = True
|
model_kwargs["load_in_8bit"] = True
|
||||||
@@ -517,7 +554,31 @@ def load_model(
|
|||||||
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
|
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
skip_move_to_device = False
|
||||||
if (
|
if (
|
||||||
|
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
|
) and not qlora_fsdp:
|
||||||
|
model = load_sharded_model(
|
||||||
|
base_model,
|
||||||
|
model_config,
|
||||||
|
cfg,
|
||||||
|
torch_dtype=cfg.torch_dtype,
|
||||||
|
)
|
||||||
|
skip_move_to_device = True
|
||||||
|
elif (
|
||||||
|
qlora_fsdp
|
||||||
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
|
and cfg.model_config_type == "dbrx"
|
||||||
|
):
|
||||||
|
quant_storage = cfg.torch_dtype
|
||||||
|
model = load_sharded_model_quant(
|
||||||
|
base_model,
|
||||||
|
model_config,
|
||||||
|
cfg,
|
||||||
|
quant_storage=quant_storage,
|
||||||
|
)
|
||||||
|
skip_move_to_device = True
|
||||||
|
elif (
|
||||||
model_config.model_type == "llama"
|
model_config.model_type == "llama"
|
||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
@@ -597,6 +658,11 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
|
skip_move_to_device = True
|
||||||
|
if "device_map" in model_kwargs:
|
||||||
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
@@ -670,13 +736,17 @@ def load_model(
|
|||||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
skip_prepare_model_for_kbit_training = False
|
skip_prepare_model_for_kbit_training = False
|
||||||
|
|
||||||
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
||||||
set_z3_leaf_modules,
|
set_z3_leaf_modules,
|
||||||
)
|
)
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|
||||||
|
|
||||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
if cfg.model_config_type == "mixtral":
|
||||||
|
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
|
||||||
|
set_z3_leaf_modules(model, [moe_block])
|
||||||
|
elif cfg.model_config_type == "dbrx":
|
||||||
|
moe_block = get_module_class_from_name(model, "DbrxFFN")
|
||||||
|
set_z3_leaf_modules(model, [moe_block])
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||||
@@ -686,7 +756,8 @@ def load_model(
|
|||||||
if cfg.adapter == "lora" and loftq_bits:
|
if cfg.adapter == "lora" and loftq_bits:
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if qlora_fsdp:
|
if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
|
||||||
|
# make sure everything is in the same dtype
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if cfg.adapter in ["lora", "qlora"]:
|
if cfg.adapter in ["lora", "qlora"]:
|
||||||
@@ -718,7 +789,11 @@ def load_model(
|
|||||||
if not reference_model or cfg.lora_model_dir:
|
if not reference_model or cfg.lora_model_dir:
|
||||||
# if we're not loading the reference model, then we're loading the model for training
|
# if we're not loading the reference model, then we're loading the model for training
|
||||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||||
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
|
if (
|
||||||
|
cfg.adapter
|
||||||
|
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]
|
||||||
|
and not cfg.merge_lora
|
||||||
|
):
|
||||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
@@ -727,7 +802,7 @@ def load_model(
|
|||||||
cfg.ddp
|
cfg.ddp
|
||||||
and not load_in_8bit
|
and not load_in_8bit
|
||||||
and not (cfg.rl and cfg.load_in_4bit)
|
and not (cfg.rl and cfg.load_in_4bit)
|
||||||
and not qlora_fsdp
|
and not skip_move_to_device
|
||||||
):
|
):
|
||||||
# TODO revaldate this conditional
|
# TODO revaldate this conditional
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
@@ -883,7 +958,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
|
|
||||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
|
||||||
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
|
if (
|
||||||
|
cfg.fsdp
|
||||||
|
and cfg.adapter
|
||||||
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
|
and rank != 0
|
||||||
|
):
|
||||||
setup_quantized_meta_for_peft(model)
|
setup_quantized_meta_for_peft(model)
|
||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
@@ -908,7 +988,22 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Exception caught during model.print_trainable_parameters(): %s", exc
|
"Exception caught during model.print_trainable_parameters(): %s", exc
|
||||||
)
|
)
|
||||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
elif (
|
||||||
|
cfg.fsdp
|
||||||
|
and cfg.adapter
|
||||||
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
|
and rank != 0
|
||||||
|
):
|
||||||
setup_quantized_peft_meta_for_training(model)
|
setup_quantized_peft_meta_for_training(model)
|
||||||
|
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_dtype(model, dtype=torch.bfloat16):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
try:
|
||||||
|
if module.weight.dtype != dtype:
|
||||||
|
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
|
||||||
|
module.to(dtype)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Module for tokenization utilities"""
|
"""Module for tokenization utilities"""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
@@ -10,10 +9,19 @@ from termcolor import colored
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
|
def check_dataset_labels(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
num_examples=5,
|
||||||
|
text_only=False,
|
||||||
|
rl_mode=False,
|
||||||
|
):
|
||||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||||
for idx in range(num_examples):
|
for idx in range(num_examples):
|
||||||
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
if not rl_mode:
|
||||||
|
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||||
|
else:
|
||||||
|
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
||||||
|
|
||||||
|
|
||||||
def check_example_labels(example, tokenizer, text_only=False):
|
def check_example_labels(example, tokenizer, text_only=False):
|
||||||
@@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False):
|
|||||||
return " ".join(colored_tokens)
|
return " ".join(colored_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
|
||||||
|
"""Helper function to color tokens based on their type."""
|
||||||
|
colored_text = colored(decoded_token, color)
|
||||||
|
return (
|
||||||
|
colored_text
|
||||||
|
if text_only
|
||||||
|
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
||||||
|
"""Helper function to process and color tokens."""
|
||||||
|
colored_tokens = [
|
||||||
|
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
||||||
|
for token in tokenizer.encode(tokens)
|
||||||
|
]
|
||||||
|
return colored_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def check_rl_example_labels(example, tokenizer, text_only=False):
|
||||||
|
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
|
||||||
|
|
||||||
|
input_tokens = example[field_prompt]
|
||||||
|
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
|
||||||
|
|
||||||
|
# Process and color each type of token
|
||||||
|
colored_tokens = process_tokens_for_rl_debug(
|
||||||
|
input_tokens, "yellow", tokenizer, text_only
|
||||||
|
)
|
||||||
|
colored_chosens = process_tokens_for_rl_debug(
|
||||||
|
labels_chosen, "green", tokenizer, text_only
|
||||||
|
)
|
||||||
|
colored_rejecteds = process_tokens_for_rl_debug(
|
||||||
|
labels_rejected, "red", tokenizer, text_only
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a delimiter based on text_only flag
|
||||||
|
delimiter = "" if text_only else " "
|
||||||
|
|
||||||
|
# Logging information
|
||||||
|
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
|
||||||
|
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
||||||
|
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
||||||
|
|
||||||
|
return delimiter.join(colored_tokens)
|
||||||
|
|
||||||
|
|
||||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
||||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
GLAIVE_TO_SHAREGPT_ROLE = {
|
||||||
"SYSTEM": "system",
|
"SYSTEM": "system",
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -13,7 +14,7 @@ from datasets import set_caching_enabled
|
|||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
@@ -98,17 +99,89 @@ def add_position_ids(sample):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def add_pose_position_ids(
|
||||||
|
sample,
|
||||||
|
max_context_len=32768,
|
||||||
|
split_on_token_ids: Optional[List[int]] = None,
|
||||||
|
chunks: int = 2,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
use the PoSE technique to extend the context length by randomly skipping
|
||||||
|
positions in the context. We only want to skip right before tokens in
|
||||||
|
the split_on_token_ids list. We should attempt to randomly distribute
|
||||||
|
the skips, but we don't need the final position_ids to be the full
|
||||||
|
context_len. There may be multiple turns in the context, so we want to
|
||||||
|
make sure we take into account the maximum possible number of skips
|
||||||
|
remaining in each sample.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
sample_len = len(input_ids)
|
||||||
|
max_skips = max_context_len - sample_len
|
||||||
|
|
||||||
|
if split_on_token_ids is None:
|
||||||
|
split_on_token_ids = []
|
||||||
|
|
||||||
|
if split_on_token_ids:
|
||||||
|
split_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
chunk_len = sample_len // chunks
|
||||||
|
split_indices = [i * chunk_len for i in range(1, chunks)]
|
||||||
|
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
|
||||||
|
if split_indices[0] < 2:
|
||||||
|
# drop the first split index if it's too close to the beginning
|
||||||
|
split_indices = split_indices[1:]
|
||||||
|
|
||||||
|
position_ids = []
|
||||||
|
prev_index = 0
|
||||||
|
total_skips = 0
|
||||||
|
|
||||||
|
for split_index in split_indices:
|
||||||
|
num_skips = (
|
||||||
|
random.randint(0, max_skips) # nosec B311
|
||||||
|
if prev_index != 0 and max_skips
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
max_skips -= num_skips
|
||||||
|
total_skips += num_skips
|
||||||
|
|
||||||
|
segment_position_ids = list(
|
||||||
|
range(prev_index + total_skips, split_index + total_skips)
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids.extend(segment_position_ids)
|
||||||
|
prev_index = split_index
|
||||||
|
|
||||||
|
sample["sequence_len"] = position_ids[-1]
|
||||||
|
position_ids = torch.tensor(position_ids)
|
||||||
|
|
||||||
|
sample["position_ids"] = position_ids
|
||||||
|
sample["length"] = len(position_ids)
|
||||||
|
assert len(position_ids) == len(input_ids)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def add_length(sample):
|
def add_length(sample):
|
||||||
sample["length"] = len(sample["input_ids"])
|
sample["length"] = len(sample["input_ids"])
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq(sample, sequence_len=2048):
|
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||||
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
|
return (
|
||||||
|
len(sample["input_ids"]) <= sequence_len
|
||||||
|
and len(sample["input_ids"]) >= min_sequence_len
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
drop_long = partial(
|
||||||
|
drop_long_seq,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
min_sequence_len=cfg.min_sample_len or 2,
|
||||||
|
)
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
if cfg.is_preprocess:
|
if cfg.is_preprocess:
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||||
@@ -153,7 +226,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
desc="Group By Length",
|
desc="Group By Length",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.sample_packing:
|
if cfg.use_pose:
|
||||||
|
pose_kwargs = {}
|
||||||
|
if cfg.pose_num_chunks is not None:
|
||||||
|
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
||||||
|
pose_fn = partial(
|
||||||
|
add_pose_position_ids,
|
||||||
|
max_context_len=cfg.pose_max_context_len,
|
||||||
|
split_on_token_ids=cfg.pose_split_on_token_ids,
|
||||||
|
**pose_kwargs,
|
||||||
|
)
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
pose_fn,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Add position_id column (PoSE)",
|
||||||
|
)
|
||||||
|
train_dataset = train_dataset.sort("sequence_len")
|
||||||
|
if cfg.eval_sample_packing is not False:
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
pose_fn,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Add position_id column (PoSE)",
|
||||||
|
)
|
||||||
|
elif cfg.sample_packing:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
@@ -306,6 +404,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
|
|
||||||
def setup_fsdp_envs(cfg):
|
def setup_fsdp_envs(cfg):
|
||||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||||
|
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
||||||
|
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
|
||||||
if cfg.fsdp_config.fsdp_offload_params:
|
if cfg.fsdp_config.fsdp_offload_params:
|
||||||
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||||
if cfg.fsdp_config.fsdp_sync_module_states:
|
if cfg.fsdp_config.fsdp_sync_module_states:
|
||||||
@@ -338,8 +438,8 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]:
|
||||||
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
|
|||||||
return load_model(cfg, tokenizer)
|
return load_model(cfg, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
class TestHFDPOTrainerBuilder:
|
class TestHFRLTrainerBuilder:
|
||||||
"""
|
"""
|
||||||
TestCase class for DPO trainer builder
|
TestCase class for DPO trainer builder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||||
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
|
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||||
training_arguments = builder.build_training_arguments(100)
|
training_arguments = builder.build_training_arguments(100)
|
||||||
assert training_arguments.adam_beta1 == 0.998
|
assert training_arguments.adam_beta1 == 0.998
|
||||||
assert training_arguments.adam_beta2 == 0.9
|
assert training_arguments.adam_beta2 == 0.9
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
@@ -74,7 +74,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
|
|||||||
@@ -158,3 +158,50 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_orpo_lora(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 64,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_dropout": 0.1,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"special_tokens": {},
|
||||||
|
"rl": "orpo",
|
||||||
|
"orpo_alpha": 0.1,
|
||||||
|
"remove_unused_columns": False,
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
|
||||||
|
"type": "chat_template.argilla",
|
||||||
|
"split": "train",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "paged_adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"warmup_steps": 5,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
@@ -87,7 +87,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": False,
|
"flash_attention": False,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
@@ -141,7 +141,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
@@ -198,7 +198,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": False,
|
"flash_attention": False,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
@@ -255,7 +255,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.1,
|
||||||
|
|||||||
@@ -27,7 +27,9 @@ def fixture_alpaca_dataset():
|
|||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
def fixture_tokenizer():
|
def fixture_tokenizer():
|
||||||
# pylint: disable=all
|
# pylint: disable=all
|
||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
||||||
|
)
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{
|
{
|
||||||
"eos_token": AddedToken(
|
"eos_token": AddedToken(
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ def fixture_sharegpt_dataset():
|
|||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
def fixture_tokenizer():
|
def fixture_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
||||||
|
)
|
||||||
tokenizer.add_tokens(
|
tokenizer.add_tokens(
|
||||||
[
|
[
|
||||||
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),
|
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),
|
||||||
|
|||||||
@@ -96,7 +96,9 @@ def fixture_multi_role_dataset():
|
|||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
def fixture_tokenizer():
|
def fixture_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
||||||
|
)
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{
|
{
|
||||||
"eos_token": AddedToken(
|
"eos_token": AddedToken(
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
||||||
self.dataset.save_to_disk(tmp_ds_name)
|
self.dataset.save_to_disk(str(tmp_ds_name))
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -454,7 +454,9 @@ class OrpoTokenizationTest(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
|
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
||||||
|
)
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{
|
{
|
||||||
"eos_token": AddedToken(
|
"eos_token": AddedToken(
|
||||||
|
|||||||
@@ -1067,17 +1067,51 @@ class TestValidation(BaseValidation):
|
|||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_hub_model_id_save_value_warns(self, minimal_cfg):
|
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
|
||||||
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
with self._caplog.at_level(logging.WARNING):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
assert (
|
assert len(self._caplog.records) == 1
|
||||||
"set without any models being saved" in self._caplog.records[0].message
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_hub_model_id_save_value(self, minimal_cfg):
|
def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
|
||||||
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
|
cfg = (
|
||||||
|
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 1
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_steps(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 0
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_epochs(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 0
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_none(self, minimal_cfg):
|
||||||
|
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 0
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
|
||||||
|
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
with self._caplog.at_level(logging.WARNING):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user