Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
f8bb4185bc skip s2 attention test due to timeout 2024-04-08 18:33:33 -04:00
68 changed files with 249 additions and 2153 deletions

View File

@@ -32,11 +32,6 @@ jobs:
python_version: "3.11"
pytorch: 2.2.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3

View File

@@ -7,7 +7,6 @@ on:
- 'requirements.txt'
- '.github/workflows/*.yml'
- "*.md"
- "examples/**/*.y[a]?ml"
workflow_dispatch:
jobs:

View File

@@ -30,11 +30,6 @@ jobs:
python_version: "3.11"
pytorch: 2.2.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -91,11 +86,6 @@ jobs:
python_version: "3.11"
pytorch: 2.2.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -29,11 +29,6 @@ jobs:
python_version: "3.11"
pytorch: 2.2.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -91,11 +86,6 @@ jobs:
python_version: "3.11"
pytorch: 2.2.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

1
.gitignore vendored
View File

@@ -133,7 +133,6 @@ venv/
ENV/
env.bak/
venv.bak/
venv3.10/
# Spyder project settings
.spyderproject

View File

@@ -44,7 +44,6 @@ Features:
- Advanced Topics
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
@@ -82,7 +81,6 @@ Features:
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
@@ -427,7 +425,7 @@ deepspeed: deepspeed_configs/zero1.json
```
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
```
##### FSDP

View File

@@ -1,6 +1,4 @@
{
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {

View File

@@ -1,6 +1,4 @@
{
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"zero_optimization": {
"stage": 3,
"offload_param": {

View File

@@ -227,12 +227,6 @@ lora_modules_to_save:
lora_fan_in_fan_out: false
# LoRA+ hyperparameters
# For more details about the following options, see:
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
peft:
# Configuration options for loftq initialization for LoRA
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
@@ -274,7 +268,6 @@ torch_compile_backend: # Optional[str]
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
micro_batch_size: 2
eval_batch_size:
num_epochs: 4
@@ -419,7 +412,6 @@ special_tokens:
# bos_token: "<s>"
# eos_token: "</s>"
# unk_token: "<unk>"
# pad_token: "[PAD]"
# Add extra tokens.
tokens:

View File

@@ -1,35 +0,0 @@
---
title: Dataset Preprocessing
description: How datasets are processed
---
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
the (dataset format)[../dataset-formats/] and prompt strategies to:
- parse the dataset based on the *dataset format*
- transform the dataset to how you would interact with the model based on the *prompt strategy*
- tokenize the dataset based on the configured model & tokenizer
- shuffle and merge multiple datasets together if using more than one
The processing of the datasets can happen one of two ways:
1. Before kicking off training by calling `python -m axolotl.cli.preprocess /path/to/your.yaml --debug`
2. When training is started
What are the benefits of pre-processing? When training interactively or for sweeps
(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly
slow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent
training parameters so that it will intelligently pull from its cache when possible.
The path of the cache is controlled by `dataset_prepared_path:` and is often left blank in example
YAMLs as this leads to a more robust solution that prevents unexpectedly reusing cached data.
If `dataset_prepared_path:` is left empty, when training, the processed dataset will be cached in a
default path of `./last_run_prepared/`, but will ignore anything already cached there. By explicitly
setting `dataset_prepared_path: ./last_run_prepared`, the trainer will use whatever pre-processed
data is in the cache.
What are the edge cases? Let's say you are writing a custom prompt strategy or using a user-defined
prompt template. Because the trainer cannot readily detect these changes, we cannot change the
calculated hash value for the pre-processed dataset. If you have `dataset_prepared_path: ...` set
and change your prompt templating logic, it may not pick up the changes you made and you will be
training over the old prompt.

View File

@@ -49,7 +49,7 @@ remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argilla
type: orpo.chat_template
```
#### Using local dataset files

View File

@@ -1,81 +0,0 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_activation_checkpointing: true

View File

@@ -1,81 +0,0 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_activation_checkpointing: true

View File

@@ -1,26 +0,0 @@
# DBRX MoE
Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.
We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
results in the trainer hanging.
### FSDP
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.
The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.
- 16-bit LoRA w/ FSDP
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
- ✅ 8-bit LoRA w/ FSDP
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)
### Deepspeed
WIP

View File

@@ -1,56 +0,0 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
unfrozen_parameters:
- transformer.blocks.[0-7].
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
deepspeed: deepspeed_configs/zero3_bf16.json

View File

@@ -65,14 +65,12 @@ deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_state_dict_type: SHARDED_STATE_DICT
special_tokens:

View File

@@ -1,13 +0,0 @@
# Llama-3
https://llama.meta.com/llama3/
[8B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
- [Full Fine Tune](./fft-8b.yaml)
- Single GPU @ 48GB VRAM
- [LoRA](./lora-8b.yml)
- Single GPU @ 11GB VRAM
[70B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
- [QLORA+FSDP](./qlora-fsdp-70b.yaml)
- Dual GPU @ 21GB VRAM

View File

@@ -1,58 +0,0 @@
base_model: meta-llama/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,67 +0,0 @@
base_model: meta-llama/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,80 +0,0 @@
base_model: casperhansen/llama-3-70b-fp16
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer # PreTrainedTokenizerFast
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out/qlora-llama3-70b
adapter: qlora
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,67 +0,0 @@
base_model: meta-llama/Meta-Llama-3-8B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: aaditya/alpaca_subset_1
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,63 +0,0 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
- model.layers.4[4-9]+.block_sparse_moe.gate
- model.layers.4[4-9]+.block_sparse_moe.experts
- model.layers.5[0-5]+.block_sparse_moe.gate
- model.layers.5[0-5]+.block_sparse_moe.experts
model_config:
output_router_logits: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
save_total_limit: 1
save_steps:
debug:
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"

View File

@@ -1,82 +0,0 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: false
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:

View File

@@ -1,82 +0,0 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argilla
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./mistral-qlora-orpo-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,81 +0,0 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
@@ -47,7 +47,7 @@ train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
tf32: false
gradient_checkpointing: true
early_stopping_patience:
@@ -69,17 +69,6 @@ debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_sharding_strategy: FULL_SHARD
fsdp_forward_prefetch: false
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:

View File

@@ -1,61 +0,0 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
- model.layers.4[4-9]+.block_sparse_moe.gate
- model.layers.4[4-9]+.block_sparse_moe.experts
- model.layers.5[0-5]+.block_sparse_moe.gate
- model.layers.5[0-5]+.block_sparse_moe.experts
model_config:
output_router_logits: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
output_dir: ./out
sequence_len: 8000
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
save_total_limit: 1
save_steps:
debug:
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"

View File

@@ -16,7 +16,7 @@ sequence_len: 1024 # supports up to 32k
sample_packing: false
pad_to_sequence_len: false
adapter: qlora
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16

View File

@@ -11,7 +11,7 @@ addict
fire
PyYAML>=6.0
requests
datasets==2.15.0
datasets>=2.15.0
flash-attn==2.5.5
sentencepiece
wandb
@@ -28,7 +28,7 @@ scipy
scikit-learn==1.2.2
pynvml
art
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
fschat==0.2.36
gradio==3.50.2
tensorboard
@@ -39,6 +39,5 @@ s3fs
gcsfs
# adlfs
trl==0.8.5
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
zstandard==0.22.0
fastcore

View File

@@ -33,7 +33,7 @@ fi
if [ "$JUPYTER_DISABLE" != "1" ]; then
# Run Jupyter Lab in the background
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
fi
# Execute the passed arguments (CMD)

View File

@@ -24,7 +24,6 @@ from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
@@ -63,20 +62,6 @@ def print_axolotl_text_art(suffix=None):
if is_main_process():
print(ascii_art)
print_dep_versions()
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}")
print("*" * 40)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
@@ -264,8 +249,8 @@ def do_inference_gradio(
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
temperature=cfg.get("gradio_temperature", 0.9),
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
@@ -300,13 +285,7 @@ def do_inference_gradio(
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
demo.queue().launch(show_api=False, share=True)
def choose_config(path: Path):
@@ -439,23 +418,6 @@ def load_rl_datasets(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,

View File

@@ -25,8 +25,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
deepspeed=None,
fsdp=None,
**kwargs,
)

View File

@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else:
register_chatml_template()
if cfg.rl: # and cfg.rl != "orpo":
if cfg.rl and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -30,20 +30,18 @@ from transformers import (
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
from trl import DPOTrainer
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
SaveModelOnTrainEndCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
@@ -55,7 +53,6 @@ from axolotl.utils.collators import (
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
@@ -74,6 +71,10 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
@@ -213,10 +214,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
class AxolotlTrainer(Trainer):
@@ -352,8 +349,6 @@ class AxolotlTrainer(Trainer):
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
def _get_eval_sampler(
@@ -818,14 +813,6 @@ class AxolotlDPOTrainer(DPOTrainer):
return res
class AxolotlORPOTrainer(ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
@@ -889,14 +876,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
return callbacks
@@ -942,27 +921,29 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
):
callbacks.append(SaveBetterTransformerModelCallback())
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
callbacks.append(SaveModelOnTrainEndCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
trainer, self.tokenizer
)
callbacks.append(LogPredictionCallback(self.cfg))
@@ -1201,7 +1182,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = None
if self.cfg.use_wandb:
report_to = "wandb"
@@ -1422,15 +1402,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
class HFRLTrainerBuilder(TrainerBuilderBase):
class HFDPOTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for DPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(SaveModelOnTrainEndCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -1466,7 +1444,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else:
training_args_kwargs["evaluation_strategy"] = "no"
if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True
@@ -1518,16 +1495,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_cls = TrainingArguments
if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args = training_args_cls(
training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
@@ -1560,34 +1528,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
if self.cfg.rl == "dpo":
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls(
*trainer_cls_args,
dpo_trainer = AxolotlDPOTrainer(
self.model,
self.model_ref,
args=training_args,
beta=self.cfg.dpo_beta or 0.1,
train_dataset=self.train_dataset,
tokenizer=self.tokenizer,
max_length=self.cfg.sequence_len,
max_target_length=None,
max_prompt_length=self.cfg.sequence_len,
generate_during_eval=True,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)

View File

@@ -123,14 +123,6 @@ def get_turns( # pylint: disable=too-many-return-statements
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.GEMMA:
if self.system_message:
raise ValueError("Gemma chat template does not support system messages")
for i, (role, message) in enumerate(self.messages):
prefix = "<bos>" if i == 0 else ""
message_str = message if message else ""
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926

View File

@@ -516,18 +516,24 @@ def mistral_model_forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = (
self._gradient_checkpointing_func( # pylint: disable=protected-access
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(

View File

@@ -6,4 +6,4 @@ from functools import partial
from ..base import load as load_base
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
load = partial(load_base, module="axolotl.prompt_strategies.orpo")

View File

@@ -78,57 +78,6 @@ class ORPODatasetParsingStrategy:
)
return MessageList(messages=messages)
def get_prompt(self, prompt) -> MessageList:
"""Map the data to extract everything up to the last turn"""
total_msg_len = len(prompt["chosen"])
total_msg_turns, remainder = divmod(total_msg_len, 2)
assert remainder == 0, "invalid number of turns"
messages: List[Message] = []
if system := prompt.get("system", None):
messages.append(Message(role="system", content=system, label=False))
for i in range(total_msg_turns):
if "prompt" in prompt:
messages.append(
Message(role="user", content=prompt["prompt"], label=False)
)
else:
messages.append(
Message(
role="user",
content=prompt["chosen"][i * 2]["content"],
label=False,
)
)
if i < total_msg_turns - 1:
messages.append(
Message(
role="assistant",
content=prompt["chosen"][i * 2 + 1]["content"],
label=False,
)
)
return MessageList(messages=messages)
def get_chosen(self, prompt) -> MessageList:
res = self.get_prompt(prompt)
res.messages.append(
Message(
role="assistant", content=prompt["chosen"][-1]["content"], label=True
)
)
return res
def get_rejected(self, prompt) -> MessageList:
res = self.get_prompt(prompt)
res.messages.append(
Message(
role="assistant", content=prompt["rejected"][-1]["content"], label=True
)
)
return res
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
"""
@@ -237,36 +186,3 @@ class ORPOPrompter(Prompter):
chat_template=self.chat_template,
tokenize=False,
), True
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
return res
return transform_fn

View File

@@ -1,7 +1,7 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -39,40 +39,76 @@ def register_chatml_template(system_message=None):
)
def build_loader(
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
prompter_cls: Type["ShareGPTPrompterV2"],
default_conversation: Optional[str] = None,
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else default_conversation
)
field_human = (
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
)
field_model = (
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
)
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = tokenization_strategy_cls(
prompter_cls(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
return strategy
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
return _load
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -122,9 +158,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(
SimpleShareGPTPromptTokenizingStrategy
):
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
"""
@@ -175,16 +209,3 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
conversation = merge_consecutive_messages(conversation)
return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -348,10 +348,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
)
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
if (
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])

View File

@@ -3,14 +3,12 @@
import os
import signal
import sys
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
import torch
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import Dataset
from peft import PeftModel
@@ -83,8 +81,6 @@ def train(
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True
@@ -128,20 +124,14 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
_model = model_weakref()
if cfg.flash_optimum and BetterTransformer:
_model = BetterTransformer.reverse(_model)
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
def terminate_handler(_, __, model):
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
_model_weakref = weakref.ref(model)
signal.signal(
signal.SIGINT,
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
@@ -212,10 +202,6 @@ def train(
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:

View File

@@ -1,8 +0,0 @@
"""
Basic utils for Axolotl
"""
import importlib
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None

View File

@@ -6,7 +6,7 @@ import logging
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Dict, List
import evaluate
import numpy as np
@@ -27,9 +27,7 @@ from transformers import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import (
barrier,
broadcast_dict,
@@ -542,7 +540,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
@@ -599,13 +597,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
return ranges
def log_table_from_dataloader(name: str, table_dataloader):
table_data: Dict[str, List[Any]] = {
"id": [],
"Prompt": [],
"Correct Completion": [],
"Predicted Completion (model.generate)": [],
"Predicted Completion (trainer.prediction_step)": [],
}
table = wandb.Table( # type: ignore[attr-defined]
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion (model.generate)",
"Predicted Completion (trainer.prediction_step)",
]
)
row_index = 0
for batch in tqdm(table_dataloader):
@@ -709,29 +709,16 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts
):
table_data["id"].append(row_index)
table_data["Prompt"].append(prompt_text)
table_data["Correct Completion"].append(completion_text)
table_data["Predicted Completion (model.generate)"].append(
prediction_text
table.add_data(
row_index,
prompt_text,
completion_text,
prediction_text,
pred_step_text,
)
table_data[
"Predicted Completion (trainer.prediction_step)"
].append(pred_step_text)
row_index += 1
if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
elif logger == "mlflow" and is_mlflow_available():
import mlflow
tracking_uri = AxolotlInputConfig(
**self.cfg.to_dict()
).mlflow_tracking_uri
mlflow.log_table(
data=table_data,
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
@@ -761,11 +748,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
artifact = wandb.Artifact(
f"config-{wandb.run.id}", type="axolotl-config"
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
LOG.info(
"The Axolotl config has been saved to the WandB run under files."
@@ -773,13 +755,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control
class SaveModelOnTrainEndCallback(TrainerCallback):
"""Callback to save model on train end"""
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
):
control.should_save = True
return control

View File

@@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
if cfg.gptq and cfg.revision_of_model:
@@ -448,14 +448,10 @@ def legacy_validate_config(cfg):
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
@@ -468,6 +464,11 @@ def legacy_validate_config(cfg):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if (
cfg.evaluation_strategy
and cfg.eval_steps

View File

@@ -98,7 +98,6 @@ class SFTDataset(BaseModel):
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
@@ -259,7 +258,6 @@ class ModelInputConfig(BaseModel):
base_model: str
base_model_config: Optional[str] = None
cls_model_config: Optional[str] = None
tokenizer_config: Optional[str] = None
tokenizer_use_fast: Optional[bool] = None
tokenizer_legacy: Optional[bool] = None
@@ -409,17 +407,6 @@ class WandbConfig(BaseModel):
return data
class GradioConfig(BaseModel):
"""Gradio configuration subset"""
gradio_title: Optional[str] = None
gradio_share: Optional[bool] = None
gradio_server_name: Optional[str] = None
gradio_server_port: Optional[int] = None
gradio_max_new_tokens: Optional[int] = None
gradio_temperature: Optional[float] = None
# pylint: disable=too-many-public-methods,too-many-ancestors
class AxolotlInputConfig(
ModelInputConfig,
@@ -430,7 +417,6 @@ class AxolotlInputConfig(
WandbConfig,
MLFlowConfig,
LISAConfig,
GradioConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,
@@ -491,7 +477,6 @@ class AxolotlInputConfig(
eval_causal_lm_metrics: Optional[List[str]] = None
do_bench_eval: Optional[bool] = None
bench_dataset: Optional[str] = None
bench_split: Optional[str] = None
metric_for_best_model: Optional[str] = None
greater_is_better: Optional[bool] = None
@@ -507,25 +492,15 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
default=False
)
gradient_checkpointing: Optional[bool] = Field(default=False)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
# for PoSE context length extension
use_pose: Optional[bool] = None
pose_split_on_token_ids: Optional[List[int]] = None
pose_max_context_len: Optional[int] = None
pose_num_chunks: Optional[int] = None
pretrain_multipack_buffer_size: Optional[int] = 10_000
pretrain_multipack_attn: Optional[bool] = Field(
@@ -792,11 +767,11 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_push_save(cls, data):
if data.get("hub_model_id") and (
data.get("save_strategy") not in ["steps", "epoch", None]
if data.get("hub_model_id") and not (
data.get("save_steps") or data.get("saves_per_epoch")
):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
return data
@@ -995,16 +970,9 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_fsdp_offload_w_8bit_optimizer(cls, data):
if (
data.get("fsdp")
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_offload_params")
):
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
)
def check_fsdp_w_8bit_optimizer(cls, data):
if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
return data
@model_validator(mode="before")

View File

@@ -1,11 +1,11 @@
"""
Data processing modules
"""
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper,
load_prepare_datasets,

View File

@@ -1,20 +1,17 @@
"""data handling specific to DPO"""
import inspect
import logging
from functools import partial
from pathlib import Path
from typing import Any, List
import yaml
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets import concatenate_datasets, load_dataset, load_from_disk
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
LOG = logging.getLogger("axolotl")
@@ -75,29 +72,16 @@ def load_prepare_dpo_datasets(cfg):
)
split_datasets.insert(i, ds)
tokenizer = None
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
if _cfg.rl == "orpo":
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
sig = inspect.signature(ds_transform_fn)
if "tokenizer" in sig.parameters:
if not tokenizer:
tokenizer = load_tokenizer(_cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
data_set = data_set.map(
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
split_datasets[i] = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
)
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
split_datasets[i] = data_set
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed

View File

@@ -379,15 +379,14 @@ def load_tokenized_prepared_datasets(
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if isinstance(ds, DatasetDict):
if config_dataset.split and config_dataset.split in ds:
ds = ds[config_dataset.split]
elif split in ds:
ds = ds[split]
else:
raise ValueError(
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)
if config_dataset.split and config_dataset.split in ds:
ds = ds[config_dataset.split]
elif split in ds:
ds = ds[split]
elif isinstance(ds, DatasetDict):
raise ValueError(
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
)
# support for using a subset of the data
if config_dataset.shards:
@@ -421,7 +420,7 @@ def load_tokenized_prepared_datasets(
if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path))
dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub:
LOG.info(
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"

View File

@@ -4,25 +4,27 @@ utility helpers for distributed checks
import os
import pickle # nosec
from contextlib import contextmanager
from datetime import timedelta
import torch
import torch.distributed as dist
from accelerate import PartialState
from accelerate import Accelerator
distributed_state = None # pylint: disable=invalid-name
accelerate = None # pylint: disable=invalid-name
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed():
"""
Check if distributed training is initialized.
"""
global distributed_state # pylint: disable=global-statement
if not distributed_state:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
return distributed_state.use_distributed and distributed_state.initialized
global accelerate # pylint: disable=global-statement
if not accelerate:
accelerate = Accelerator()
return dist.is_available() and dist.is_initialized()
def barrier():

View File

@@ -1,13 +0,0 @@
"""custom checkpointing utils"""
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer.__self__,
*args,
)

View File

@@ -1,52 +0,0 @@
"""Unsloth checkpointing"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""
Saves VRAM by smartly offloading to RAM.
Tiny hit to performance, since we mask the movement via non blocking calls.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (
None,
hidden_states.grad,
) + (
None,
) * len(ctx.args)

View File

@@ -1,259 +0,0 @@
"""
module to handle loading model on cpu/meta device for FSDP
"""
import os
import time
from typing import List, Optional, Type, Union
import safetensors
import torch
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from torch import Tensor, nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
def _replace_linear(
model: nn.Module,
linear_replacement: Type[nn.Module],
quant_config: Union[dict, None] = None,
skip_modules=None,
**kwargs,
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
if skip_modules is None:
skip_modules = ["lm_head"]
for name, module in model.named_children():
if len(list(module.children())) > 0:
_replace_linear(
module, linear_replacement, quant_config, skip_modules, **kwargs
)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access
name
] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
**kwargs,
)
else:
raise ValueError(
f"Unsupported linear replacement: {type(linear_replacement)}"
)
return model
def load_and_quantize(
module: nn.Module,
name: str,
value: Tensor,
device: torch.device = None,
dtype: torch.dtype = None,
skip_names: Optional[List[str]] = None,
to_cpu: bool = False,
to_meta: bool = False,
verbose: bool = False,
quant_method: str = "bnb",
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
"""
if not skip_names:
skip_names = []
def place_on_device(value):
if to_meta:
device = "meta"
elif to_cpu:
device = "cpu"
return value.to(device=device, dtype=dtype)
if any(skip_name in name for skip_name in skip_names):
if verbose:
print(f"Skipping {name} because it is in skip_names")
return
module_key, _, value_key = name.rpartition(".")
try:
submodule = module.get_submodule(module_key)
except AttributeError as exc:
print(f"Module {module_key} not found:\n{exc}")
return
try:
if quant_method == "bnb":
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
if to_meta:
value = type(param)(value.data.to("meta"), **value.__dict__)
elif to_cpu:
value = type(param)(value.data.to("cpu"), **value.__dict__)
else:
value = type(param)(place_on_device(value).data)
except AttributeError:
# it's a buffer
value = place_on_device(value)
setattr(submodule, value_key, value)
def n_loading_workers(quant_method: str, param_count: float):
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count() / torch.cuda.device_count())
model_params_b = 70
right = int(
(4 if quant_method == "hqq" else 8)
* (devprops.total_memory / 1e9 / 40)
* (model_params_b / (param_count / 1e9))
)
return min(left, right)
def load_sharded_model(
model_name,
model_config,
cfg,
torch_dtype=torch.bfloat16,
low_memory=True,
):
if (low_memory and cfg.local_rank == 0) or not low_memory:
model = AutoModelForCausalLM.from_pretrained(
model_name,
use_cache=False,
torch_dtype=torch.float32,
_attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
trust_remote_code=cfg.trust_remote_code,
)
dtype = torch_dtype if not cfg.float32 else None
model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
else:
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code,
)
return model
def load_sharded_model_quant(
model_name,
model_config,
cfg,
compute_dtype=torch.bfloat16,
quant_storage=torch.float32,
low_memory=True,
verbose=False,
loading_workers=2,
):
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config,
trust_remote_code=cfg.trust_remote_code,
)
if hasattr(model, "transformer"):
model.transformer = _replace_linear(
model.transformer,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
)
else:
# this is the more common case with HF transformers
model.model = _replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
)
model.is_loaded_in_4bit = True
# Grab the safetensors files that hold the weights
try:
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)
except OSError:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
except OSError as exc:
# This means the model probably doesn't have a safetensors file
raise exc
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)
quant_method = "bnb"
param_count = sum((p.numel() for n, p in model.named_parameters()))
n_workers = (
n_loading_workers(quant_method, param_count)
if loading_workers == -1
else loading_workers
)
if cfg.local_rank == 0 and verbose:
print(f"Using n_workers: {n_workers} for loading")
start = time.time()
for filename in tqdm(
files,
desc="Loading & Quantizing Model Shards",
disable=cfg.local_rank != 0,
position=0,
):
weights = safetensors.torch.load_file(filename)
parallel(
load_and_quantize_parallel,
iter(weights.items()),
n_workers=n_workers,
threadpool=True,
model=model,
dtype=quant_storage,
device=cfg.local_rank,
skip_names=[],
to_cpu=(low_memory and cfg.local_rank == 0),
to_meta=(low_memory and cfg.local_rank != 0),
verbose=verbose,
quant_method=quant_method,
)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache()
return model

View File

@@ -1,5 +1,4 @@
"""Module for models and model loading"""
# pylint: disable=too-many-lines
import logging
@@ -12,7 +11,6 @@ import addict
import bitsandbytes as bnb
import torch
import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights
from bitsandbytes.nn import Params4bit
from peft import (
@@ -46,37 +44,11 @@ from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger("axolotl")
# copied from accelerator.FullyShardedDataParallelPlugin
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
if len(modules_children) == 0:
return None
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
return None
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
quant_config_exists = (
hasattr(model_config, "quantization_config")
@@ -313,9 +285,6 @@ def load_model(
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
if cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
@@ -490,7 +459,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
if not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
@@ -501,16 +470,6 @@ def load_model(
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif cfg.adapter == "lora" and cfg.load_in_8bit:
bnb_config = {
"load_in_8bit": True,
}
# Exclude mamba blocks from int8 quantization for jamba
if cfg.model_config_type == "jamba":
bnb_config["llm_int8_skip_modules"] = ["mamba"]
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
if cfg.load_in_8bit and cfg.adapter is not None:
model_kwargs["load_in_8bit"] = True
@@ -558,31 +517,7 @@ def load_model(
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
try:
skip_move_to_device = False
if (
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
) and not qlora_fsdp:
model = load_sharded_model(
base_model,
model_config,
cfg,
torch_dtype=cfg.torch_dtype,
)
skip_move_to_device = True
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.model_config_type == "dbrx"
):
quant_storage = cfg.torch_dtype
model = load_sharded_model_quant(
base_model,
model_config,
cfg,
quant_storage=quant_storage,
)
skip_move_to_device = True
elif (
model_config.model_type == "llama"
and not cfg.trust_remote_code
and not cfg.gptq
@@ -662,11 +597,6 @@ def load_model(
**model_kwargs,
)
else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
@@ -740,17 +670,13 @@ def load_model(
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if is_deepspeed_zero3_enabled():
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules,
)
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
if cfg.model_config_type == "mixtral":
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
set_z3_leaf_modules(model, [moe_block])
elif cfg.model_config_type == "dbrx":
moe_block = get_module_class_from_name(model, "DbrxFFN")
set_z3_leaf_modules(model, [moe_block])
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
@@ -760,8 +686,7 @@ def load_model(
if cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True
if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading):
# make sure everything is in the same dtype
if qlora_fsdp:
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]:
@@ -802,7 +727,7 @@ def load_model(
cfg.ddp
and not load_in_8bit
and not (cfg.rl and cfg.load_in_4bit)
and not skip_move_to_device
and not qlora_fsdp
):
# TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}")
@@ -958,12 +883,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
rank = int(os.environ.get("LOCAL_RANK", 0))
if (
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
setup_quantized_meta_for_peft(model)
if cfg.lora_model_dir:
@@ -988,22 +908,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
LOG.warning(
"Exception caught during model.print_trainable_parameters(): %s", exc
)
elif (
cfg.fsdp
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and rank != 0
):
elif cfg.fsdp and cfg.adapter == "qlora":
setup_quantized_peft_meta_for_training(model)
return model, lora_config
def ensure_dtype(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
try:
if module.weight.dtype != dtype:
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
module.to(dtype)
except AttributeError:
pass

View File

@@ -1,5 +1,6 @@
"""Module for tokenization utilities"""
import logging
import re
from typing import Dict, List
@@ -9,19 +10,10 @@ from termcolor import colored
LOG = logging.getLogger("axolotl")
def check_dataset_labels(
dataset,
tokenizer,
num_examples=5,
text_only=False,
rl_mode=False,
):
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(num_examples):
if not rl_mode:
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
else:
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
def check_example_labels(example, tokenizer, text_only=False):
@@ -48,53 +40,6 @@ def check_example_labels(example, tokenizer, text_only=False):
return " ".join(colored_tokens)
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
"""Helper function to color tokens based on their type."""
colored_text = colored(decoded_token, color)
return (
colored_text
if text_only
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
)
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
"""Helper function to process and color tokens."""
colored_tokens = [
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
for token in tokenizer.encode(tokens)
]
return colored_tokens
def check_rl_example_labels(example, tokenizer, text_only=False):
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
input_tokens = example[field_prompt]
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
# Process and color each type of token
colored_tokens = process_tokens_for_rl_debug(
input_tokens, "yellow", tokenizer, text_only
)
colored_chosens = process_tokens_for_rl_debug(
labels_chosen, "green", tokenizer, text_only
)
colored_rejecteds = process_tokens_for_rl_debug(
labels_rejected, "red", tokenizer, text_only
)
# Create a delimiter based on text_only flag
delimiter = "" if text_only else " "
# Logging information
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
return delimiter.join(colored_tokens)
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
GLAIVE_TO_SHAREGPT_ROLE = {
"SYSTEM": "system",

View File

@@ -1,10 +1,9 @@
"""Module containing the Trainer class and related functions"""
import math
import os
import random
from contextlib import contextmanager
from functools import partial
from typing import List, Optional
from typing import List
import numpy as np
import torch
@@ -14,7 +13,7 @@ from datasets import set_caching_enabled
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -99,89 +98,17 @@ def add_position_ids(sample):
return sample
def add_pose_position_ids(
sample,
max_context_len=32768,
split_on_token_ids: Optional[List[int]] = None,
chunks: int = 2,
):
"""
use the PoSE technique to extend the context length by randomly skipping
positions in the context. We only want to skip right before tokens in
the split_on_token_ids list. We should attempt to randomly distribute
the skips, but we don't need the final position_ids to be the full
context_len. There may be multiple turns in the context, so we want to
make sure we take into account the maximum possible number of skips
remaining in each sample.
"""
input_ids = sample["input_ids"]
sample_len = len(input_ids)
max_skips = max_context_len - sample_len
if split_on_token_ids is None:
split_on_token_ids = []
if split_on_token_ids:
split_indices = [
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids
]
else:
chunk_len = sample_len // chunks
split_indices = [i * chunk_len for i in range(1, chunks)]
split_indices.append(len(input_ids)) # make sure we go to the end of the sample
if split_indices[0] < 2:
# drop the first split index if it's too close to the beginning
split_indices = split_indices[1:]
position_ids = []
prev_index = 0
total_skips = 0
for split_index in split_indices:
num_skips = (
random.randint(0, max_skips) # nosec B311
if prev_index != 0 and max_skips
else 0
)
max_skips -= num_skips
total_skips += num_skips
segment_position_ids = list(
range(prev_index + total_skips, split_index + total_skips)
)
position_ids.extend(segment_position_ids)
prev_index = split_index
sample["sequence_len"] = position_ids[-1]
position_ids = torch.tensor(position_ids)
sample["position_ids"] = position_ids
sample["length"] = len(position_ids)
assert len(position_ids) == len(input_ids)
return sample
def add_length(sample):
sample["length"] = len(sample["input_ids"])
return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return (
len(sample["input_ids"]) <= sequence_len
and len(sample["input_ids"]) >= min_sequence_len
)
def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
@@ -197,12 +124,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "olmo":
LOG.info("dropping position_ids column")
train_dataset = train_dataset.remove_columns("position_ids")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("position_ids")
if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
@@ -232,32 +153,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
desc="Group By Length",
)
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
if cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
@@ -302,7 +198,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
if update:
cfg.total_num_tokens = total_num_tokens
@@ -316,7 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.sum()
)
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens:_}`",
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
if update:
@@ -343,7 +239,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
* cfg.num_epochs
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
main_process_only=True,
)
else:
@@ -410,8 +306,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_activation_checkpointing:
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true"
if cfg.fsdp_config.fsdp_offload_params:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
if cfg.fsdp_config.fsdp_sync_module_states:
@@ -444,8 +338,8 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
else:

View File

@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
import pytest
from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer)
class TestHFRLTrainerBuilder:
class TestHFDPOTrainerBuilder:
"""
TestCase class for DPO trainer builder
"""
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9

View File

@@ -7,6 +7,8 @@ import os
import unittest
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip("Skipping test due to timeout.")
class TestLlamaShiftedSparseAttention(unittest.TestCase):
"""
Test case for Llama models using S2 Attn

View File

@@ -30,7 +30,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
@@ -74,7 +74,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,

View File

@@ -22,7 +22,7 @@ class TestModelPatches(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,

View File

@@ -158,50 +158,3 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_orpo_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "orpo",
"orpo_alpha": 0.1,
"remove_unused_columns": False,
"chat_template": "chatml",
"datasets": [
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
"type": "chat_template.argilla",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()

View File

@@ -33,7 +33,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sequence_len": 1024,
"load_in_4bit": True,
@@ -87,7 +87,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": False,
"sequence_len": 1024,
"load_in_4bit": True,
@@ -141,7 +141,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sequence_len": 1024,
"adapter": "lora",
@@ -198,7 +198,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": False,
"sequence_len": 1024,
"adapter": "lora",
@@ -255,7 +255,7 @@ class TestMixtral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.1,

View File

@@ -27,9 +27,7 @@ def fixture_alpaca_dataset():
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
# pylint: disable=all
tokenizer = AutoTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(

View File

@@ -43,9 +43,7 @@ def fixture_sharegpt_dataset():
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.add_tokens(
[
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),

View File

@@ -96,9 +96,7 @@ def fixture_multi_role_dataset():
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(

View File

@@ -110,7 +110,7 @@ class TestDatasetPreparation(unittest.TestCase):
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
self.dataset.save_to_disk(str(tmp_ds_name))
self.dataset.save_to_disk(tmp_ds_name)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(

View File

@@ -454,9 +454,7 @@ class OrpoTokenizationTest(unittest.TestCase):
def setUp(self) -> None:
# pylint: disable=duplicate-code
tokenizer = LlamaTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(

View File

@@ -1067,52 +1067,18 @@ class TestValidation(BaseValidation):
):
validate_config(cfg)
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 1
def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 1
def test_hub_model_id_save_value_steps(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_epochs(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_none(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
def test_hub_model_id_save_value_warns(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert (
"set without any models being saved" in self._caplog.records[0].message
)
def test_hub_model_id_save_value(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0