Compare commits
10 Commits
tensor-par
...
llava
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b52e61a574 | ||
|
|
53f93f67bb | ||
|
|
ef95ea2977 | ||
|
|
1321608dc4 | ||
|
|
7ff30c4033 | ||
|
|
faa46fbcf8 | ||
|
|
fdc3e4d505 | ||
|
|
b885169229 | ||
|
|
ab9d12ce34 | ||
|
|
866774737b |
90
README.md
90
README.md
@@ -32,6 +32,7 @@ Features:
|
|||||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||||
- [Config](#config)
|
- [Config](#config)
|
||||||
- [Train](#train)
|
- [Train](#train)
|
||||||
|
- [Training w/ Deepspeed](#training-with-deepspeed)
|
||||||
- [Inference](#inference)
|
- [Inference](#inference)
|
||||||
- [Merge LORA to Base](#merge-lora-to-base)
|
- [Merge LORA to Base](#merge-lora-to-base)
|
||||||
- [Common Errors](#common-errors-)
|
- [Common Errors](#common-errors-)
|
||||||
@@ -114,25 +115,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
docker compose up -d
|
docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>Docker advanced</summary>
|
|
||||||
|
|
||||||
A more powerful Docker command to run would be this:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run --gpus '"all"' --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
|
||||||
```
|
|
||||||
|
|
||||||
It additionally:
|
|
||||||
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
|
||||||
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
|
||||||
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
|
||||||
|
|
||||||
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
#### Conda/Pip venv
|
#### Conda/Pip venv
|
||||||
1. Install python >=**3.9**
|
1. Install python >=**3.9**
|
||||||
|
|
||||||
@@ -374,13 +356,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- typescript
|
- typescript
|
||||||
type: ... # unimplemented custom format
|
type: ... # unimplemented custom format
|
||||||
|
|
||||||
# fastchat conversation
|
|
||||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
|
||||||
datasets:
|
|
||||||
- path: ...
|
|
||||||
type: sharegpt
|
|
||||||
conversation: chatml
|
|
||||||
|
|
||||||
# local
|
# local
|
||||||
datasets:
|
datasets:
|
||||||
- path: data.jsonl # or json
|
- path: data.jsonl # or json
|
||||||
@@ -419,7 +394,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
<summary>All yaml options (click me)</summary>
|
<summary>All yaml options</summary>
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||||
@@ -486,9 +461,7 @@ datasets:
|
|||||||
data_files: # Optional[str] path to source data files
|
data_files: # Optional[str] path to source data files
|
||||||
shards: # Optional[int] number of shards to split data into
|
shards: # Optional[int] number of shards to split data into
|
||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
|
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
|
||||||
|
|
||||||
# Custom user prompt
|
# Custom user prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
@@ -618,14 +591,14 @@ gradient_accumulation_steps: 1
|
|||||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size:
|
eval_batch_size:
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
lr_quadratic_warmup:
|
lr_quadratic_warmup:
|
||||||
logging_steps:
|
logging_steps:
|
||||||
save_strategy: # Set to `no` to skip checkpoint saves
|
save_strategy: # Set to `no` to skip checkpoint saves
|
||||||
save_steps: # Leave empty to save at each epoch
|
save_steps: # Leave empty to save at each epoch
|
||||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
eval_steps: # Leave empty to eval at each epoch
|
||||||
save_total_limit: # Checkpoints saved at a time
|
save_total_limit: # Checkpoints saved at a time
|
||||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||||
# if both are set, num_epochs will not be guaranteed.
|
# if both are set, num_epochs will not be guaranteed.
|
||||||
@@ -842,41 +815,14 @@ Run
|
|||||||
accelerate launch -m axolotl.cli.train your_config.yml
|
accelerate launch -m axolotl.cli.train your_config.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Preprocess dataset
|
|
||||||
|
|
||||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
|
||||||
This is recommended for large datasets.
|
|
||||||
|
|
||||||
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
|
||||||
- Use `--debug` to see preprocessed examples.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m axolotl.cli.preprocess your_config.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Multi-GPU
|
#### Multi-GPU
|
||||||
|
|
||||||
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
|
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||||
is the recommended multi-GPU option currently because FSDP may experience
|
```bash
|
||||||
[loss instability](https://github.com/huggingface/transformers/issues/26498).
|
CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
|
||||||
|
|
||||||
##### DeepSpeed
|
|
||||||
|
|
||||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
|
||||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
|
||||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
|
||||||
|
|
||||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
deepspeed: deepspeed/zero1.json
|
|
||||||
```
|
```
|
||||||
|
|
||||||
```shell
|
##### Config
|
||||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
|
||||||
```
|
|
||||||
|
|
||||||
##### FSDP
|
|
||||||
|
|
||||||
- llama FSDP
|
- llama FSDP
|
||||||
```yaml
|
```yaml
|
||||||
@@ -901,6 +847,24 @@ wandb_run_id:
|
|||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Training with Deepspeed
|
||||||
|
|
||||||
|
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||||
|
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||||
|
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||||
|
|
||||||
|
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
deepspeed: deepspeed/zero1.json
|
||||||
|
```
|
||||||
|
|
||||||
### Inference
|
### Inference
|
||||||
|
|
||||||
Pass the appropriate flag to the train command:
|
Pass the appropriate flag to the train command:
|
||||||
|
|||||||
@@ -1,6 +1,14 @@
|
|||||||
{
|
{
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 3,
|
"stage": 3,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"offload_param": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
"overlap_comm": true,
|
"overlap_comm": true,
|
||||||
"contiguous_gradients": true,
|
"contiguous_gradients": true,
|
||||||
"sub_group_size": 0,
|
"sub_group_size": 0,
|
||||||
@@ -33,13 +41,12 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"scheduler": {
|
"scheduler": {
|
||||||
"type": "WarmupDecayLR",
|
"type": "WarmupLR",
|
||||||
"params": {
|
"params": {
|
||||||
"warmup_min_lr": "auto",
|
"warmup_min_lr": "auto",
|
||||||
"warmup_max_lr": "auto",
|
"warmup_max_lr": "auto",
|
||||||
"warmup_num_steps": "auto",
|
"warmup_num_steps": "auto",
|
||||||
"warmup_type": "linear",
|
"warmup_type": "linear"
|
||||||
"total_num_steps": "auto"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
|
|||||||
@@ -12,7 +12,3 @@ This usually happens when you run out of system RAM.
|
|||||||
> Exitcode -7 while using deepspeed
|
> Exitcode -7 while using deepspeed
|
||||||
|
|
||||||
Try upgrading deepspeed w: `pip install -U deepspeed`
|
Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||||
|
|
||||||
> AttributeError: 'DummyOptim' object has no attribute 'step'
|
|
||||||
|
|
||||||
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
|
||||||
|
|||||||
36
docs/llava.md
Normal file
36
docs/llava.md
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# LLaVA
|
||||||
|
|
||||||
|
### Installing dependencies
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/haotian-liu/LLaVA.git
|
||||||
|
cd LLaVA
|
||||||
|
pip install --no-deps -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Downloading assets
|
||||||
|
|
||||||
|
LLaVA doesn't support remote datasets, so both the JSON and image assets need to be downloaded locally
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mkdir llava
|
||||||
|
mkdir data
|
||||||
|
cd llava
|
||||||
|
curl -L -O https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip
|
||||||
|
unzip images.zip
|
||||||
|
|
||||||
|
cd ../data
|
||||||
|
curl -L -O https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/blip_laion_cc_sbu_558k.json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pretraining
|
||||||
|
|
||||||
|
Pretraining aligns the vision model with the language model.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
accelerate launch -m axolotl.cli.train_mm examples/multimodal/pretrain-llava-llama.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Finetuning
|
||||||
|
|
||||||
|
TBD
|
||||||
@@ -49,7 +49,7 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -54,7 +54,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -56,7 +56,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -54,7 +54,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -56,7 +56,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -54,7 +54,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -56,7 +56,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ output_dir: ./qlora-out
|
|||||||
# decrease if OOM, increase for max VRAM utilization
|
# decrease if OOM, increase for max VRAM utilization
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
# Optimizer for QLoRA
|
# Optimizer for QLoRA
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ wandb_log_model:
|
|||||||
output_dir: ./jeopardy-bot-7b
|
output_dir: ./jeopardy-bot-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ wandb_log_model:
|
|||||||
output_dir: ./model-out
|
output_dir: ./model-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_torch
|
optimizer: adamw_torch
|
||||||
adam_beta2: 0.95
|
adam_beta2: 0.95
|
||||||
adam_eps: 0.00001
|
adam_eps: 0.00001
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -54,7 +54,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
save_steps:
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -56,7 +56,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -60,7 +60,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
save_steps: 50
|
save_steps: 50
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -54,7 +54,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ wandb_log_model:
|
|||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.000005
|
learning_rate: 0.000005
|
||||||
@@ -46,7 +46,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
save_steps:
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
save_steps:
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ wandb_log_model:
|
|||||||
output_dir: ./mpt-alpaca-7b
|
output_dir: ./mpt-alpaca-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
|
|||||||
66
examples/multimodal/pretrain-llava-llama.yml
Normal file
66
examples/multimodal/pretrain-llava-llama.yml
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
base_model: NousResearch/Llama-2-7b-hf
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_llama_derived_model: true
|
||||||
|
|
||||||
|
# multimodal pretrain
|
||||||
|
multimodal: true
|
||||||
|
mm_vision_tower: openai/clip-vit-large-patch14
|
||||||
|
tune_mm_mlp_adapter: true
|
||||||
|
mm_freeze_backbone: true
|
||||||
|
mm_vision_select_layer: -2
|
||||||
|
mm_projector_type: mlp2x_gelu
|
||||||
|
mm_image_folder: ./llava/
|
||||||
|
mm_use_im_patch_token: false
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: ./data/blip_laion_cc_sbu_558k.json
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_run_id:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
eval_steps:
|
||||||
|
save_steps: 0.1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<unk>"
|
||||||
66
examples/multimodal/pretrain-llava-mistral.yml
Normal file
66
examples/multimodal/pretrain-llava-mistral.yml
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
base_model: mistralai/Mistral-7B-v0.1
|
||||||
|
model_type: MistralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_mistral_derived_model: true
|
||||||
|
|
||||||
|
# multimodal pretrain
|
||||||
|
multimodal: true
|
||||||
|
mm_vision_tower: openai/clip-vit-large-patch14
|
||||||
|
tune_mm_mlp_adapter: true
|
||||||
|
mm_freeze_backbone: true
|
||||||
|
mm_vision_select_layer: -2
|
||||||
|
mm_projector_type: mlp2x_gelu
|
||||||
|
mm_image_folder: ./llava/
|
||||||
|
mm_use_im_patch_token: false
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: ./data/blip_laion_cc_sbu_558k.json
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_run_id:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
eval_steps:
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<unk>"
|
||||||
@@ -23,7 +23,7 @@ wandb_log_model:
|
|||||||
output_dir: ./lora-alpaca-pythia
|
output_dir: ./lora-alpaca-pythia
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
learning_rate: 0.00001
|
learning_rate: 0.00001
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
@@ -33,5 +33,5 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
eval_steps: 0.05
|
eval_steps: 20
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ wandb_log_model:
|
|||||||
output_dir: ./redpajama-alpaca-3b
|
output_dir: ./redpajama-alpaca-3b
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ wandb_log_model:
|
|||||||
output_dir: ./lora-replit
|
output_dir: ./lora-replit
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
optimizer:
|
optimizer:
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
lr_scheduler:
|
lr_scheduler:
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ output_dir: ./qlora-out
|
|||||||
# decrease if OOM, increase for max VRAM utilization
|
# decrease if OOM, increase for max VRAM utilization
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
num_epochs: 4
|
num_epochs: 3
|
||||||
# Optimizer for QLoRA
|
# Optimizer for QLoRA
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
# Page
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
# Table of contents
|
|
||||||
|
|
||||||
* [Page](README.md)
|
|
||||||
* [Small dev details](small-dev-details.md)
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# Small dev details
|
|
||||||
|
|
||||||
/
|
|
||||||
@@ -31,4 +31,3 @@ scikit-learn==1.2.2
|
|||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat==0.2.29
|
fschat==0.2.29
|
||||||
tensor_parallel
|
|
||||||
|
|||||||
@@ -45,6 +45,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
if parsed_cli_args.prepare_ds_only:
|
||||||
|
return
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
@@ -215,6 +216,46 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def load_mm_dataset(
|
||||||
|
*,
|
||||||
|
cfg: DictDefault,
|
||||||
|
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
|
||||||
|
model,
|
||||||
|
):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
from llava.train.train import DataArguments, LazySupervisedDataset
|
||||||
|
|
||||||
|
vision_tower = model.get_vision_tower()
|
||||||
|
data_args = DataArguments(
|
||||||
|
data_path=cfg.datasets[0]["path"],
|
||||||
|
lazy_preprocess=cfg.mm_lazy_preprocess
|
||||||
|
if cfg.mm_lazy_preprocess is not None
|
||||||
|
else True,
|
||||||
|
is_multimodal=True,
|
||||||
|
image_folder=cfg.mm_image_folder or None,
|
||||||
|
image_aspect_ratio=cfg.mm_image_aspect_ratio or "square",
|
||||||
|
image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None,
|
||||||
|
)
|
||||||
|
data_args.image_processor = vision_tower.image_processor
|
||||||
|
data_args.mm_use_im_start_end = cfg.mm_use_im_start_end or False
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
train_dataset = LazySupervisedDataset(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_path=data_args.data_path,
|
||||||
|
data_args=data_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_num_steps = int(
|
||||||
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainDatasetMeta(
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=None,
|
||||||
|
total_num_steps=total_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_datasets(
|
def load_datasets(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -222,9 +263,7 @@ def load_datasets(
|
|||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||||
cfg, tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
if cli_args.debug or cfg.debug:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
@@ -240,10 +279,6 @@ def load_datasets(
|
|||||||
text_only=cli_args.debug_text_only,
|
text_only=cli_args.debug_text_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("printing prompters...")
|
|
||||||
for prompter in prompters:
|
|
||||||
LOG.info(prompter)
|
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
return TrainDatasetMeta(
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
from colorama import Fore
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -15,6 +16,7 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
@@ -30,7 +32,18 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
|
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
||||||
|
msg = (
|
||||||
|
Fore.RED
|
||||||
|
+ "--prepare_ds_only called without dataset_prepared_path set."
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
LOG.warning(msg)
|
||||||
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
if parsed_cli_args.prepare_ds_only:
|
||||||
|
return
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
|
|
||||||
@@ -12,13 +13,15 @@ from axolotl.cli import (
|
|||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
check_user_token,
|
check_user_token,
|
||||||
load_cfg,
|
load_cfg,
|
||||||
load_datasets,
|
load_mm_dataset,
|
||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import PreprocessCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
@@ -27,26 +30,29 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
if not parsed_cfg.dataset_prepared_path:
|
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
Fore.RED
|
Fore.RED
|
||||||
+ "preprocess CLI called without dataset_prepared_path set, "
|
+ "--prepare_ds_only called without dataset_prepared_path set."
|
||||||
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
|
|
||||||
+ Fore.RESET
|
+ Fore.RESET
|
||||||
)
|
)
|
||||||
LOG.warning(msg)
|
LOG.warning(msg)
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
tokenizer = load_tokenizer(parsed_cfg)
|
||||||
LOG.info(
|
model, _ = load_model(parsed_cfg, tokenizer)
|
||||||
Fore.GREEN
|
dataset_meta = load_mm_dataset(
|
||||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
cfg=parsed_cfg, cli_args=parsed_cli_args, model=model
|
||||||
+ Fore.RESET
|
|
||||||
)
|
)
|
||||||
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if parsed_cli_args.prepare_ds_only:
|
||||||
|
return
|
||||||
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -25,22 +25,11 @@ class TrainerCliArgs:
|
|||||||
debug_num_examples: int = field(default=5)
|
debug_num_examples: int = field(default=5)
|
||||||
inference: bool = field(default=False)
|
inference: bool = field(default=False)
|
||||||
merge_lora: bool = field(default=False)
|
merge_lora: bool = field(default=False)
|
||||||
|
prepare_ds_only: bool = field(default=False)
|
||||||
prompter: Optional[str] = field(default=None)
|
prompter: Optional[str] = field(default=None)
|
||||||
shard: bool = field(default=False)
|
shard: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PreprocessCliArgs:
|
|
||||||
"""
|
|
||||||
dataclass representing arguments for preprocessing only
|
|
||||||
"""
|
|
||||||
|
|
||||||
debug: bool = field(default=False)
|
|
||||||
debug_text_only: bool = field(default=False)
|
|
||||||
debug_num_examples: int = field(default=1)
|
|
||||||
prompter: Optional[str] = field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from functools import partial
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import tensor_parallel as tp
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -34,7 +33,6 @@ from axolotl.utils.callbacks import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
from axolotl.utils.distributed import is_distributed
|
|
||||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -42,6 +40,14 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from llava.train.llava_trainer import get_mm_adapter_state_maybe_zero_3
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
||||||
|
raise ImportError("missing LLaVA package")
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||||
|
|
||||||
|
|
||||||
@@ -104,8 +110,16 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
bench_source_max_len: int = field(
|
bench_source_max_len: int = field(
|
||||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||||
)
|
)
|
||||||
tensor_parallel: bool = field(
|
tune_mm_mlp_adapter: bool = field(
|
||||||
default=False, metadata={"help": "Use tensor parallelism to train"}
|
default=False,
|
||||||
|
metadata={"help": "Whether to train the multimodal projector adapter"},
|
||||||
|
)
|
||||||
|
freeze_mm_mlp_adapter: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to freeze the multimodal projector adapter"},
|
||||||
|
)
|
||||||
|
mm_projector_lr: Optional[float] = field(
|
||||||
|
default=None, metadata={"help": "Learning rate for the multimodal projector"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -116,8 +130,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
|
|
||||||
args = None # type: AxolotlTrainingArguments
|
args = None # type: AxolotlTrainingArguments
|
||||||
|
|
||||||
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
def __init__(self, *args, bench_data_collator=None, **kwargs):
|
||||||
self.num_epochs = num_epochs
|
|
||||||
self.bench_data_collator = bench_data_collator
|
self.bench_data_collator = bench_data_collator
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@@ -188,7 +201,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
num_epochs=self.num_epochs,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
@@ -212,7 +224,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
num_epochs=self.num_epochs,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
@@ -251,13 +262,40 @@ class AxolotlTrainer(Trainer):
|
|||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
if self.args.tensor_parallel:
|
if getattr(self.args, "tune_mm_mlp_adapter", False):
|
||||||
model = tp.tensor_parallel(model, distributed=is_distributed())
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
model.hf_device_map = tp.infer_sharded_device_map(model)
|
|
||||||
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
|
|
||||||
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
|
|
||||||
|
weights_to_save = self._get_mm_mlp_adapter_weights()
|
||||||
|
|
||||||
|
if self.args.local_rank in (0, -1):
|
||||||
|
self.model.config.save_pretrained(output_dir)
|
||||||
|
torch.save(
|
||||||
|
weights_to_save, os.path.join(output_dir, "mm_projector.bin")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model = super()._wrap_model(model, training=training, dataloader=dataloader)
|
super()._save_checkpoint(model, trial, metrics)
|
||||||
return model
|
|
||||||
|
def _get_mm_mlp_adapter_weights(self):
|
||||||
|
# Only save Adapter
|
||||||
|
keys_to_match = ["mm_projector", "vision_resampler"]
|
||||||
|
if getattr(self.args, "use_im_start_end", False):
|
||||||
|
keys_to_match.extend(["embed_tokens", "embed_in"])
|
||||||
|
|
||||||
|
return get_mm_adapter_state_maybe_zero_3(
|
||||||
|
self.model.named_parameters(), keys_to_match
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||||
|
if getattr(self.args, "tune_mm_mlp_adapter", False):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
super()._save(output_dir, state_dict)
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||||
@@ -384,10 +422,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return trainer_kwargs, trainer_cls
|
return trainer_kwargs, trainer_cls
|
||||||
|
|
||||||
def hook_post_create_trainer(self, trainer):
|
def hook_post_create_trainer(self, trainer):
|
||||||
if self.cfg.tensor_parallel:
|
# TODO
|
||||||
trainer.model = trainer.accelerator.prepare_model(
|
|
||||||
trainer.model, device_placement=True
|
|
||||||
)
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
@@ -629,9 +664,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"sample_packing_seq_len_multiplier"
|
"sample_packing_seq_len_multiplier"
|
||||||
] = self.cfg.micro_batch_size
|
] = self.cfg.micro_batch_size
|
||||||
|
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||||
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
|
|
||||||
|
# multimodal: llava
|
||||||
|
training_arguments_kwargs["tune_mm_mlp_adapter"] = self.cfg.tune_mm_mlp_adapter
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"freeze_mm_mlp_adapter"
|
||||||
|
] = self.cfg.freeze_mm_mlp_adapter
|
||||||
|
training_arguments_kwargs["mm_projector_lr"] = self.cfg.mm_projector_lr
|
||||||
|
|
||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
@@ -649,18 +691,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
sys.path.append(self.cfg.torchdistx_path)
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
importlib.import_module("torchdistx")
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
data_collator_kwargs = {
|
|
||||||
"padding": True, # True/"longest" is the default
|
|
||||||
}
|
|
||||||
if self.cfg.pad_to_sequence_len:
|
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
|
||||||
self.cfg.sequence_len / 64
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
|
||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
|
||||||
|
|
||||||
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
|
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||||
add_mem_tokens,
|
add_mem_tokens,
|
||||||
@@ -685,23 +715,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||||
trainer_kwargs, trainer_cls
|
trainer_kwargs, trainer_cls
|
||||||
)
|
)
|
||||||
|
trainer_collator_kwargs = self.build_data_collator()
|
||||||
|
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=DataCollatorForSeq2Seq(
|
|
||||||
self.tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
**data_collator_kwargs,
|
|
||||||
),
|
|
||||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
|
||||||
self.tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
**data_collator_kwargs,
|
|
||||||
),
|
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
num_epochs=self.cfg.num_epochs,
|
**trainer_collator_kwargs,
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
trainer = self.hook_post_create_trainer(trainer)
|
trainer = self.hook_post_create_trainer(trainer)
|
||||||
@@ -709,3 +731,41 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer.add_callback(callback)
|
trainer.add_callback(callback)
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
def build_data_collator(self):
|
||||||
|
data_collator_kwargs = {
|
||||||
|
"padding": True, # True/"longest" is the default
|
||||||
|
}
|
||||||
|
if self.cfg.pad_to_sequence_len:
|
||||||
|
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
||||||
|
self.cfg.sequence_len / 64
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||||
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
|
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||||
|
|
||||||
|
collator_kwargs = {}
|
||||||
|
if self.cfg.multimodal:
|
||||||
|
from llava.train.train import DataCollatorForSupervisedDataset
|
||||||
|
|
||||||
|
collator_kwargs["data_collator"] = DataCollatorForSupervisedDataset(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
collator_kwargs["data_collator"] = DataCollatorForSeq2Seq(
|
||||||
|
self.tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
**data_collator_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cfg.do_bench_eval:
|
||||||
|
collator_kwargs[
|
||||||
|
"bench_data_collator"
|
||||||
|
] = transformers.DataCollatorForSeq2Seq(
|
||||||
|
self.tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
**data_collator_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return collator_kwargs
|
||||||
|
|||||||
0
src/axolotl/models/llava/__init__.py
Normal file
0
src/axolotl/models/llava/__init__.py
Normal file
167
src/axolotl/models/llava/llava_mistral.py
Normal file
167
src/axolotl/models/llava/llava_mistral.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
LLaVA Mistral classes
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from llava.model.llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
MistralConfig,
|
||||||
|
MistralForCausalLM,
|
||||||
|
MistralModel,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaMistralConfig(MistralConfig):
|
||||||
|
"""
|
||||||
|
HF Transformers Config for Mistral w LLaVA
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "llava_mistral"
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaMistralModel(LlavaMetaModel, MistralModel):
|
||||||
|
"""
|
||||||
|
HF Transformers Model for Mistral w LLaVA
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = LlavaMistralConfig
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config: LlavaMistralConfig
|
||||||
|
): # pylint: disable=useless-parent-delegation
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
|
||||||
|
"""
|
||||||
|
HF Transformers Causal Model for Mistral w LLaVA
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = LlavaMistralConfig
|
||||||
|
|
||||||
|
def __init__(self, config: LlavaMistralConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = LlavaMistralModel(config)
|
||||||
|
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_model(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
images: Optional[torch.FloatTensor] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
past_key_values,
|
||||||
|
inputs_embeds,
|
||||||
|
labels,
|
||||||
|
) = self.prepare_inputs_labels_for_multimodal(
|
||||||
|
input_ids, attention_mask, past_key_values, labels, images
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model/pipeline parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if past_key_values:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
|
if inputs_embeds is not None and past_key_values is None:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
|
else:
|
||||||
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"images": kwargs.get("images", None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
|
AutoConfig.register("llava_mistral", LlavaMistralConfig)
|
||||||
|
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
|
||||||
40
src/axolotl/monkeypatch/llama_embeddings_hijack.py
Normal file
40
src/axolotl/monkeypatch/llama_embeddings_hijack.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def noised_embed(orig_embed, noise_alpha, model):
|
||||||
|
def new_func(input_ids):
|
||||||
|
# during training, we add noise to the embedding
|
||||||
|
# during generation, we don't add noise to the embedding
|
||||||
|
if model.training:
|
||||||
|
embed_init = orig_embed(input_ids)
|
||||||
|
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||||
|
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||||
|
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||||
|
-mag_norm, mag_norm
|
||||||
|
)
|
||||||
|
return orig_embed(input_ids)
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
def post_init(orig_post_init):
|
||||||
|
def new_func(self):
|
||||||
|
orig_post_init(self)
|
||||||
|
self.embed_tokens.forward = noised_embed(
|
||||||
|
self.embed_tokens.forward, noise_alpha, self
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel.post_init
|
||||||
|
)
|
||||||
40
src/axolotl/monkeypatch/mistral_embeddings_hijack.py
Normal file
40
src/axolotl/monkeypatch/mistral_embeddings_hijack.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers.models.mistral.modeling_mistral
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def noised_embed(orig_embed, noise_alpha, model):
|
||||||
|
def new_func(input_ids):
|
||||||
|
# during training, we add noise to the embedding
|
||||||
|
# during generation, we don't add noise to the embedding
|
||||||
|
if model.training:
|
||||||
|
embed_init = orig_embed(input_ids)
|
||||||
|
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||||
|
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||||
|
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||||
|
-mag_norm, mag_norm
|
||||||
|
)
|
||||||
|
return orig_embed(input_ids)
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
def post_init(orig_post_init):
|
||||||
|
def new_func(self):
|
||||||
|
orig_post_init(self)
|
||||||
|
self.embed_tokens.forward = noised_embed(
|
||||||
|
self.embed_tokens.forward, noise_alpha, self
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralModel.post_init
|
||||||
|
)
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
"""
|
|
||||||
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from peft import PeftModel
|
|
||||||
from transformers import PreTrainedModel
|
|
||||||
|
|
||||||
|
|
||||||
def patch_neft(alpha, model):
|
|
||||||
embeddings = None
|
|
||||||
if isinstance(model, PreTrainedModel):
|
|
||||||
embeddings = model.get_input_embeddings()
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
embeddings = model.base_model.get_input_embeddings()
|
|
||||||
if not embeddings:
|
|
||||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
|
||||||
embeddings.noisy_embedding_alpha = alpha
|
|
||||||
old_forward = embeddings.forward
|
|
||||||
|
|
||||||
# This hack seems to be needed to properly use a custom forward pass
|
|
||||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
|
||||||
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
|
||||||
embeddings, embeddings.__class__
|
|
||||||
)
|
|
||||||
setattr(embeddings, "forward", bound_method)
|
|
||||||
|
|
||||||
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def unpatch_neft(model):
|
|
||||||
embeddings = None
|
|
||||||
if isinstance(model, PreTrainedModel):
|
|
||||||
embeddings = model.get_input_embeddings()
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
embeddings = model.base_model.get_input_embeddings()
|
|
||||||
if not embeddings:
|
|
||||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
|
||||||
if hasattr(embeddings, "_old_forward"):
|
|
||||||
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
|
||||||
del embeddings._old_forward # pylint: disable=protected-access
|
|
||||||
del embeddings.noisy_embedding_alpha
|
|
||||||
|
|
||||||
|
|
||||||
def neft_forward(self, inputs: torch.Tensor):
|
|
||||||
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
|
||||||
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
|
||||||
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
|
||||||
-mag_norm, mag_norm
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def pretrain_hook(cfg, trainer):
|
|
||||||
if cfg.noisy_embedding_alpha:
|
|
||||||
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
|
||||||
|
|
||||||
|
|
||||||
def post_train_hook(cfg, trainer):
|
|
||||||
if cfg.noisy_embedding_alpha:
|
|
||||||
unpatch_neft(trainer.model)
|
|
||||||
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
)
|
)
|
||||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
return SimpleShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompterV2(
|
ShareGPTPrompterV2(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
role_key_model=field_model,
|
role_key_model=field_model,
|
||||||
@@ -34,9 +34,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
if ds_cfg and "strict" in ds_cfg:
|
|
||||||
strategy.strict = ds_cfg["strict"]
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
def load_role(tokenizer, cfg):
|
||||||
@@ -62,26 +59,8 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
basic sharegpt strategy to grab conversations from the sample row
|
basic sharegpt strategy to grab conversations from the sample row
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_strict = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def strict(self):
|
|
||||||
return self._strict
|
|
||||||
|
|
||||||
@strict.setter
|
|
||||||
def strict(self, strict):
|
|
||||||
self._strict = strict
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
conversations = prompt["conversations"]
|
return prompt["conversations"]
|
||||||
if self.strict:
|
|
||||||
return conversations
|
|
||||||
# remap roles - allow for assistant turn
|
|
||||||
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
|
|
||||||
turns = [
|
|
||||||
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
|
|||||||
@@ -245,7 +245,6 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
(
|
(
|
||||||
instruction,
|
instruction,
|
||||||
input, # pylint: disable=redefined-builtin
|
input, # pylint: disable=redefined-builtin
|
||||||
|
|||||||
@@ -4,12 +4,10 @@ import logging
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Generator, Optional, Union
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
from colorama import Fore
|
|
||||||
from fastchat.conversation import Conversation, get_conv_template
|
from fastchat.conversation import Conversation, get_conv_template
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
|
||||||
|
|
||||||
|
|
||||||
class PromptStyle(Enum):
|
class PromptStyle(Enum):
|
||||||
@@ -57,15 +55,20 @@ class AlpacaPrompter:
|
|||||||
)
|
)
|
||||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||||
|
|
||||||
def _build_result(self, instruction, input_text, output):
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||||
|
output: Union[None, str] = None,
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
if input_text:
|
if input:
|
||||||
res = (
|
res = (
|
||||||
self.system_format.format(system=self.system_prompt)
|
self.system_format.format(system=self.system_prompt)
|
||||||
if self.system_prompt
|
if self.system_prompt
|
||||||
else ""
|
else ""
|
||||||
) + self.turn_format.format(instruction=instruction, input=input_text)
|
) + self.turn_format.format(instruction=instruction, input=input)
|
||||||
else:
|
else:
|
||||||
res = (
|
res = (
|
||||||
self.system_format.format(system=self.system_no_input_prompt)
|
self.system_format.format(system=self.system_no_input_prompt)
|
||||||
@@ -74,21 +77,7 @@ class AlpacaPrompter:
|
|||||||
) + self.turn_no_input_format.format(instruction=instruction)
|
) + self.turn_no_input_format.format(instruction=instruction)
|
||||||
if output:
|
if output:
|
||||||
res = f"{res}{output}"
|
res = f"{res}{output}"
|
||||||
|
yield res
|
||||||
return res
|
|
||||||
|
|
||||||
def build_prompt(
|
|
||||||
self,
|
|
||||||
instruction: str,
|
|
||||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
|
||||||
output: Union[None, str] = None,
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
yield self._build_result(instruction, input, output)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return REPR_TEMPLATE.format(
|
|
||||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UnpromptedPrompter(AlpacaPrompter):
|
class UnpromptedPrompter(AlpacaPrompter):
|
||||||
@@ -202,14 +191,14 @@ class ReflectAlpacaPrompter:
|
|||||||
)
|
)
|
||||||
self.response_split = "ASSISTANT:"
|
self.response_split = "ASSISTANT:"
|
||||||
|
|
||||||
def _build_result(
|
def build_prompt(
|
||||||
self,
|
self,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||||
output: Union[None, str] = None,
|
output: Union[None, str] = None,
|
||||||
reflection: Union[None, str] = None,
|
reflection: Union[None, str] = None,
|
||||||
corrected: Union[None, str] = None,
|
corrected: Union[None, str] = None,
|
||||||
):
|
) -> Generator[str, None, None]:
|
||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
if input:
|
if input:
|
||||||
@@ -223,30 +212,7 @@ class ReflectAlpacaPrompter:
|
|||||||
corrected=corrected,
|
corrected=corrected,
|
||||||
)
|
)
|
||||||
res = f"{res}{label}"
|
res = f"{res}{label}"
|
||||||
|
yield res
|
||||||
return res
|
|
||||||
|
|
||||||
def build_prompt(
|
|
||||||
self,
|
|
||||||
instruction: str,
|
|
||||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
|
||||||
output: Union[None, str] = None,
|
|
||||||
reflection: Union[None, str] = None,
|
|
||||||
corrected: Union[None, str] = None,
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
yield self._build_result(
|
|
||||||
instruction,
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
reflection,
|
|
||||||
corrected,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return REPR_TEMPLATE.format(
|
|
||||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||||
@@ -281,7 +247,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
if role_key_model:
|
if role_key_model:
|
||||||
self.role_key_model = role_key_model
|
self.role_key_model = role_key_model
|
||||||
|
|
||||||
def _build_result(self, source):
|
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
# If there isn't a back and forth conversation, ignore it
|
# If there isn't a back and forth conversation, ignore it
|
||||||
# also happens on the data splitting leaving empty conversations
|
# also happens on the data splitting leaving empty conversations
|
||||||
@@ -316,20 +282,11 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
return conv.get_turns()
|
for part in conv.get_turns():
|
||||||
|
|
||||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
|
||||||
turns = self._build_result(source)
|
|
||||||
|
|
||||||
for part in turns:
|
|
||||||
if part[0] and not part[1]:
|
if part[0] and not part[1]:
|
||||||
LOG.warning(f"role with empty message: {part[0]}")
|
LOG.warning(f"role with empty message: {part[0]}")
|
||||||
yield part
|
yield part
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
|
||||||
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||||
"""
|
"""
|
||||||
@@ -347,15 +304,3 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|||||||
role_key_human=role_key_human,
|
role_key_human=role_key_human,
|
||||||
role_key_model=role_key_model,
|
role_key_model=role_key_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedPrompter:
|
|
||||||
"""
|
|
||||||
A dummy class for custom prompters
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "Pre-tokenized or custom dataset types are unsupported for logging"
|
|
||||||
|
|||||||
@@ -16,11 +16,18 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.monkeypatch import neft_embeddings
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from llava.train.train import safe_save_model_for_hf_trainer
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
||||||
|
raise ImportError("missing LLaVA package")
|
||||||
|
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
src_dir = os.path.join(project_root, "src")
|
src_dir = os.path.join(project_root, "src")
|
||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
@@ -108,7 +115,6 @@ def train(
|
|||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
|
|
||||||
pretrain_hooks(cfg, trainer)
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||||
@@ -116,7 +122,6 @@ def train(
|
|||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
else:
|
else:
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
post_train_hooks(cfg, trainer)
|
|
||||||
|
|
||||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||||
|
|
||||||
@@ -140,6 +145,8 @@ def train(
|
|||||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
trainer.save_model(cfg.output_dir)
|
trainer.save_model(cfg.output_dir)
|
||||||
|
elif cfg.multimodal:
|
||||||
|
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=cfg.output_dir)
|
||||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||||
trainer.accelerator.wait_for_everyone()
|
trainer.accelerator.wait_for_everyone()
|
||||||
@@ -152,37 +159,17 @@ def train(
|
|||||||
# The model name saved is `pytorch_model.bin`
|
# The model name saved is `pytorch_model.bin`
|
||||||
unwrapped_model.save_pretrained(
|
unwrapped_model.save_pretrained(
|
||||||
cfg.output_dir,
|
cfg.output_dir,
|
||||||
is_main_process=trainer.accelerator.is_main_process,
|
is_main_process=trainer.args.should_save,
|
||||||
save_function=trainer.accelerator.save,
|
save_function=trainer.accelerator.save,
|
||||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
||||||
)
|
)
|
||||||
elif cfg.local_rank == 0:
|
elif trainer.args.should_save:
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
# TODO figure out if `trainer.save_model(cfg.output_dir)` is sufficient here
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def pretrain_hooks(cfg, trainer):
|
|
||||||
"""
|
|
||||||
Run hooks right before kicking off the training
|
|
||||||
:param cfg:
|
|
||||||
:param trainer:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
neft_embeddings.pretrain_hook(cfg, trainer)
|
|
||||||
|
|
||||||
|
|
||||||
def post_train_hooks(cfg, trainer):
|
|
||||||
"""
|
|
||||||
Run hooks right after training completes
|
|
||||||
:param cfg:
|
|
||||||
:param trainer:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
neft_embeddings.post_train_hook(cfg, trainer)
|
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
"""Benchmarking and measurement utilities"""
|
"""Benchmarking and measurement utilities"""
|
||||||
import functools
|
import functools
|
||||||
import logging
|
|
||||||
|
|
||||||
import pynvml
|
import pynvml
|
||||||
import torch
|
import torch
|
||||||
from pynvml.nvml import NVMLError
|
from pynvml.nvml import NVMLError
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.bench")
|
|
||||||
|
|
||||||
|
|
||||||
def check_cuda_device(default_value):
|
def check_cuda_device(default_value):
|
||||||
"""
|
"""
|
||||||
@@ -65,14 +62,7 @@ def gpu_memory_usage_smi(device=0):
|
|||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
if not torch.cuda.is_available():
|
usage, cache, misc = gpu_memory_usage_all(device)
|
||||||
return (0, 0, 0)
|
|
||||||
|
|
||||||
try:
|
|
||||||
usage, cache, misc = gpu_memory_usage_all(device)
|
|
||||||
except ValueError as exc:
|
|
||||||
LOG.exception(exc)
|
|
||||||
return (0, 0, 0)
|
|
||||||
extras = []
|
extras = []
|
||||||
if cache > 0:
|
if cache > 0:
|
||||||
extras.append(f"+{cache:.03f}GB cache")
|
extras.append(f"+{cache:.03f}GB cache")
|
||||||
|
|||||||
@@ -369,10 +369,15 @@ def validate_config(cfg):
|
|||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.tensor_parallel and cfg.gradient_checkpointing:
|
if cfg.multimodal:
|
||||||
raise ValueError(
|
try:
|
||||||
"TensorParallelPreTrainedModel does not support gradient checkpointing"
|
import llava # noqa: F401 # pylint:disable=unused-import
|
||||||
)
|
except ImportError as exc:
|
||||||
|
LOG.warning(
|
||||||
|
"LLaVA package required for multimodal training. See docs/llava.md for more information."
|
||||||
|
)
|
||||||
|
raise exc
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import functools
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import (
|
from datasets import (
|
||||||
@@ -36,7 +36,6 @@ from axolotl.prompters import (
|
|||||||
MultipleChoiceExplainPrompter,
|
MultipleChoiceExplainPrompter,
|
||||||
ReflectAlpacaPrompter,
|
ReflectAlpacaPrompter,
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
UnsupportedPrompter,
|
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
@@ -55,11 +54,21 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
|||||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer, model=None):
|
||||||
prompters = []
|
if cfg.multimodal:
|
||||||
if not cfg.pretraining_dataset:
|
if not model:
|
||||||
|
raise ValueError("missing model argument")
|
||||||
|
from llava.train.train import LazySupervisedDataset
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
eval_dataset = None
|
||||||
|
train_dataset = LazySupervisedDataset(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif not cfg.pretraining_dataset:
|
||||||
|
with zero_first(is_main_process()):
|
||||||
|
train_dataset, eval_dataset = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -72,7 +81,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
return train_dataset, eval_dataset, cfg.max_steps
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
@@ -85,7 +94,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
else:
|
else:
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
return train_dataset, eval_dataset, total_num_steps
|
||||||
|
|
||||||
|
|
||||||
def load_tokenized_prepared_datasets(
|
def load_tokenized_prepared_datasets(
|
||||||
@@ -111,7 +120,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
else Path(default_dataset_prepared_path) / ds_hash
|
else Path(default_dataset_prepared_path) / ds_hash
|
||||||
)
|
)
|
||||||
dataset = None
|
dataset = None
|
||||||
prompters = []
|
|
||||||
use_auth_token = cfg.hf_use_auth_token
|
use_auth_token = cfg.hf_use_auth_token
|
||||||
try:
|
try:
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
@@ -150,13 +158,13 @@ def load_tokenized_prepared_datasets(
|
|||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
for config_dataset in for_d_in_datasets(cfg.datasets):
|
for d in for_d_in_datasets(cfg.datasets):
|
||||||
ds: Union[Dataset, DatasetDict] = None
|
ds: Union[Dataset, DatasetDict] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
try:
|
try:
|
||||||
load_dataset(
|
load_dataset(
|
||||||
config_dataset.path,
|
d.path,
|
||||||
name=config_dataset.name,
|
name=d.name,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
@@ -165,33 +173,33 @@ def load_tokenized_prepared_datasets(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
local_path = Path(config_dataset.path)
|
local_path = Path(d.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
if local_path.is_dir():
|
if local_path.is_dir():
|
||||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
config_dataset.path,
|
d.path,
|
||||||
name=config_dataset.name,
|
name=d.name,
|
||||||
data_files=config_dataset.data_files,
|
data_files=d.data_files,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = "json"
|
ds_type = "json"
|
||||||
if config_dataset.ds_type:
|
if d.ds_type:
|
||||||
ds_type = config_dataset.ds_type
|
ds_type = d.ds_type
|
||||||
elif ".parquet" in config_dataset.path:
|
elif ".parquet" in d.path:
|
||||||
ds_type = "parquet"
|
ds_type = "parquet"
|
||||||
elif ".arrow" in config_dataset.path:
|
elif ".arrow" in d.path:
|
||||||
ds_type = "arrow"
|
ds_type = "arrow"
|
||||||
elif ".csv" in config_dataset.path:
|
elif ".csv" in d.path:
|
||||||
ds_type = "csv"
|
ds_type = "csv"
|
||||||
elif ".txt" in config_dataset.path:
|
elif ".txt" in d.path:
|
||||||
ds_type = "text"
|
ds_type = "text"
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
ds_type,
|
ds_type,
|
||||||
name=config_dataset.name,
|
name=d.name,
|
||||||
data_files=config_dataset.path,
|
data_files=d.path,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
@@ -201,25 +209,25 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
elif ds_from_hub:
|
elif ds_from_hub:
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
config_dataset.path,
|
d.path,
|
||||||
name=config_dataset.name,
|
name=d.name,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
data_files=config_dataset.data_files,
|
data_files=d.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if isinstance(config_dataset.data_files, str):
|
if isinstance(d.data_files, str):
|
||||||
fp = hf_hub_download(
|
fp = hf_hub_download(
|
||||||
repo_id=config_dataset.path,
|
repo_id=d.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=config_dataset.data_files,
|
filename=d.data_files,
|
||||||
)
|
)
|
||||||
elif isinstance(config_dataset.data_files, list):
|
elif isinstance(d.data_files, list):
|
||||||
fp = []
|
fp = []
|
||||||
for file in config_dataset.data_files:
|
for file in d.data_files:
|
||||||
fp.append(
|
fp.append(
|
||||||
hf_hub_download(
|
hf_hub_download(
|
||||||
repo_id=config_dataset.path,
|
repo_id=d.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=file,
|
filename=file,
|
||||||
)
|
)
|
||||||
@@ -229,27 +237,21 @@ def load_tokenized_prepared_datasets(
|
|||||||
"data_files must be either a string or list of strings"
|
"data_files must be either a string or list of strings"
|
||||||
)
|
)
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
"json",
|
"json", name=d.name, data_files=fp, streaming=False, split=None
|
||||||
name=config_dataset.name,
|
|
||||||
data_files=fp,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
)
|
)
|
||||||
if not ds:
|
if not ds:
|
||||||
raise ValueError("unhandled dataset load")
|
raise ValueError("unhandled dataset load")
|
||||||
# support for using a subset of the data
|
# support for using a subset of the data
|
||||||
if config_dataset.shards:
|
if d.shards:
|
||||||
if "train" in ds:
|
if "train" in ds:
|
||||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||||
num_shards=config_dataset.shards, index=0
|
num_shards=d.shards, index=0
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ds = ds.shuffle(seed=seed).shard(
|
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||||
num_shards=config_dataset.shards, index=0
|
|
||||||
)
|
|
||||||
|
|
||||||
d_base_type = d_prompt_style = None
|
d_base_type = d_prompt_style = None
|
||||||
d_type = config_dataset.type
|
d_type = d.type
|
||||||
if isinstance(d_type, str):
|
if isinstance(d_type, str):
|
||||||
d_type_split = d_type.split(":")
|
d_type_split = d_type.split(":")
|
||||||
d_base_type = d_type_split[0]
|
d_base_type = d_type_split[0]
|
||||||
@@ -258,26 +260,108 @@ def load_tokenized_prepared_datasets(
|
|||||||
ds = ds["train"]
|
ds = ds["train"]
|
||||||
elif (
|
elif (
|
||||||
isinstance(ds, DatasetDict)
|
isinstance(ds, DatasetDict)
|
||||||
and config_dataset.train_on_split
|
and d.train_on_split
|
||||||
and config_dataset.train_on_split in ds
|
and d.train_on_split in ds
|
||||||
):
|
):
|
||||||
ds = ds[config_dataset.train_on_split]
|
ds = ds[d.train_on_split]
|
||||||
elif isinstance(ds, DatasetDict):
|
elif isinstance(ds, DatasetDict):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
|
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"input_ids" in ds.features
|
||||||
|
and "attention_mask" in ds.features
|
||||||
|
and "labels" in ds.features
|
||||||
|
):
|
||||||
|
# dataset is already tokenized, just drop it straight in
|
||||||
|
datasets.append(ds)
|
||||||
|
elif isinstance(d.type, DictDefault):
|
||||||
|
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
elif d_base_type == "alpaca":
|
||||||
|
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||||
|
AlpacaPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
elif d_base_type == "explainchoice":
|
||||||
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||||
|
MultipleChoiceExplainPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
elif d_base_type == "concisechoice":
|
||||||
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||||
|
MultipleChoiceConcisePrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
elif d_base_type == "summarizetldr":
|
||||||
|
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||||
|
SummarizeTLDRPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
elif d_base_type == "jeopardy":
|
||||||
|
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||||
|
JeopardyPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
elif d_base_type == "oasst":
|
||||||
|
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||||
|
AlpacaPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
elif d_base_type == "gpteacher":
|
||||||
|
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||||
|
GPTeacherPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
elif d_base_type == "reflection":
|
||||||
|
ds_strategy = AlpacaReflectionPTStrategy(
|
||||||
|
ReflectAlpacaPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
else:
|
||||||
|
suffix = ""
|
||||||
|
if ":load_" in d.type:
|
||||||
|
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
||||||
|
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
|
||||||
|
raise ValueError(
|
||||||
|
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
|
||||||
config_dataset=config_dataset,
|
|
||||||
dataset=ds,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
cfg=cfg,
|
|
||||||
d_base_type=d_base_type,
|
|
||||||
d_prompt_style=d_prompt_style,
|
|
||||||
)
|
|
||||||
datasets.append(dataset_wrapper)
|
|
||||||
prompters.append(dataset_prompter)
|
|
||||||
|
|
||||||
LOG.info("merging datasets")
|
LOG.info("merging datasets")
|
||||||
dataset = concatenate_datasets(datasets)
|
dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
@@ -295,14 +379,14 @@ def load_tokenized_prepared_datasets(
|
|||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset, prompters
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_datasets(
|
def load_prepare_datasets(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
cfg,
|
cfg,
|
||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
) -> Tuple[Dataset, Dataset, List[Any]]:
|
) -> Tuple[Dataset, Dataset]:
|
||||||
max_packed_sequence_len = (
|
max_packed_sequence_len = (
|
||||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||||
)
|
)
|
||||||
@@ -311,7 +395,6 @@ def load_prepare_datasets(
|
|||||||
) # make sure we don't accidentally set it larger than sequence_len
|
) # make sure we don't accidentally set it larger than sequence_len
|
||||||
|
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
prompters = []
|
|
||||||
if cfg.max_packed_sequence_len is not None:
|
if cfg.max_packed_sequence_len is not None:
|
||||||
# see if we can go ahead and load the stacked dataset
|
# see if we can go ahead and load the stacked dataset
|
||||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||||
@@ -367,7 +450,7 @@ def load_prepare_datasets(
|
|||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset, prompters = load_tokenized_prepared_datasets(
|
dataset = load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -409,7 +492,7 @@ def load_prepare_datasets(
|
|||||||
private=True,
|
private=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset, prompters = load_tokenized_prepared_datasets(
|
dataset = load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -460,124 +543,7 @@ def load_prepare_datasets(
|
|||||||
train_dataset = dataset
|
train_dataset = dataset
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
return train_dataset, eval_dataset, prompters
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_wrapper(
|
|
||||||
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
|
|
||||||
):
|
|
||||||
dataset_wrapper = None
|
|
||||||
dataset_prompter = None
|
|
||||||
|
|
||||||
if (
|
|
||||||
"input_ids" in dataset.features
|
|
||||||
and "attention_mask" in dataset.features
|
|
||||||
and "labels" in dataset.features
|
|
||||||
):
|
|
||||||
# dataset is already tokenized, just drop it straight in
|
|
||||||
dataset_prompter = UnsupportedPrompter()
|
|
||||||
dataset_wrapper = dataset
|
|
||||||
elif isinstance(config_dataset.type, DictDefault):
|
|
||||||
ds_strategy = load(
|
|
||||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
|
||||||
)
|
|
||||||
dataset_prompter = UnsupportedPrompter()
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
|
||||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
|
||||||
dataset_prompter = UnsupportedPrompter()
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
|
||||||
elif d_base_type == "alpaca":
|
|
||||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
|
||||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
|
||||||
dataset_prompter,
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
|
||||||
dataset_wrapper = ds_wrapper
|
|
||||||
elif d_base_type == "explainchoice":
|
|
||||||
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
|
||||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
|
||||||
dataset_prompter,
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
|
||||||
dataset_wrapper = ds_wrapper
|
|
||||||
elif d_base_type == "concisechoice":
|
|
||||||
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
|
||||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
|
||||||
dataset_prompter,
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
|
||||||
dataset_wrapper = ds_wrapper
|
|
||||||
elif d_base_type == "summarizetldr":
|
|
||||||
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
|
||||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
|
||||||
dataset_prompter,
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
|
||||||
dataset_wrapper = ds_wrapper
|
|
||||||
elif d_base_type == "jeopardy":
|
|
||||||
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
|
||||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
|
||||||
dataset_prompter,
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
|
||||||
dataset_wrapper = ds_wrapper
|
|
||||||
elif d_base_type == "oasst":
|
|
||||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
|
||||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
|
||||||
dataset_prompter,
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
|
||||||
dataset_wrapper = ds_wrapper
|
|
||||||
elif d_base_type == "gpteacher":
|
|
||||||
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
|
||||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
|
||||||
dataset_prompter,
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
|
||||||
dataset_wrapper = ds_wrapper
|
|
||||||
elif d_base_type == "reflection":
|
|
||||||
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
|
||||||
ds_strategy = AlpacaReflectionPTStrategy(
|
|
||||||
dataset_prompter,
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
|
||||||
dataset_wrapper = ds_wrapper
|
|
||||||
else:
|
|
||||||
suffix = ""
|
|
||||||
if ":load_" in config_dataset.type:
|
|
||||||
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
|
|
||||||
LOG.error(
|
|
||||||
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataset_wrapper, dataset_prompter
|
|
||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
def encode_pretraining(
|
||||||
|
|||||||
@@ -3,9 +3,6 @@ import hashlib
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
from queue import Queue
|
|
||||||
from threading import Thread
|
|
||||||
from typing import Any, Callable, List, Union
|
from typing import Any, Callable, List, Union
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
@@ -152,8 +149,6 @@ class MultipackDistributedDataloader:
|
|||||||
packing_efficiency_estimate: float = 1.0,
|
packing_efficiency_estimate: float = 1.0,
|
||||||
sample_packing_seq_len_multiplier: int = 1,
|
sample_packing_seq_len_multiplier: int = 1,
|
||||||
device_count: int = 1,
|
device_count: int = 1,
|
||||||
prefetch_max: int = 1000,
|
|
||||||
num_epochs: int = 1,
|
|
||||||
):
|
):
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
@@ -172,7 +167,6 @@ class MultipackDistributedDataloader:
|
|||||||
self.seq_max_length = seq_max_length
|
self.seq_max_length = seq_max_length
|
||||||
self.batch_max_length = batch_size * seq_max_length
|
self.batch_max_length = batch_size * seq_max_length
|
||||||
self.collate_fn = collate_fn
|
self.collate_fn = collate_fn
|
||||||
self.num_epochs = num_epochs
|
|
||||||
|
|
||||||
self.num_replicas = 1
|
self.num_replicas = 1
|
||||||
self.rank = 0
|
self.rank = 0
|
||||||
@@ -183,44 +177,6 @@ class MultipackDistributedDataloader:
|
|||||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
self.device_count = device_count
|
self.device_count = device_count
|
||||||
|
|
||||||
# maxsize is maximum number of samples in queue
|
|
||||||
self.prefetch_max = prefetch_max
|
|
||||||
self.queue: Queue = Queue(maxsize=prefetch_max)
|
|
||||||
self.thread = None
|
|
||||||
|
|
||||||
def _worker(self):
|
|
||||||
LOG.info(
|
|
||||||
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
|
|
||||||
)
|
|
||||||
for epoch in range(self.num_epochs):
|
|
||||||
for sample in self._internal_batch_generator():
|
|
||||||
while True:
|
|
||||||
if self.queue.full():
|
|
||||||
time.sleep(1)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
self.queue.put(sample)
|
|
||||||
|
|
||||||
# stop the queue when epoch is done
|
|
||||||
self.queue.put(None)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
if hasattr(self.sampler, "set_epoch"):
|
|
||||||
new_epoch = self.sampler.epoch + 1
|
|
||||||
self.sampler.set_epoch(new_epoch)
|
|
||||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
|
||||||
|
|
||||||
if self.thread is None:
|
|
||||||
self.thread = Thread(target=self._worker, daemon=True)
|
|
||||||
self.thread.start()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
item = self.queue.get()
|
|
||||||
|
|
||||||
if item is None:
|
|
||||||
break
|
|
||||||
yield item
|
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
def generate_batches(self, set_stats=False):
|
||||||
LOG.info("generating packed batches")
|
LOG.info("generating packed batches")
|
||||||
if self.sampler:
|
if self.sampler:
|
||||||
@@ -250,7 +206,11 @@ class MultipackDistributedDataloader:
|
|||||||
|
|
||||||
return batches, totseqs
|
return batches, totseqs
|
||||||
|
|
||||||
def _internal_batch_generator(self):
|
def __iter__(self):
|
||||||
|
if hasattr(self.sampler, "set_epoch"):
|
||||||
|
new_epoch = self.sampler.epoch + 1
|
||||||
|
self.sampler.set_epoch(new_epoch)
|
||||||
|
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||||
all_batches, _ = self.generate_batches(set_stats=True)
|
all_batches, _ = self.generate_batches(set_stats=True)
|
||||||
features = self.dataset.features.keys()
|
features = self.dataset.features.keys()
|
||||||
len_remaining = self._len_est()
|
len_remaining = self._len_est()
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Optional, Tuple # noqa: F401
|
|||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.utils.bitsandbytes
|
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from peft import PeftConfig, prepare_model_for_kbit_training
|
from peft import PeftConfig, prepare_model_for_kbit_training
|
||||||
from peft.tuners.lora import QuantLinear
|
from peft.tuners.lora import QuantLinear
|
||||||
@@ -32,7 +31,7 @@ LOG = logging.getLogger("axolotl")
|
|||||||
|
|
||||||
def load_model_config(cfg):
|
def load_model_config(cfg):
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
model_config_name = cfg.base_model_config or cfg.base_model
|
||||||
trust_remote_code = cfg.trust_remote_code is True
|
trust_remote_code: bool = False or cfg.trust_remote_code
|
||||||
return AutoConfig.from_pretrained(
|
return AutoConfig.from_pretrained(
|
||||||
model_config_name, trust_remote_code=trust_remote_code
|
model_config_name, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
@@ -73,6 +72,11 @@ def load_tokenizer(cfg):
|
|||||||
# set a pad_token, but use eos_token so we don't add a new token
|
# set a pad_token, but use eos_token so we don't add a new token
|
||||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||||
|
|
||||||
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||||
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -94,11 +98,6 @@ def load_tokenizer(cfg):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
|
||||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
|
||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
|
||||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -181,6 +180,26 @@ def load_model(
|
|||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
||||||
|
from axolotl.monkeypatch.llama_embeddings_hijack import (
|
||||||
|
replace_llama_embeddings_with_uniform_distribution,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with noisy embeddings")
|
||||||
|
replace_llama_embeddings_with_uniform_distribution(
|
||||||
|
noise_alpha=cfg.noisy_embedding_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
||||||
|
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
||||||
|
replace_mistral_embeddings_with_uniform_distribution,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with noisy embeddings")
|
||||||
|
replace_mistral_embeddings_with_uniform_distribution(
|
||||||
|
noise_alpha=cfg.noisy_embedding_alpha
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||||
replace_llama_rope_with_xpos_rope,
|
replace_llama_rope_with_xpos_rope,
|
||||||
@@ -222,7 +241,7 @@ def load_model(
|
|||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
llm_int8_threshold=6.0,
|
llm_int8_threshold=6.0,
|
||||||
llm_int8_has_fp16_weight=False,
|
llm_int8_has_fp16_weight=False,
|
||||||
bnb_4bit_compute_dtype=torch.float16,
|
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
@@ -236,12 +255,102 @@ def load_model(
|
|||||||
model_kwargs["use_flash_attention_2"] = True
|
model_kwargs["use_flash_attention_2"] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if (
|
if cfg.multimodal:
|
||||||
cfg.is_llama_derived_model
|
from llava.train.train import DataArguments, ModelArguments
|
||||||
and not cfg.trust_remote_code
|
|
||||||
and not cfg.gptq
|
if cfg.is_llama_derived_model:
|
||||||
and not cfg.tensor_parallel
|
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
|
||||||
):
|
|
||||||
|
model = LlavaLlamaForCausalLM.from_pretrained(
|
||||||
|
cfg.base_model,
|
||||||
|
)
|
||||||
|
elif cfg.is_mistral_derived_model:
|
||||||
|
from axolotl.models.llava.llava_mistral import LlavaMistralForCausalLM
|
||||||
|
|
||||||
|
model = LlavaMistralForCausalLM.from_pretrained(
|
||||||
|
cfg.base_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"unhandled model architecture for multimodal training"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.mm_freeze_backbone:
|
||||||
|
model.model.requires_grad_(False)
|
||||||
|
|
||||||
|
if cfg.gradient_checkpointing:
|
||||||
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
|
||||||
|
def make_inputs_require_grad(
|
||||||
|
module,
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
): # pylint: disable=redefined-builtin,unused-argument
|
||||||
|
output.requires_grad_(True)
|
||||||
|
|
||||||
|
model.get_input_embeddings().register_forward_hook(
|
||||||
|
make_inputs_require_grad
|
||||||
|
)
|
||||||
|
|
||||||
|
model_args = ModelArguments(
|
||||||
|
model_name_or_path=cfg.base_model,
|
||||||
|
version="v0",
|
||||||
|
freeze_backbone=cfg.mm_freeze_backbone or False,
|
||||||
|
tune_mm_mlp_adapter=cfg.tune_mm_mlp_adapter or False,
|
||||||
|
vision_tower=cfg.mm_vision_tower,
|
||||||
|
mm_vision_select_layer=cfg.mm_vision_select_layer or -1,
|
||||||
|
pretrain_mm_mlp_adapter=cfg.pretrain_mm_mlp_adapter,
|
||||||
|
mm_projector_type=cfg.mm_projector_type or "linear",
|
||||||
|
mm_use_im_start_end=cfg.mm_use_im_start_end or False,
|
||||||
|
mm_use_im_patch_token=cfg.mm_use_im_patch_token,
|
||||||
|
mm_vision_select_feature=cfg.mm_vision_select_feature or "patch",
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.mm_vision_tower is not None:
|
||||||
|
model.get_model().initialize_vision_modules(
|
||||||
|
model_args=model_args, fsdp=cfg.fsdp
|
||||||
|
)
|
||||||
|
|
||||||
|
vision_tower = model.get_vision_tower()
|
||||||
|
vision_tower.to(dtype=cfg.torch_dtype, device=cfg.device)
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
data_args = DataArguments(
|
||||||
|
data_path=cfg.datasets[0]["path"],
|
||||||
|
lazy_preprocess=cfg.mm_lazy_preprocess
|
||||||
|
if cfg.mm_lazy_preprocess is not None
|
||||||
|
else True,
|
||||||
|
is_multimodal=True,
|
||||||
|
image_folder=cfg.mm_image_folder or None,
|
||||||
|
image_aspect_ratio=cfg.mm_image_aspect_ratio or "square",
|
||||||
|
image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None,
|
||||||
|
)
|
||||||
|
data_args.image_processor = vision_tower.image_processor
|
||||||
|
model.config.image_aspect_ratio = data_args.image_aspect_ratio
|
||||||
|
model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
|
||||||
|
model.config.tune_mm_mlp_adapter = cfg.tune_mm_mlp_adapter
|
||||||
|
if cfg.tune_mm_mlp_adapter:
|
||||||
|
model.requires_grad_(False)
|
||||||
|
for (
|
||||||
|
p # pylint: disable=invalid-name
|
||||||
|
) in model.get_model().mm_projector.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
model.config.freeze_mm_mlp_adapter = cfg.freeze_mm_mlp_adapter
|
||||||
|
if cfg.freeze_mm_mlp_adapter:
|
||||||
|
for (
|
||||||
|
p # pylint: disable=invalid-name
|
||||||
|
) in model.get_model().mm_projector.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
model.config.mm_use_im_start_end = (
|
||||||
|
data_args.mm_use_im_start_end
|
||||||
|
) = cfg.mm_use_im_start_end
|
||||||
|
model.config.mm_use_im_patch_token = cfg.mm_use_im_patch_token
|
||||||
|
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
|
||||||
|
elif cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
config_kwargs = {}
|
config_kwargs = {}
|
||||||
@@ -307,7 +416,7 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel:
|
elif model_type and not cfg.trust_remote_code:
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -322,17 +431,6 @@ def load_model(
|
|||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif cfg.tensor_parallel:
|
|
||||||
model_kwargs.pop("device_map")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
offload_state_dict=True,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -383,18 +481,15 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
embeddings_len = (
|
||||||
embeddings_len = (
|
math.ceil(len(tokenizer) / 32) * 32
|
||||||
math.ceil(len(tokenizer) / 32) * 32
|
if cfg.resize_token_embeddings_to_32x
|
||||||
if cfg.resize_token_embeddings_to_32x
|
else len(tokenizer)
|
||||||
else len(tokenizer)
|
)
|
||||||
)
|
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
model.resize_token_embeddings(embeddings_len)
|
||||||
model.resize_token_embeddings(embeddings_len)
|
else:
|
||||||
else:
|
model.tie_weights()
|
||||||
model.tie_weights()
|
|
||||||
except NotImplementedError:
|
|
||||||
LOG.warning("`resize_token_embeddings` not implemented on model")
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
@@ -406,20 +501,6 @@ def load_model(
|
|||||||
)
|
)
|
||||||
model.config.max_position_embeddings = cfg.sequence_len
|
model.config.max_position_embeddings = cfg.sequence_len
|
||||||
|
|
||||||
if (
|
|
||||||
hasattr(model.config, "bos_token_id")
|
|
||||||
and model.config.bos_token_id
|
|
||||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
|
||||||
):
|
|
||||||
model.config.bos_token_id = tokenizer.bos_token_id
|
|
||||||
|
|
||||||
if (
|
|
||||||
hasattr(model.config, "eos_token_id")
|
|
||||||
and model.config.eos_token_id
|
|
||||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
|
||||||
):
|
|
||||||
model.config.eos_token_id = tokenizer.eos_token_id
|
|
||||||
|
|
||||||
if model.device.type == "cuda":
|
if model.device.type == "cuda":
|
||||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||||
|
|
||||||
@@ -497,12 +578,7 @@ def load_adapter(model, cfg, adapter, inference=False):
|
|||||||
if adapter is None:
|
if adapter is None:
|
||||||
return model, None
|
return model, None
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
try:
|
model.enable_input_require_grads()
|
||||||
model.enable_input_require_grads()
|
|
||||||
except NotImplementedError:
|
|
||||||
LOG.warning("enable_input_require_grads not implemented on model")
|
|
||||||
if adapter == "qlora" and cfg.tensor_parallel:
|
|
||||||
model, _ = load_tp_qlora(model)
|
|
||||||
if adapter in ["lora", "qlora"]:
|
if adapter in ["lora", "qlora"]:
|
||||||
return load_lora(model, cfg, inference=inference)
|
return load_lora(model, cfg, inference=inference)
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
@@ -539,7 +615,14 @@ def load_llama_adapter(model, cfg):
|
|||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
|
multimodal_keywords = [
|
||||||
|
"mm_projector",
|
||||||
|
"vision_tower",
|
||||||
|
"vision_resampler",
|
||||||
|
] # for LLaVA
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
||||||
|
continue
|
||||||
if (
|
if (
|
||||||
isinstance(module, cls)
|
isinstance(module, cls)
|
||||||
or "Linear" in module.__class__.__name__
|
or "Linear" in module.__class__.__name__
|
||||||
@@ -554,25 +637,6 @@ def find_all_linear_names(model):
|
|||||||
return list(lora_module_names)
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|
||||||
def load_tp_qlora(model):
|
|
||||||
from transformers.utils.bitsandbytes import replace_with_bnb_linear
|
|
||||||
|
|
||||||
model = replace_with_bnb_linear(
|
|
||||||
model,
|
|
||||||
quantization_config=BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
llm_int8_threshold=6.0,
|
|
||||||
llm_int8_has_fp16_weight=False,
|
|
||||||
bnb_4bit_compute_dtype=torch.float16,
|
|
||||||
bnb_4bit_use_double_quant=True,
|
|
||||||
bnb_4bit_quant_type="nf4",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model.is_loaded_in_4bit = True
|
|
||||||
|
|
||||||
return model, None
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model, cfg, inference=False):
|
def load_lora(model, cfg, inference=False):
|
||||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import torch.distributed as dist
|
|||||||
from datasets import set_caching_enabled
|
from datasets import set_caching_enabled
|
||||||
from torch.utils.data import DistributedSampler, RandomSampler
|
from torch.utils.data import DistributedSampler, RandomSampler
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
from axolotl.core.trainer_builder import AxolotlTrainer, HFCausalTrainerBuilder
|
||||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
@@ -216,7 +216,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
num_epochs=cfg.num_epochs,
|
|
||||||
)
|
)
|
||||||
data_loader_len = data_loader.len_w_stats()
|
data_loader_len = data_loader.len_w_stats()
|
||||||
actual_eff = data_loader.efficiency()
|
actual_eff = data_loader.efficiency()
|
||||||
@@ -260,7 +259,9 @@ def setup_fsdp_envs(cfg):
|
|||||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(
|
||||||
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
|
) -> AxolotlTrainer:
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
|
|||||||
Reference in New Issue
Block a user