Compare commits

..

10 Commits

Author SHA1 Message Date
Wing Lian
b52e61a574 pretrain fixes for mm 2023-10-30 11:03:55 -04:00
Wing Lian
53f93f67bb fix to set training args so projector properly saves 2023-10-29 06:08:38 -04:00
Wing Lian
ef95ea2977 additional args for parity, fix to properly save projector during pretrain 2023-10-29 05:12:34 -04:00
Wing Lian
1321608dc4 add docs and tweak yml 2023-10-28 13:07:59 -04:00
Wing Lian
7ff30c4033 wip 2023-10-25 09:19:19 -04:00
Wing Lian
faa46fbcf8 fix code for llava parity, add llama yml 2023-10-24 09:45:47 -04:00
Wing Lian
fdc3e4d505 more fixes to try to get mm working 2023-10-23 23:15:33 -04:00
Wing Lian
b885169229 handle load_model splat 2023-10-23 21:55:05 -04:00
Wing Lian
ab9d12ce34 handle dataset loading for multimodal 2023-10-23 21:44:07 -04:00
Wing Lian
866774737b WIP llaval support 2023-10-23 20:29:49 -04:00
53 changed files with 1002 additions and 693 deletions

View File

@@ -32,6 +32,7 @@ Features:
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
- [Training w/ Deepspeed](#training-with-deepspeed)
- [Inference](#inference) - [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base) - [Merge LORA to Base](#merge-lora-to-base)
- [Common Errors](#common-errors-) - [Common Errors](#common-errors-)
@@ -114,25 +115,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
docker compose up -d docker compose up -d
``` ```
<details>
<summary>Docker advanced</summary>
A more powerful Docker command to run would be this:
```bash
docker run --gpus '"all"' --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
```
It additionally:
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
</details>
#### Conda/Pip venv #### Conda/Pip venv
1. Install python >=**3.9** 1. Install python >=**3.9**
@@ -374,13 +356,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript - typescript
type: ... # unimplemented custom format type: ... # unimplemented custom format
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
datasets:
- path: ...
type: sharegpt
conversation: chatml
# local # local
datasets: datasets:
- path: data.jsonl # or json - path: data.jsonl # or json
@@ -419,7 +394,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
<details> <details>
<summary>All yaml options (click me)</summary> <summary>All yaml options</summary>
```yaml ```yaml
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
@@ -486,9 +461,7 @@ datasets:
data_files: # Optional[str] path to source data files data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# Custom user prompt # Custom user prompt
- path: repo - path: repo
@@ -618,14 +591,14 @@ gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU. # The number of samples to include in each batch. This is the number of samples sent to each GPU.
micro_batch_size: 2 micro_batch_size: 2
eval_batch_size: eval_batch_size:
num_epochs: 4 num_epochs: 3
warmup_steps: 100 warmup_steps: 100
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
save_strategy: # Set to `no` to skip checkpoint saves save_strategy: # Set to `no` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch save_steps: # Leave empty to save at each epoch
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps eval_steps: # Leave empty to eval at each epoch
save_total_limit: # Checkpoints saved at a time save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that # Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed. # if both are set, num_epochs will not be guaranteed.
@@ -842,41 +815,14 @@ Run
accelerate launch -m axolotl.cli.train your_config.yml accelerate launch -m axolotl.cli.train your_config.yml
``` ```
#### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
This is recommended for large datasets.
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
- Use `--debug` to see preprocessed examples.
```bash
python -m axolotl.cli.preprocess your_config.yml
```
#### Multi-GPU #### Multi-GPU
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed You can optionally pre-tokenize dataset with the following before finetuning:
is the recommended multi-GPU option currently because FSDP may experience ```bash
[loss instability](https://github.com/huggingface/transformers/issues/26498). CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
##### DeepSpeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```yaml
deepspeed: deepspeed/zero1.json
``` ```
```shell ##### Config
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```
##### FSDP
- llama FSDP - llama FSDP
```yaml ```yaml
@@ -901,6 +847,24 @@ wandb_run_id:
wandb_log_model: wandb_log_model:
``` ```
### Training with Deepspeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```
or
```yaml
deepspeed: deepspeed/zero1.json
```
### Inference ### Inference
Pass the appropriate flag to the train command: Pass the appropriate flag to the train command:

View File

@@ -1,6 +1,14 @@
{ {
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"sub_group_size": 0, "sub_group_size": 0,
@@ -33,13 +41,12 @@
} }
}, },
"scheduler": { "scheduler": {
"type": "WarmupDecayLR", "type": "WarmupLR",
"params": { "params": {
"warmup_min_lr": "auto", "warmup_min_lr": "auto",
"warmup_max_lr": "auto", "warmup_max_lr": "auto",
"warmup_num_steps": "auto", "warmup_num_steps": "auto",
"warmup_type": "linear", "warmup_type": "linear"
"total_num_steps": "auto"
} }
}, },
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",

View File

@@ -12,7 +12,3 @@ This usually happens when you run out of system RAM.
> Exitcode -7 while using deepspeed > Exitcode -7 while using deepspeed
Try upgrading deepspeed w: `pip install -U deepspeed` Try upgrading deepspeed w: `pip install -U deepspeed`
> AttributeError: 'DummyOptim' object has no attribute 'step'
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.

36
docs/llava.md Normal file
View File

@@ -0,0 +1,36 @@
# LLaVA
### Installing dependencies
```shell
git clone https://github.com/haotian-liu/LLaVA.git
cd LLaVA
pip install --no-deps -e .
```
### Downloading assets
LLaVA doesn't support remote datasets, so both the JSON and image assets need to be downloaded locally
```shell
mkdir llava
mkdir data
cd llava
curl -L -O https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip
unzip images.zip
cd ../data
curl -L -O https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/blip_laion_cc_sbu_558k.json
```
### Pretraining
Pretraining aligns the vision model with the language model.
```shell
accelerate launch -m axolotl.cli.train_mm examples/multimodal/pretrain-llava-llama.yml
```
### Finetuning
TBD

View File

@@ -49,7 +49,7 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -34,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -36,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -34,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -36,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -34,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -36,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -53,7 +53,7 @@ output_dir: ./qlora-out
# decrease if OOM, increase for max VRAM utilization # decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1 micro_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
num_epochs: 4 num_epochs: 3
# Optimizer for QLoRA # Optimizer for QLoRA
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
torchdistx_path: torchdistx_path:

View File

@@ -46,7 +46,7 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -24,7 +24,7 @@ wandb_log_model:
output_dir: ./jeopardy-bot-7b output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -37,7 +37,7 @@ wandb_log_model:
output_dir: ./model-out output_dir: ./model-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_torch optimizer: adamw_torch
adam_beta2: 0.95 adam_beta2: 0.95
adam_eps: 0.00001 adam_eps: 0.00001

View File

@@ -34,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
save_steps: save_steps:

View File

@@ -36,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size: eval_table_size:
save_steps: save_steps:
debug: debug:

View File

@@ -40,7 +40,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -60,7 +60,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: 50 save_steps: 50
debug: debug:
deepspeed: deepspeed:

View File

@@ -34,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size: eval_table_size:
save_steps: save_steps:
debug: debug:

View File

@@ -26,7 +26,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.000005 learning_rate: 0.000005
@@ -46,7 +46,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
save_steps: save_steps:

View File

@@ -63,7 +63,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
save_steps: save_steps:

View File

@@ -26,7 +26,7 @@ wandb_log_model:
output_dir: ./mpt-alpaca-7b output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -0,0 +1,66 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
# multimodal pretrain
multimodal: true
mm_vision_tower: openai/clip-vit-large-patch14
tune_mm_mlp_adapter: true
mm_freeze_backbone: true
mm_vision_select_layer: -2
mm_projector_type: mlp2x_gelu
mm_image_folder: ./llava/
mm_use_im_patch_token: false
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: ./data/blip_laion_cc_sbu_558k.json
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./out
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps:
save_steps: 0.1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<unk>"

View File

@@ -0,0 +1,66 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
# multimodal pretrain
multimodal: true
mm_vision_tower: openai/clip-vit-large-patch14
tune_mm_mlp_adapter: true
mm_freeze_backbone: true
mm_vision_select_layer: -2
mm_projector_type: mlp2x_gelu
mm_image_folder: ./llava/
mm_use_im_patch_token: false
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: ./data/blip_laion_cc_sbu_558k.json
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./out
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps:
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<unk>"

View File

@@ -23,7 +23,7 @@ wandb_log_model:
output_dir: ./lora-alpaca-pythia output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 4 num_epochs: 3
learning_rate: 0.00001 learning_rate: 0.00001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
@@ -33,5 +33,5 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
weight_decay: 0.1 weight_decay: 0.1
eval_steps: 0.05 eval_steps: 20
logging_steps: 1 logging_steps: 1

View File

@@ -27,7 +27,7 @@ wandb_log_model:
output_dir: ./redpajama-alpaca-3b output_dir: ./redpajama-alpaca-3b
batch_size: 4 batch_size: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -26,7 +26,7 @@ wandb_log_model:
output_dir: ./lora-replit output_dir: ./lora-replit
batch_size: 8 batch_size: 8
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: optimizer:
torchdistx_path: torchdistx_path:
lr_scheduler: lr_scheduler:

View File

@@ -51,7 +51,7 @@ output_dir: ./qlora-out
# decrease if OOM, increase for max VRAM utilization # decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1 micro_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
num_epochs: 4 num_epochs: 3
# Optimizer for QLoRA # Optimizer for QLoRA
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
torchdistx_path: torchdistx_path:

View File

@@ -1 +0,0 @@
# Page

View File

@@ -1,4 +0,0 @@
# Table of contents
* [Page](README.md)
* [Small dev details](small-dev-details.md)

View File

@@ -1,3 +0,0 @@
# Small dev details
/

View File

@@ -31,4 +31,3 @@ scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29 fschat==0.2.29
tensor_parallel

View File

@@ -45,6 +45,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
shard(cfg=parsed_cfg, cli_args=parsed_cli_args) shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else: else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -2,6 +2,7 @@
import importlib import importlib
import logging import logging
import math
import os import os
import random import random
import sys import sys
@@ -215,6 +216,46 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
return cfg return cfg
def load_mm_dataset(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
model,
):
# pylint: disable=duplicate-code
from llava.train.train import DataArguments, LazySupervisedDataset
vision_tower = model.get_vision_tower()
data_args = DataArguments(
data_path=cfg.datasets[0]["path"],
lazy_preprocess=cfg.mm_lazy_preprocess
if cfg.mm_lazy_preprocess is not None
else True,
is_multimodal=True,
image_folder=cfg.mm_image_folder or None,
image_aspect_ratio=cfg.mm_image_aspect_ratio or "square",
image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None,
)
data_args.image_processor = vision_tower.image_processor
data_args.mm_use_im_start_end = cfg.mm_use_im_start_end or False
tokenizer = load_tokenizer(cfg)
train_dataset = LazySupervisedDataset(
tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args,
)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=None,
total_num_steps=total_num_steps,
)
def load_datasets( def load_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
@@ -222,9 +263,7 @@ def load_datasets(
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
cfg, tokenizer
)
if cli_args.debug or cfg.debug: if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
@@ -240,10 +279,6 @@ def load_datasets(
text_only=cli_args.debug_text_only, text_only=cli_args.debug_text_only,
) )
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta( return TrainDatasetMeta(
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,

View File

@@ -6,6 +6,7 @@ from pathlib import Path
import fire import fire
import transformers import transformers
from colorama import Fore
from axolotl.cli import ( from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
@@ -15,6 +16,7 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train") LOG = logging.getLogger("axolotl.cli.train")
@@ -30,7 +32,18 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "--prepare_ds_only called without dataset_prepared_path set."
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -5,6 +5,7 @@ import logging
from pathlib import Path from pathlib import Path
import fire import fire
import torch
import transformers import transformers
from colorama import Fore from colorama import Fore
@@ -12,13 +13,15 @@ from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
check_user_token, check_user_token,
load_cfg, load_cfg,
load_datasets, load_mm_dataset,
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import PreprocessCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.train import train
from axolotl.utils.models import load_model, load_tokenizer
LOG = logging.getLogger("axolotl.cli.preprocess") LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
@@ -27,26 +30,29 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
if not parsed_cfg.dataset_prepared_path: if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
msg = ( msg = (
Fore.RED Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, " + "--prepare_ds_only called without dataset_prepared_path set."
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
+ Fore.RESET + Fore.RESET
) )
LOG.warning(msg) LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) tokenizer = load_tokenizer(parsed_cfg)
LOG.info( model, _ = load_model(parsed_cfg, tokenizer)
Fore.GREEN dataset_meta = load_mm_dataset(
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`" cfg=parsed_cfg, cli_args=parsed_cli_args, model=model
+ Fore.RESET
) )
del model
torch.cuda.empty_cache()
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -25,22 +25,11 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=5) debug_num_examples: int = field(default=5)
inference: bool = field(default=False) inference: bool = field(default=False)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
shard: bool = field(default=False) shard: bool = field(default=False)
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
def load_model_and_tokenizer( def load_model_and_tokenizer(
*, *,
cfg: DictDefault, cfg: DictDefault,

View File

@@ -14,7 +14,6 @@ from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import tensor_parallel as tp
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
@@ -34,7 +33,6 @@ from axolotl.utils.callbacks import (
) )
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import is_distributed
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
try: try:
@@ -42,6 +40,14 @@ try:
except ImportError: except ImportError:
pass pass
try:
from llava.train.llava_trainer import get_mm_adapter_state_maybe_zero_3
except ImportError:
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
raise ImportError("missing LLaVA package")
LOG = logging.getLogger("axolotl.core.trainer_builder") LOG = logging.getLogger("axolotl.core.trainer_builder")
@@ -104,8 +110,16 @@ class AxolotlTrainingArguments(TrainingArguments):
bench_source_max_len: int = field( bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."} default=2048, metadata={"help": "Maximum source sequence length for bench."}
) )
tensor_parallel: bool = field( tune_mm_mlp_adapter: bool = field(
default=False, metadata={"help": "Use tensor parallelism to train"} default=False,
metadata={"help": "Whether to train the multimodal projector adapter"},
)
freeze_mm_mlp_adapter: bool = field(
default=False,
metadata={"help": "Whether to freeze the multimodal projector adapter"},
)
mm_projector_lr: Optional[float] = field(
default=None, metadata={"help": "Learning rate for the multimodal projector"}
) )
@@ -116,8 +130,7 @@ class AxolotlTrainer(Trainer):
args = None # type: AxolotlTrainingArguments args = None # type: AxolotlTrainingArguments
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): def __init__(self, *args, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -188,7 +201,6 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=self.num_epochs,
) )
) )
return super().get_train_dataloader() return super().get_train_dataloader()
@@ -212,7 +224,6 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size, sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=self.num_epochs,
) )
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
@@ -251,13 +262,40 @@ class AxolotlTrainer(Trainer):
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs)
def _wrap_model(self, model, training=True, dataloader=None): def _save_checkpoint(self, model, trial, metrics=None):
if self.args.tensor_parallel: if getattr(self.args, "tune_mm_mlp_adapter", False):
model = tp.tensor_parallel(model, distributed=is_distributed()) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
model.hf_device_map = tp.infer_sharded_device_map(model)
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
weights_to_save = self._get_mm_mlp_adapter_weights()
if self.args.local_rank in (0, -1):
self.model.config.save_pretrained(output_dir)
torch.save(
weights_to_save, os.path.join(output_dir, "mm_projector.bin")
)
else: else:
model = super()._wrap_model(model, training=training, dataloader=dataloader) super()._save_checkpoint(model, trial, metrics)
return model
def _get_mm_mlp_adapter_weights(self):
# Only save Adapter
keys_to_match = ["mm_projector", "vision_resampler"]
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
return get_mm_adapter_state_maybe_zero_3(
self.model.named_parameters(), keys_to_match
)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
pass
else:
super()._save(output_dir, state_dict)
class OneCycleLRSchedulerTrainer(AxolotlTrainer): class OneCycleLRSchedulerTrainer(AxolotlTrainer):
@@ -384,10 +422,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer_kwargs, trainer_cls return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer): def hook_post_create_trainer(self, trainer):
if self.cfg.tensor_parallel: # TODO
trainer.model = trainer.accelerator.prepare_model(
trainer.model, device_placement=True
)
return trainer return trainer
def get_callbacks(self): def get_callbacks(self):
@@ -629,9 +664,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"sample_packing_seq_len_multiplier" "sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size ] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
# multimodal: llava
training_arguments_kwargs["tune_mm_mlp_adapter"] = self.cfg.tune_mm_mlp_adapter
training_arguments_kwargs[
"freeze_mm_mlp_adapter"
] = self.cfg.freeze_mm_mlp_adapter
training_arguments_kwargs["mm_projector_lr"] = self.cfg.mm_projector_lr
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
@@ -649,18 +691,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
sys.path.append(self.cfg.torchdistx_path) sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx") importlib.import_module("torchdistx")
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
self.cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention: if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import ( from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens, add_mem_tokens,
@@ -685,23 +715,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls trainer_kwargs, trainer_cls
) )
trainer_collator_kwargs = self.build_data_collator()
trainer = trainer_cls( trainer = trainer_cls(
model=self.model, model=self.model,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset, eval_dataset=self.eval_dataset,
args=training_args, args=training_args,
data_collator=DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs, **trainer_collator_kwargs,
**trainer_kwargs, **trainer_kwargs,
) )
trainer = self.hook_post_create_trainer(trainer) trainer = self.hook_post_create_trainer(trainer)
@@ -709,3 +731,41 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer.add_callback(callback) trainer.add_callback(callback)
return trainer return trainer
def build_data_collator(self):
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
self.cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
collator_kwargs = {}
if self.cfg.multimodal:
from llava.train.train import DataCollatorForSupervisedDataset
collator_kwargs["data_collator"] = DataCollatorForSupervisedDataset(
tokenizer=self.tokenizer,
)
else:
collator_kwargs["data_collator"] = DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
if self.cfg.do_bench_eval:
collator_kwargs[
"bench_data_collator"
] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
return collator_kwargs

View File

View File

@@ -0,0 +1,167 @@
"""
LLaVA Mistral classes
"""
from typing import List, Optional, Tuple, Union
import torch
from llava.model.llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (
AutoConfig,
AutoModelForCausalLM,
MistralConfig,
MistralForCausalLM,
MistralModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
class LlavaMistralConfig(MistralConfig):
"""
HF Transformers Config for Mistral w LLaVA
"""
model_type = "llava_mistral"
class LlavaMistralModel(LlavaMetaModel, MistralModel):
"""
HF Transformers Model for Mistral w LLaVA
"""
config_class = LlavaMistralConfig
def __init__(
self, config: LlavaMistralConfig
): # pylint: disable=useless-parent-delegation
super().__init__(config)
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
"""
HF Transformers Causal Model for Mistral w LLaVA
"""
config_class = LlavaMistralConfig
def __init__(self, config: LlavaMistralConfig):
super().__init__(config)
self.model = LlavaMistralModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
(
input_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
) = self.prepare_inputs_labels_for_multimodal(
input_ids, attention_mask, past_key_values, labels, images
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
}
)
return model_inputs
AutoConfig.register("llava_mistral", LlavaMistralConfig)
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)

View File

@@ -0,0 +1,40 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
import transformers.models.llama.modeling_llama
from transformers.utils import logging
logger = logging.get_logger(__name__)
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(input_ids)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
)
return orig_embed(input_ids)
return new_func
def post_init(orig_post_init):
def new_func(self):
orig_post_init(self)
self.embed_tokens.forward = noised_embed(
self.embed_tokens.forward, noise_alpha, self
)
return new_func
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
transformers.models.llama.modeling_llama.LlamaModel.post_init
)

View File

@@ -0,0 +1,40 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
import transformers.models.mistral.modeling_mistral
from transformers.utils import logging
logger = logging.get_logger(__name__)
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(input_ids)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
)
return orig_embed(input_ids)
return new_func
def post_init(orig_post_init):
def new_func(self):
orig_post_init(self)
self.embed_tokens.forward = noised_embed(
self.embed_tokens.forward, noise_alpha, self
)
return new_func
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
transformers.models.mistral.modeling_mistral.MistralModel.post_init
)

View File

@@ -1,65 +0,0 @@
"""
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
from peft import PeftModel
from transformers import PreTrainedModel
def patch_neft(alpha, model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
embeddings.noisy_embedding_alpha = alpha
old_forward = embeddings.forward
# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
embeddings, embeddings.__class__
)
setattr(embeddings, "forward", bound_method)
embeddings._old_forward = old_forward # pylint: disable=protected-access
return model
def unpatch_neft(model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
if hasattr(embeddings, "_old_forward"):
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
del embeddings._old_forward # pylint: disable=protected-access
del embeddings.noisy_embedding_alpha
def neft_forward(self, inputs: torch.Tensor):
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
if self.training:
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
-mag_norm, mag_norm
)
return embeddings
def pretrain_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
def post_train_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
unpatch_neft(trainer.model)

View File

@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
) )
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy( return SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2( ShareGPTPrompterV2(
conversation=conversation, conversation=conversation,
role_key_model=field_model, role_key_model=field_model,
@@ -34,9 +34,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg): def load_role(tokenizer, cfg):
@@ -62,26 +59,8 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
basic sharegpt strategy to grab conversations from the sample row basic sharegpt strategy to grab conversations from the sample row
""" """
_strict = True
@property
def strict(self):
return self._strict
@strict.setter
def strict(self, strict):
self._strict = strict
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
conversations = prompt["conversations"] return prompt["conversations"]
if self.strict:
return conversations
# remap roles - allow for assistant turn
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
turns = [
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
]
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):

View File

@@ -245,7 +245,6 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
raise NotImplementedError raise NotImplementedError
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
( (
instruction, instruction,
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin

View File

@@ -4,12 +4,10 @@ import logging
from enum import Enum from enum import Enum
from typing import Generator, Optional, Union from typing import Generator, Optional, Union
from colorama import Fore
from fastchat.conversation import Conversation, get_conv_template from fastchat.conversation import Conversation, get_conv_template
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
class PromptStyle(Enum): class PromptStyle(Enum):
@@ -57,15 +55,20 @@ class AlpacaPrompter:
) )
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
def _build_result(self, instruction, input_text, output): def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input_text: if input:
res = ( res = (
self.system_format.format(system=self.system_prompt) self.system_format.format(system=self.system_prompt)
if self.system_prompt if self.system_prompt
else "" else ""
) + self.turn_format.format(instruction=instruction, input=input_text) ) + self.turn_format.format(instruction=instruction, input=input)
else: else:
res = ( res = (
self.system_format.format(system=self.system_no_input_prompt) self.system_format.format(system=self.system_no_input_prompt)
@@ -74,21 +77,7 @@ class AlpacaPrompter:
) + self.turn_no_input_format.format(instruction=instruction) ) + self.turn_no_input_format.format(instruction=instruction)
if output: if output:
res = f"{res}{output}" res = f"{res}{output}"
yield res
return res
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
yield self._build_result(instruction, input, output)
def __repr__(self) -> str:
return REPR_TEMPLATE.format(
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
)
class UnpromptedPrompter(AlpacaPrompter): class UnpromptedPrompter(AlpacaPrompter):
@@ -202,14 +191,14 @@ class ReflectAlpacaPrompter:
) )
self.response_split = "ASSISTANT:" self.response_split = "ASSISTANT:"
def _build_result( def build_prompt(
self, self,
instruction: str, instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None, output: Union[None, str] = None,
reflection: Union[None, str] = None, reflection: Union[None, str] = None,
corrected: Union[None, str] = None, corrected: Union[None, str] = None,
): ) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input: if input:
@@ -223,30 +212,7 @@ class ReflectAlpacaPrompter:
corrected=corrected, corrected=corrected,
) )
res = f"{res}{label}" res = f"{res}{label}"
yield res
return res
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
reflection: Union[None, str] = None,
corrected: Union[None, str] = None,
) -> Generator[str, None, None]:
# pylint: disable=duplicate-code
yield self._build_result(
instruction,
input,
output,
reflection,
corrected,
)
def __repr__(self) -> str:
return REPR_TEMPLATE.format(
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
)
SHAREGPT_ASSERTION_FAILED_ROLE = ( SHAREGPT_ASSERTION_FAILED_ROLE = (
@@ -281,7 +247,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if role_key_model: if role_key_model:
self.role_key_model = role_key_model self.role_key_model = role_key_model
def _build_result(self, source): def build_prompt(self, source) -> Generator[str, None, None]:
if len(source) < 2: if len(source) < 2:
# If there isn't a back and forth conversation, ignore it # If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations # also happens on the data splitting leaving empty conversations
@@ -316,20 +282,11 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
return conv.get_turns() for part in conv.get_turns():
def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)
for part in turns:
if part[0] and not part[1]: if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}") LOG.warning(f"role with empty message: {part[0]}")
yield part yield part
def __repr__(self) -> str:
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
class ShareGPTPrompterV2(ShareGPTPrompter): class ShareGPTPrompterV2(ShareGPTPrompter):
""" """
@@ -347,15 +304,3 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
role_key_human=role_key_human, role_key_human=role_key_human,
role_key_model=role_key_model, role_key_model=role_key_model,
) )
class UnsupportedPrompter:
"""
A dummy class for custom prompters
"""
def __init__(self) -> None:
pass
def __repr__(self):
return "Pre-tokenized or custom dataset types are unsupported for logging"

View File

@@ -16,11 +16,18 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
try:
from llava.train.train import safe_save_model_for_hf_trainer
except ImportError:
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
raise ImportError("missing LLaVA package")
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src") src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
@@ -108,7 +115,6 @@ def train(
if cfg.group_by_length: if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length") LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum: if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True enable_flash=True, enable_math=True, enable_mem_efficient=True
@@ -116,7 +122,6 @@ def train(
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else: else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
@@ -140,6 +145,8 @@ def train(
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp: if cfg.fsdp:
trainer.save_model(cfg.output_dir) trainer.save_model(cfg.output_dir)
elif cfg.multimodal:
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled(): elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone() trainer.accelerator.wait_for_everyone()
@@ -152,37 +159,17 @@ def train(
# The model name saved is `pytorch_model.bin` # The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained( unwrapped_model.save_pretrained(
cfg.output_dir, cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process, is_main_process=trainer.args.should_save,
save_function=trainer.accelerator.save, save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped), state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
) )
elif cfg.local_rank == 0: elif trainer.args.should_save:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
# TODO figure out if `trainer.save_model(cfg.output_dir)` is sufficient here
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
return model, tokenizer return model, tokenizer
def pretrain_hooks(cfg, trainer):
"""
Run hooks right before kicking off the training
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.pretrain_hook(cfg, trainer)
def post_train_hooks(cfg, trainer):
"""
Run hooks right after training completes
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.post_train_hook(cfg, trainer)

View File

@@ -1,13 +1,10 @@
"""Benchmarking and measurement utilities""" """Benchmarking and measurement utilities"""
import functools import functools
import logging
import pynvml import pynvml
import torch import torch
from pynvml.nvml import NVMLError from pynvml.nvml import NVMLError
LOG = logging.getLogger("axolotl.utils.bench")
def check_cuda_device(default_value): def check_cuda_device(default_value):
""" """
@@ -65,14 +62,7 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device): def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available(): usage, cache, misc = gpu_memory_usage_all(device)
return (0, 0, 0)
try:
usage, cache, misc = gpu_memory_usage_all(device)
except ValueError as exc:
LOG.exception(exc)
return (0, 0, 0)
extras = [] extras = []
if cache > 0: if cache > 0:
extras.append(f"+{cache:.03f}GB cache") extras.append(f"+{cache:.03f}GB cache")

View File

@@ -369,10 +369,15 @@ def validate_config(cfg):
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit." "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
) )
if cfg.tensor_parallel and cfg.gradient_checkpointing: if cfg.multimodal:
raise ValueError( try:
"TensorParallelPreTrainedModel does not support gradient checkpointing" import llava # noqa: F401 # pylint:disable=unused-import
) except ImportError as exc:
LOG.warning(
"LLaVA package required for multimodal training. See docs/llava.md for more information."
)
raise exc
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -3,7 +3,7 @@ import functools
import hashlib import hashlib
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
import torch import torch
from datasets import ( from datasets import (
@@ -36,7 +36,6 @@ from axolotl.prompters import (
MultipleChoiceExplainPrompter, MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter, ReflectAlpacaPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
UnsupportedPrompter,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
@@ -55,11 +54,21 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer, model=None):
prompters = [] if cfg.multimodal:
if not cfg.pretraining_dataset: if not model:
raise ValueError("missing model argument")
from llava.train.train import LazySupervisedDataset
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset, prompters = load_prepare_datasets( eval_dataset = None
train_dataset = LazySupervisedDataset(
tokenizer=tokenizer,
)
elif not cfg.pretraining_dataset:
with zero_first(is_main_process()):
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else: else:
@@ -72,7 +81,7 @@ def prepare_dataset(cfg, tokenizer):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps, prompters return train_dataset, eval_dataset, cfg.max_steps
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
@@ -85,7 +94,7 @@ def prepare_dataset(cfg, tokenizer):
LOG.info(f"Maximum number of steps set at {total_num_steps}") LOG.info(f"Maximum number of steps set at {total_num_steps}")
else: else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
return train_dataset, eval_dataset, total_num_steps, prompters return train_dataset, eval_dataset, total_num_steps
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
@@ -111,7 +120,6 @@ def load_tokenized_prepared_datasets(
else Path(default_dataset_prepared_path) / ds_hash else Path(default_dataset_prepared_path) / ds_hash
) )
dataset = None dataset = None
prompters = []
use_auth_token = cfg.hf_use_auth_token use_auth_token = cfg.hf_use_auth_token
try: try:
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
@@ -150,13 +158,13 @@ def load_tokenized_prepared_datasets(
yield dataset yield dataset
# pylint: disable=invalid-name # pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg.datasets): for d in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False ds_from_hub = False
try: try:
load_dataset( load_dataset(
config_dataset.path, d.path,
name=config_dataset.name, name=d.name,
streaming=True, streaming=True,
token=use_auth_token, token=use_auth_token,
) )
@@ -165,33 +173,33 @@ def load_tokenized_prepared_datasets(
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(config_dataset.path) local_path = Path(d.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk` # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_dataset( ds = load_dataset(
config_dataset.path, d.path,
name=config_dataset.name, name=d.name,
data_files=config_dataset.data_files, data_files=d.data_files,
streaming=False, streaming=False,
split=None, split=None,
) )
elif local_path.is_file(): elif local_path.is_file():
ds_type = "json" ds_type = "json"
if config_dataset.ds_type: if d.ds_type:
ds_type = config_dataset.ds_type ds_type = d.ds_type
elif ".parquet" in config_dataset.path: elif ".parquet" in d.path:
ds_type = "parquet" ds_type = "parquet"
elif ".arrow" in config_dataset.path: elif ".arrow" in d.path:
ds_type = "arrow" ds_type = "arrow"
elif ".csv" in config_dataset.path: elif ".csv" in d.path:
ds_type = "csv" ds_type = "csv"
elif ".txt" in config_dataset.path: elif ".txt" in d.path:
ds_type = "text" ds_type = "text"
ds = load_dataset( ds = load_dataset(
ds_type, ds_type,
name=config_dataset.name, name=d.name,
data_files=config_dataset.path, data_files=d.path,
streaming=False, streaming=False,
split=None, split=None,
) )
@@ -201,25 +209,25 @@ def load_tokenized_prepared_datasets(
) )
elif ds_from_hub: elif ds_from_hub:
ds = load_dataset( ds = load_dataset(
config_dataset.path, d.path,
name=config_dataset.name, name=d.name,
streaming=False, streaming=False,
data_files=config_dataset.data_files, data_files=d.data_files,
token=use_auth_token, token=use_auth_token,
) )
else: else:
if isinstance(config_dataset.data_files, str): if isinstance(d.data_files, str):
fp = hf_hub_download( fp = hf_hub_download(
repo_id=config_dataset.path, repo_id=d.path,
repo_type="dataset", repo_type="dataset",
filename=config_dataset.data_files, filename=d.data_files,
) )
elif isinstance(config_dataset.data_files, list): elif isinstance(d.data_files, list):
fp = [] fp = []
for file in config_dataset.data_files: for file in d.data_files:
fp.append( fp.append(
hf_hub_download( hf_hub_download(
repo_id=config_dataset.path, repo_id=d.path,
repo_type="dataset", repo_type="dataset",
filename=file, filename=file,
) )
@@ -229,27 +237,21 @@ def load_tokenized_prepared_datasets(
"data_files must be either a string or list of strings" "data_files must be either a string or list of strings"
) )
ds = load_dataset( ds = load_dataset(
"json", "json", name=d.name, data_files=fp, streaming=False, split=None
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
) )
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if config_dataset.shards: if d.shards:
if "train" in ds: if "train" in ds:
ds = ds.shuffle(seed=seed)["train"].shard( ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=config_dataset.shards, index=0 num_shards=d.shards, index=0
) )
else: else:
ds = ds.shuffle(seed=seed).shard( ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
num_shards=config_dataset.shards, index=0
)
d_base_type = d_prompt_style = None d_base_type = d_prompt_style = None
d_type = config_dataset.type d_type = d.type
if isinstance(d_type, str): if isinstance(d_type, str):
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
@@ -258,26 +260,108 @@ def load_tokenized_prepared_datasets(
ds = ds["train"] ds = ds["train"]
elif ( elif (
isinstance(ds, DatasetDict) isinstance(ds, DatasetDict)
and config_dataset.train_on_split and d.train_on_split
and config_dataset.train_on_split in ds and d.train_on_split in ds
): ):
ds = ds[config_dataset.train_on_split] ds = ds[d.train_on_split]
elif isinstance(ds, DatasetDict): elif isinstance(ds, DatasetDict):
raise ValueError( raise ValueError(
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `" f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
)
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
and "labels" in ds.features
):
# dataset is already tokenized, just drop it straight in
datasets.append(ds)
elif isinstance(d.type, DictDefault):
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif ds_strategy := load(d.type, tokenizer, cfg, d):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "explainchoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceExplainPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "concisechoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceConcisePrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "summarizetldr":
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
SummarizeTLDRPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "jeopardy":
ds_strategy = JeopardyPromptTokenizingStrategy(
JeopardyPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "oasst":
ds_strategy = OpenAssistantPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(
GPTeacherPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "reflection":
ds_strategy = AlpacaReflectionPTStrategy(
ReflectAlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
suffix = ""
if ":load_" in d.type:
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
) )
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer,
cfg=cfg,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
LOG.info("merging datasets") LOG.info("merging datasets")
dataset = concatenate_datasets(datasets) dataset = concatenate_datasets(datasets)
@@ -295,14 +379,14 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
return dataset, prompters return dataset
def load_prepare_datasets( def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
cfg, cfg,
default_dataset_prepared_path, default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset, List[Any]]: ) -> Tuple[Dataset, Dataset]:
max_packed_sequence_len = ( max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
) )
@@ -311,7 +395,6 @@ def load_prepare_datasets(
) # make sure we don't accidentally set it larger than sequence_len ) # make sure we don't accidentally set it larger than sequence_len
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
prompters = []
if cfg.max_packed_sequence_len is not None: if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset # see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else "" seed = f"@{str(cfg.seed)}" if cfg.seed else ""
@@ -367,7 +450,7 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
else: else:
dataset, prompters = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) )
@@ -409,7 +492,7 @@ def load_prepare_datasets(
private=True, private=True,
) )
else: else:
dataset, prompters = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) )
@@ -460,124 +543,7 @@ def load_prepare_datasets(
train_dataset = dataset train_dataset = dataset
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, prompters return train_dataset, eval_dataset
def get_dataset_wrapper(
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
):
dataset_wrapper = None
dataset_prompter = None
if (
"input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
):
# dataset is already tokenized, just drop it straight in
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = dataset
elif isinstance(config_dataset.type, DictDefault):
ds_strategy = load(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice":
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice":
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr":
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy":
dataset_prompter = JeopardyPrompter(d_prompt_style)
ds_strategy = JeopardyPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "oasst":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = OpenAssistantPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher":
dataset_prompter = GPTeacherPrompter(d_prompt_style)
ds_strategy = GPTeacherPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "reflection":
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaReflectionPTStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
else:
suffix = ""
if ":load_" in config_dataset.type:
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
LOG.error(
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
)
raise ValueError(
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
)
return dataset_wrapper, dataset_prompter
def encode_pretraining( def encode_pretraining(

View File

@@ -3,9 +3,6 @@ import hashlib
import itertools import itertools
import logging import logging
import math import math
import time
from queue import Queue
from threading import Thread
from typing import Any, Callable, List, Union from typing import Any, Callable, List, Union
import numba import numba
@@ -152,8 +149,6 @@ class MultipackDistributedDataloader:
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1, sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1, device_count: int = 1,
prefetch_max: int = 1000,
num_epochs: int = 1,
): ):
# Dataset # Dataset
self.dataset = dataset self.dataset = dataset
@@ -172,7 +167,6 @@ class MultipackDistributedDataloader:
self.seq_max_length = seq_max_length self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.num_epochs = num_epochs
self.num_replicas = 1 self.num_replicas = 1
self.rank = 0 self.rank = 0
@@ -183,44 +177,6 @@ class MultipackDistributedDataloader:
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count self.device_count = device_count
# maxsize is maximum number of samples in queue
self.prefetch_max = prefetch_max
self.queue: Queue = Queue(maxsize=prefetch_max)
self.thread = None
def _worker(self):
LOG.info(
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
)
for epoch in range(self.num_epochs):
for sample in self._internal_batch_generator():
while True:
if self.queue.full():
time.sleep(1)
else:
break
self.queue.put(sample)
# stop the queue when epoch is done
self.queue.put(None)
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
if self.thread is None:
self.thread = Thread(target=self._worker, daemon=True)
self.thread.start()
while True:
item = self.queue.get()
if item is None:
break
yield item
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
LOG.info("generating packed batches") LOG.info("generating packed batches")
if self.sampler: if self.sampler:
@@ -250,7 +206,11 @@ class MultipackDistributedDataloader:
return batches, totseqs return batches, totseqs
def _internal_batch_generator(self): def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True) all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys() features = self.dataset.features.keys()
len_remaining = self._len_est() len_remaining = self._len_est()

View File

@@ -7,7 +7,6 @@ from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
import transformers.utils.bitsandbytes
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear from peft.tuners.lora import QuantLinear
@@ -32,7 +31,7 @@ LOG = logging.getLogger("axolotl")
def load_model_config(cfg): def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True trust_remote_code: bool = False or cfg.trust_remote_code
return AutoConfig.from_pretrained( return AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code model_config_name, trust_remote_code=trust_remote_code
) )
@@ -73,6 +72,11 @@ def load_tokenizer(cfg):
# set a pad_token, but use eos_token so we don't add a new token # set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -94,11 +98,6 @@ def load_tokenizer(cfg):
] ]
) )
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
return tokenizer return tokenizer
@@ -181,6 +180,26 @@ def load_model(
LOG.info("patching with flash attention") LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.llama_embeddings_hijack import (
replace_llama_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_llama_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_mistral_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
if cfg.is_llama_derived_model and cfg.xpos_rope: if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope, replace_llama_rope_with_xpos_rope,
@@ -222,7 +241,7 @@ def load_model(
load_in_4bit=True, load_in_4bit=True,
llm_int8_threshold=6.0, llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False, llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_compute_dtype=cfg.torch_dtype,
bnb_4bit_use_double_quant=True, bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
) )
@@ -236,12 +255,102 @@ def load_model(
model_kwargs["use_flash_attention_2"] = True model_kwargs["use_flash_attention_2"] = True
try: try:
if ( if cfg.multimodal:
cfg.is_llama_derived_model from llava.train.train import DataArguments, ModelArguments
and not cfg.trust_remote_code
and not cfg.gptq if cfg.is_llama_derived_model:
and not cfg.tensor_parallel from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
):
model = LlavaLlamaForCausalLM.from_pretrained(
cfg.base_model,
)
elif cfg.is_mistral_derived_model:
from axolotl.models.llava.llava_mistral import LlavaMistralForCausalLM
model = LlavaMistralForCausalLM.from_pretrained(
cfg.base_model,
)
else:
raise NotImplementedError(
"unhandled model architecture for multimodal training"
)
if cfg.mm_freeze_backbone:
model.model.requires_grad_(False)
if cfg.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(
module,
input,
output,
): # pylint: disable=redefined-builtin,unused-argument
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
model_args = ModelArguments(
model_name_or_path=cfg.base_model,
version="v0",
freeze_backbone=cfg.mm_freeze_backbone or False,
tune_mm_mlp_adapter=cfg.tune_mm_mlp_adapter or False,
vision_tower=cfg.mm_vision_tower,
mm_vision_select_layer=cfg.mm_vision_select_layer or -1,
pretrain_mm_mlp_adapter=cfg.pretrain_mm_mlp_adapter,
mm_projector_type=cfg.mm_projector_type or "linear",
mm_use_im_start_end=cfg.mm_use_im_start_end or False,
mm_use_im_patch_token=cfg.mm_use_im_patch_token,
mm_vision_select_feature=cfg.mm_vision_select_feature or "patch",
)
if cfg.mm_vision_tower is not None:
model.get_model().initialize_vision_modules(
model_args=model_args, fsdp=cfg.fsdp
)
vision_tower = model.get_vision_tower()
vision_tower.to(dtype=cfg.torch_dtype, device=cfg.device)
# pylint: disable=duplicate-code
data_args = DataArguments(
data_path=cfg.datasets[0]["path"],
lazy_preprocess=cfg.mm_lazy_preprocess
if cfg.mm_lazy_preprocess is not None
else True,
is_multimodal=True,
image_folder=cfg.mm_image_folder or None,
image_aspect_ratio=cfg.mm_image_aspect_ratio or "square",
image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None,
)
data_args.image_processor = vision_tower.image_processor
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
model.config.tune_mm_mlp_adapter = cfg.tune_mm_mlp_adapter
if cfg.tune_mm_mlp_adapter:
model.requires_grad_(False)
for (
p # pylint: disable=invalid-name
) in model.get_model().mm_projector.parameters():
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = cfg.freeze_mm_mlp_adapter
if cfg.freeze_mm_mlp_adapter:
for (
p # pylint: disable=invalid-name
) in model.get_model().mm_projector.parameters():
p.requires_grad = False
model.config.mm_use_im_start_end = (
data_args.mm_use_im_start_end
) = cfg.mm_use_im_start_end
model.config.mm_use_im_patch_token = cfg.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
elif cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config_kwargs = {} config_kwargs = {}
@@ -307,7 +416,7 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs, **model_kwargs,
) )
elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel: elif model_type and not cfg.trust_remote_code:
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
@@ -322,17 +431,6 @@ def load_model(
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
elif cfg.tensor_parallel:
model_kwargs.pop("device_map")
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
low_cpu_mem_usage=True,
offload_state_dict=True,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else: else:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
base_model, base_model,
@@ -383,18 +481,15 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
try: embeddings_len = (
embeddings_len = ( math.ceil(len(tokenizer) / 32) * 32
math.ceil(len(tokenizer) / 32) * 32 if cfg.resize_token_embeddings_to_32x
if cfg.resize_token_embeddings_to_32x else len(tokenizer)
else len(tokenizer) )
) if model.get_input_embeddings().num_embeddings < embeddings_len:
if model.get_input_embeddings().num_embeddings < embeddings_len: model.resize_token_embeddings(embeddings_len)
model.resize_token_embeddings(embeddings_len) else:
else: model.tie_weights()
model.tie_weights()
except NotImplementedError:
LOG.warning("`resize_token_embeddings` not implemented on model")
if ( if (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
@@ -406,20 +501,6 @@ def load_model(
) )
model.config.max_position_embeddings = cfg.sequence_len model.config.max_position_embeddings = cfg.sequence_len
if (
hasattr(model.config, "bos_token_id")
and model.config.bos_token_id
and model.config.bos_token_id != tokenizer.bos_token_id
):
model.config.bos_token_id = tokenizer.bos_token_id
if (
hasattr(model.config, "eos_token_id")
and model.config.eos_token_id
and model.config.eos_token_id != tokenizer.eos_token_id
):
model.config.eos_token_id = tokenizer.eos_token_id
if model.device.type == "cuda": if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device) log_gpu_memory_usage(LOG, "after model load", model.device)
@@ -497,12 +578,7 @@ def load_adapter(model, cfg, adapter, inference=False):
if adapter is None: if adapter is None:
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
try: model.enable_input_require_grads()
model.enable_input_require_grads()
except NotImplementedError:
LOG.warning("enable_input_require_grads not implemented on model")
if adapter == "qlora" and cfg.tensor_parallel:
model, _ = load_tp_qlora(model)
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference) return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter": if adapter == "llama-adapter":
@@ -539,7 +615,14 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set() lora_module_names = set()
multimodal_keywords = [
"mm_projector",
"vision_tower",
"vision_resampler",
] # for LLaVA
for name, module in model.named_modules(): for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if ( if (
isinstance(module, cls) isinstance(module, cls)
or "Linear" in module.__class__.__name__ or "Linear" in module.__class__.__name__
@@ -554,25 +637,6 @@ def find_all_linear_names(model):
return list(lora_module_names) return list(lora_module_names)
def load_tp_qlora(model):
from transformers.utils.bitsandbytes import replace_with_bnb_linear
model = replace_with_bnb_linear(
model,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
),
)
model.is_loaded_in_4bit = True
return model, None
def load_lora(model, cfg, inference=False): def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]

View File

@@ -13,7 +13,7 @@ import torch.distributed as dist
from datasets import set_caching_enabled from datasets import set_caching_enabled
from torch.utils.data import DistributedSampler, RandomSampler from torch.utils.data import DistributedSampler, RandomSampler
from axolotl.core.trainer_builder import HFCausalTrainerBuilder from axolotl.core.trainer_builder import AxolotlTrainer, HFCausalTrainerBuilder
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
@@ -216,7 +216,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
packing_efficiency_estimate=cfg.sample_packing_eff_est, packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.micro_batch_size, sample_packing_seq_len_multiplier=cfg.micro_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=cfg.num_epochs,
) )
data_loader_len = data_loader.len_w_stats() data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency() actual_eff = data_loader.efficiency()
@@ -260,7 +259,9 @@ def setup_fsdp_envs(cfg):
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
) -> AxolotlTrainer:
if cfg.fsdp: if cfg.fsdp:
setup_fsdp_envs(cfg) setup_fsdp_envs(cfg)
elif cfg.deepspeed: elif cfg.deepspeed: