Compare commits

...

23 Commits

Author SHA1 Message Date
kallewoof
450e04d3c4 fix: remove excessive newlines in system prompt(s) for alpaca (#936) 2023-12-13 16:40:02 +09:00
Juraj Bednar
b0cf397ecb More hints on what to do with CUDA Out of memory errors (#925) 2023-12-13 16:38:38 +09:00
Wing Lian
5f79b8242f new evals_per_epoch and saves_per_epoch to make things cleaner (#944)
* new evals_per_epoch and saves_per_epoch to make things cleaner

* update per PR feedback
2023-12-12 15:35:23 -05:00
Hamel Husain
f1de29dd1e Respect sequence_len in config for type: llama2_chat (#926)
* Respect sequence_len in config for `type: llama2_chat`

It was hardcoded to `4096` I am not sure why?  This updates it to pull from the config. 

cc: @winglian

* Update llama2_chat.py

* apply black formatting

* fix tokenizer

* update test data

* lint fixtures
2023-12-12 09:39:22 -08:00
Wing Lian
7fabc4d95e Mixtral official (#942)
* multipack support for official mixtral implementation

* fix patch to load multipack for mixtral

* chore: lint
2023-12-11 23:44:33 -05:00
Motoki Wu
9a5eb3990c Update requirements.txt (#940) 2023-12-11 22:57:28 -05:00
Casper
86487c2e96 Mixtral: More correct MoE, lower loss (#932)
* More correct MoE

* Fix formatting
2023-12-10 10:34:25 -05:00
Wing Lian
35f9b0f149 update to latest transformers for mixstral support (#929)
* update to latest transformers for mixstral support

* pin transformers

* fix typo
2023-12-10 10:32:27 -05:00
Wing Lian
68b227a7d8 Mixtral multipack (#928)
* mixtral multipack

* use mixtral model

* sample yml

* calculate cu_seqlens properly

* use updated flash ettention setting

* attn var checks

* force use of flash attention 2 for packing

* lint

* disable future fix for now

* update support table
2023-12-09 21:26:30 -05:00
Timothy Lim
03c6318ba3 fixing prompt template of chatml by removal of linebreak (#922)
Co-authored-by: Timothy  Lim <timothyyonglee.lim@kxrdev.com>
2023-12-09 13:07:44 -05:00
Wing Lian
40a6362c92 support for mamba (#915)
* support for mamba

* more mamba fixes

* use fork for mamba kwargs fix

* grad checkpointing doesn't work

* fix extras for mamaba

* mamba loss fix

* use fp32 and remove verbose logging

* mamba fixes

* fix collator for mamba

* set model_type on training_args

* don't save safetensors for mamba

* update mamba config to disable safetensor checkpooints, install for tests

* no evals for mamba tests

* handle save_pretrained

* handle unused safetensors arg
2023-12-09 12:10:41 -05:00
NanoCode012
d339beb9d9 chore: clarify Readme on sharegpt system role 2023-12-08 11:35:53 +09:00
NanoCode012
fde091cb12 fix(tokenizer): handle fast tokenizer properly for bos/eos (#914) 2023-12-08 11:31:13 +09:00
Casper
06ae39200b Pin flash-attn to 2.3.3 (#919) 2023-12-07 07:36:52 +01:00
NanoCode012
a581e9f8f6 feat: add check for quantized model (#913)
* feat: add check for quantized model

* chore: refactor and add another check

* Update src/axolotl/utils/models.py

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-12-05 01:20:06 +09:00
Bryan Thornbury
992e742cdc Support device_map=sequential & max_memory config parameters (#903)
* Support device_map sequential (and others). Support max_memory in cfg.

* Update documentation in README accordingly.

* Update README.md

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-12-04 09:29:21 -05:00
NanoCode012
a1da39cd48 Feat(wandb): Refactor to be more flexible (#767)
* Feat: Update to handle wandb env better

* chore: rename wandb_run_id to wandb_name

* feat: add new recommendation and update config

* fix: indent and pop disabled env if project passed

* feat: test env set for wandb and recommendation

* feat: update to use wandb_name and allow id

* chore: add info to readme
2023-12-04 22:17:25 +09:00
kallewoof
58ec8b1113 feature: loss watchdog for terminating training runs that are failing (#899)
Co-authored-by: Karl-Johan Alm <kalle@gmail.com>
2023-12-04 07:54:34 -05:00
Haoxiang Wang
476a205cea Remove learning rate scheduler in deepspeed config to avoid conflict (#909) 2023-12-04 05:17:38 -05:00
Wing Lian
3e3229e2d9 fix for qwen w lora (#906) 2023-11-30 12:45:50 -05:00
Wing Lian
1d21aa6b0a ensure merged model matches the training dtype (#902)
* ensure merged model matches the training dtype

* Update src/axolotl/cli/__init__.py

* Update src/axolotl/cli/__init__.py
2023-11-29 09:55:19 -05:00
kallewoof
71b7ea3c05 Determine FSDP/deepspeed settings on device select. (#883)
* Determine FSDP/deepspeed settings on device select.

Without this, the OS env check for accelerate will fail.

* rename and move env setup call

* chore: lint

---------

Co-authored-by: Karl-Johan Alm <kalle@gmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-11-29 08:36:35 -05:00
NanoCode012
a48dbf6561 fix: remove FA for qwen examples (#900)
* fix: remove FA for qwen lora

* fix: remove FA for qlora
2023-11-27 21:23:54 +09:00
65 changed files with 1336 additions and 208 deletions

View File

@@ -73,7 +73,7 @@ jobs:
run: |
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn]
pip3 install -U -e .[flash-attn,mamba-ssm]
pip3 install -r requirements-tests.txt
- name: Run e2e tests

View File

@@ -8,6 +8,9 @@ ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*]
ignore_errors = True
[mypy-axolotl.models.mixtral.*]
ignore_errors = True
[mypy-axolotl.models.phi.*]
ignore_errors = True

View File

@@ -65,19 +65,21 @@ Features:
## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | | | | |
| cerebras | ✅ | ✅ | ✅ | | | | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| gpt-j | ✅ | | | ❌ | ❌ | | ❓ |
| XGen | ✅ | | ✅ | | | | |
| phi | ✅ | ✅ | ✅ | | | ❓ | ❓ |
| RWKV | ✅ | ❓ | | ❓ | ❓ | ❓ | |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | | | | |
| Mixtral-MoE | ✅ | ✅ | ✅ | | | | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | | | ❌ | ❌ | | ❓ |
| falcon | ✅ | | ✅ | | | | |
| gpt-j | ✅ | ✅ | ✅ | | | ❓ | ❓ |
| XGen | ✅ | ❓ | | ❓ | ❓ | ❓ | |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
## Quickstart ⚡
@@ -245,7 +247,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations where `from` is `human`/`gpt`
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -612,6 +614,12 @@ eval_sample_packing:
sample_packing_eff_est:
total_num_tokens:
# Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory
device_map:
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
max_memory:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
# If you already have a lora model trained that you want to load, put that here.
@@ -659,7 +667,8 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
wandb_project: # Your wandb project name
wandb_entity: # A wandb Team name if using a Team
wandb_watch:
wandb_run_id: # Set the name of your wandb run
wandb_name: # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# Where to save the full-finetuned model to
@@ -682,9 +691,11 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
save_strategy: # Set to `no` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
@@ -694,6 +705,9 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package)
save_safetensors:
@@ -952,7 +966,7 @@ wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
```
@@ -1008,6 +1022,10 @@ Please reduce any below
- `gradient_accumulation_steps`
- `sequence_len`
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
Using adamw_bnb_8bit might also save you some memory.
> `failed (exitcode: -9)`
Usually means your system has run out of system memory.

View File

@@ -24,16 +24,6 @@
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",

View File

@@ -28,16 +28,6 @@
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",

View File

@@ -32,16 +32,6 @@
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",

View File

@@ -4,6 +4,7 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh

View File

@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: btlm-out
@@ -72,8 +72,8 @@ gptq_groupsize:
gptq_model_v1:
warmup_steps: 32
eval_steps:
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
save_total_limit:
debug:

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./qlora-out
batch_size: 4
@@ -49,8 +49,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./falcon-7b
batch_size: 2
@@ -51,8 +51,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 40
eval_steps: 5
save_steps: 43
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./qlora-out
@@ -80,8 +80,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 5
save_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.000001

View File

@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./falcon-7b
batch_size: 2
@@ -51,8 +51,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 40
eval_steps: 5
save_steps: 43
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./qlora-out
gradient_accumulation_steps: 2
@@ -46,8 +46,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1
@@ -42,8 +42,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 110
save_steps: 660
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
@@ -58,9 +58,9 @@ flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
save_steps:
saves_per_epoch: 1
debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1

View File

@@ -32,7 +32,7 @@ lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./model-out
gradient_accumulation_steps: 1
@@ -62,8 +62,8 @@ flash_attention:
sdp_attention:
flash_optimum:
warmup_steps: 100
eval_steps:
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -54,10 +54,10 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -56,9 +56,9 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -35,7 +35,7 @@ relora_cpu_offload: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -60,8 +60,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps: 50
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -54,9 +54,9 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

61
examples/mamba/config.yml Normal file
View File

@@ -0,0 +1,61 @@
base_model: state-spaces/mamba-2.8b
model_type: MambaLMHeadModel
tokenizer_type: AutoTokenizer
tokenizer_config: EleutherAI/gpt-neox-20b
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./out
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
tokens:
save_safetensors: False

View File

@@ -21,7 +21,7 @@ pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -46,10 +46,10 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -0,0 +1,79 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
#lora_target_modules:
# - gate
# - q_proj
# - k_proj
# - v_proj
# - o_proj
# - w1
# - w2
# - w3
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed/zero2.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -38,7 +38,7 @@ lora_target_modules:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -62,11 +62,14 @@ logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1
@@ -44,8 +44,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 110
save_steps: 660
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0001

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./openllama-out
gradient_accumulation_steps: 1
@@ -49,8 +49,8 @@ flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./lora-out
gradient_accumulation_steps: 1
@@ -54,8 +54,8 @@ flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./qlora-out
gradient_accumulation_steps: 1
@@ -48,8 +48,8 @@ flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
@@ -59,8 +59,8 @@ xformers_attention:
flash_attention:
warmup_steps: 100
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
@@ -59,8 +59,8 @@ xformers_attention:
flash_attention:
warmup_steps: 100
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./pythia-12b
gradient_accumulation_steps: 1

View File

@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1
@@ -33,5 +33,5 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
weight_decay: 0.1
eval_steps: 0.05
evals_per_epoch: 4
logging_steps: 1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -53,13 +53,13 @@ resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attention:
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -53,13 +53,13 @@ resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attention:
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./redpajama-alpaca-3b
batch_size: 4
@@ -45,8 +45,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 110
save_steps: 660
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0001

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project: lora-replit
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./lora-replit
batch_size: 8
@@ -45,8 +45,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 50
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0

View File

@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./qlora-out
@@ -78,8 +78,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 50
save_steps: 50
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -2,7 +2,7 @@
auto-gptq==0.5.1
packaging
peft==0.6.0
transformers==4.35.2
transformers @ git+https://github.com/huggingface/transformers.git@e5079b0b2abcef11ecbdae60ba4a6636c57b725d
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.24.1
@@ -29,7 +29,7 @@ scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.29
fschat==0.2.34
gradio==3.50.2
tensorboard

View File

@@ -46,10 +46,13 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn>=2.3.0",
"flash-attn==2.3.3",
],
"deepspeed": [
"deepspeed",
],
"mamba-ssm": [
"mamba-ssm==1.0.1",
],
},
)

View File

@@ -29,6 +29,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -71,7 +72,7 @@ def do_merge_lora(
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
model.to(dtype=cfg.torch_dtype)
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
@@ -296,6 +297,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
validate_config(cfg)
prepare_optim_env(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)

View File

@@ -25,12 +25,16 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
MambaDataCollator,
)
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
@@ -48,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments):
Extend the base TrainingArguments for axolotl helpers
"""
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
@@ -284,6 +291,32 @@ class AxolotlTrainer(Trainer):
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""
def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
@@ -430,6 +463,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -458,6 +494,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -525,7 +563,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors:
if self.cfg.save_safetensors is not None:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est:
@@ -643,7 +681,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_run_id if self.cfg.use_wandb else None
self.cfg.wandb_name if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
@@ -673,6 +711,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
@@ -727,11 +766,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
data_collator=self.build_collator(**data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
@@ -751,3 +786,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.micro_batch_size
return trainer
def build_collator(self, **kwargs):
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
)

View File

@@ -0,0 +1,12 @@
"""
Modeling module for Mamba models
"""
def fix_mamba_attn_for_loss():
from mamba_ssm.models import mixer_seq_simple
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name

View File

@@ -0,0 +1,42 @@
"""
HF Transformers MambaConfig
"""
from transformers import PretrainedConfig
class MambaConfig(PretrainedConfig):
"""
modeling configuration for state space model/mamba
"""
model_type = "mamba"
def __init__(
self,
vocab_size=50280,
d_model=2560,
n_layer=64,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=8,
pad_token_id=50277,
bos_token_id=0,
eos_token_id=0,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layer = n_layer
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@@ -0,0 +1,128 @@
# pylint: skip-file
import os
from collections import namedtuple
from functools import partial
from typing import Optional, Union
import torch
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
from torch import nn
from torch.nn import CrossEntropyLoss
from axolotl.models.mamba.configuration_mamba import MambaConfig
class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
d_model: int,
n_layer: int,
vocab_size: int,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
device=None,
dtype=None,
**backbone_kwargs,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (
vocab_size % pad_vocab_size_multiple
)
self.config = MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layer,
pad_vocab_size_multiple=pad_vocab_size_multiple,
)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
initializer_cfg=initializer_cfg,
**backbone_kwargs,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()
def tie_weights(self):
self.lm_head.weight = self.backbone.embedding.weight
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(
self,
input_ids,
position_ids=None,
inference_params=None,
num_last_tokens=0,
labels=None,
**kwargs,
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
loss = None
if labels is not None:
logits = lm_logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
print(loss)
return CausalLMOutput(logits=lm_logits, loss=loss)
else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
):
if state_dict is None:
state_dict = self.state_dict()
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config = load_config_hf(pretrained_model_name)
model = cls(**config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
)
return model

View File

@@ -0,0 +1,22 @@
"""
Patches to support multipack for mixtral
"""
import transformers
def replace_mixtral_attn_with_multipack_flash_attn():
from .modeling_mixtral import (
MixtralMultipackFlashAttention2,
mixtral_decoder_layer_forward,
mixtral_model_forward,
)
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
mixtral_decoder_layer_forward
)
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
mixtral_model_forward
)
transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
"flash_attention_2"
] = MixtralMultipackFlashAttention2

View File

@@ -0,0 +1,379 @@
"""
Mixtral modeling for multipack
"""
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
import logging
import warnings
from typing import List, Optional, Tuple, Union
import torch
from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from transformers import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.mixtral.modeling_mixtral import (
MixtralFlashAttention2,
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
"""
Custom multipack implementation w flash attention 2
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
attn_output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=self.attention_dropout,
softmax_scale=None,
causal=True,
)
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def mixtral_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_router_logits (`bool`, *optional*):
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
should not be returned during inference.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
if output_router_logits:
outputs += (router_logits,)
return outputs
def mixtral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = (
attention_mask
if (attention_mask is not None and 0 in attention_mask)
else None
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
LOG.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_router_logits,
]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)

View File

@@ -81,8 +81,9 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sequence_len = 4096
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
self.tokenizer.add_special_tokens(
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")}
)
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
def tokenize_prompt(self, prompt):

View File

@@ -13,7 +13,7 @@ register_conv_template(
system_message="You are a helpful assistant.",
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>\n",
sep="<|im_end|>",
)
)

View File

@@ -33,8 +33,8 @@ class AlpacaPrompter(Prompter):
Base class for alpaca prompters
"""
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
system_format: str = "{system}"
turn_format: str
turn_no_input_format: str

View File

@@ -82,7 +82,8 @@ def train(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
if hasattr(model, "config"):
model.config.use_cache = False
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
@@ -92,7 +93,8 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
model.config.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:

View File

@@ -124,6 +124,36 @@ class GPUStatsCallback(
return control
class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
_args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
if self.violations >= self.patience:
LOG.warning(
"Loss is too high, stopping training (loss_watchdog_threshold)"
)
control.should_training_stop = True
else:
self.violations = 0
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [

View File

@@ -2,12 +2,16 @@
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union
import numpy as np
import torch
import transformers
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
IGNORE_INDEX = -100
@dataclass
class DataCollatorForSeq2Seq:
@@ -146,3 +150,31 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return {
"input_ids": input_ids,
"labels": labels,
}

View File

@@ -27,7 +27,7 @@ def choose_device(cfg):
cfg.device = get_device()
if cfg.world_size == 1:
cfg.device_map = "auto"
cfg.device_map = cfg.device_map or "auto"
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
@@ -77,6 +77,15 @@ def normalize_config(cfg):
else:
cfg.torch_dtype = torch.float32
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step
cfg.save_steps = save_steps
if cfg.evals_per_epoch:
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
if eval_steps < 1.0: # prevent evals on every step
cfg.eval_steps = eval_steps
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
if not cfg.base_model_config:
@@ -352,6 +361,27 @@ def validate_config(cfg):
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt"
)
if cfg.saves_per_epoch and cfg.save_steps:
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
)
if (
cfg.evals_per_epoch
and cfg.evaluation_strategy
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
@@ -397,6 +427,13 @@ def validate_config(cfg):
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
)
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -4,6 +4,7 @@ import math
import os
from typing import Optional, Tuple # noqa: F401
import addict
import bitsandbytes as bnb
import torch
import transformers
@@ -21,6 +22,7 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase,
)
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
@@ -28,16 +30,50 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl")
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
quant_config_exists = hasattr(model_config, "quantization_config")
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
and model_config.quantization_config["quant_method"] == "gptq"
)
if cfg.gptq and not quant_config_method_is_gptq:
raise ValueError(
"model_config.quantization_config is not set or quant_method is not set to gptq. "
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True
model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
try:
model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
except ValueError as err:
if "mamba" in model_config_name:
return addict.Dict(
{
"model_type": "mamba",
}
)
raise err
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
return model_config
@@ -69,6 +105,7 @@ def load_tokenizer(cfg):
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
@@ -101,6 +138,23 @@ def load_tokenizer(cfg):
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens:
tokenizer.add_tokens(
[
@@ -195,6 +249,18 @@ def load_model(
LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if (
cfg.model_config_type == "mixtral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mixtral import (
replace_mixtral_attn_with_multipack_flash_attn,
)
LOG.info("patching with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()
if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope,
@@ -216,6 +282,7 @@ def load_model(
model_kwargs = {}
model_kwargs["device_map"] = cfg.device_map
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype
if cfg.model_revision:
@@ -241,13 +308,22 @@ def load_model(
bnb_4bit_quant_type="nf4",
)
# sample packing uses custom FA2 patch
if cfg.flash_attention and not cfg.sample_packing:
if (
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
):
model_kwargs["use_flash_attention_2"] = True
if cfg.flash_attention:
if not cfg.sample_packing:
if (
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
or model_config.model_type == "mixtral"
):
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type == "mixtral":
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
@@ -309,6 +385,20 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
model_kwargs["device"] = torch.cuda.current_device()
del model_kwargs["torch_dtype"]
del model_kwargs["device_map"]
del model_kwargs["max_memory"]
model = MambaLMHeadModel.from_pretrained(
base_model,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
@@ -368,13 +458,17 @@ def load_model(
if cfg.resize_token_embeddings_to_32x
else len(tokenizer)
)
if model.get_input_embeddings().num_embeddings < embeddings_len:
if (
hasattr(model, "get_input_embeddings")
and model.get_input_embeddings().num_embeddings < embeddings_len
):
model.resize_token_embeddings(embeddings_len)
else:
model.tie_weights()
if (
hasattr(model.config, "max_position_embeddings")
hasattr(model, "config")
and hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
@@ -384,20 +478,22 @@ def load_model(
model.config.max_position_embeddings = cfg.sequence_len
if (
hasattr(model.config, "bos_token_id")
hasattr(model, "config")
and hasattr(model.config, "bos_token_id")
and model.config.bos_token_id
and model.config.bos_token_id != tokenizer.bos_token_id
):
model.config.bos_token_id = tokenizer.bos_token_id
if (
hasattr(model.config, "eos_token_id")
hasattr(model, "config")
and hasattr(model.config, "eos_token_id")
and model.config.eos_token_id
and model.config.eos_token_id != tokenizer.eos_token_id
):
model.config.eos_token_id = tokenizer.eos_token_id
if model.device.type == "cuda":
if hasattr(model, "device") and model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)
@@ -412,15 +508,22 @@ def load_model(
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
if not skip_prepare_model_for_kbit_training:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
@@ -449,7 +552,8 @@ def load_model(
requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates")
model.config.use_cache = False
if hasattr(model, "config"):
model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)

View File

@@ -131,8 +131,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
)
# Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
cfg.is_mistral_derived_model and cfg.flash_attention
if (
"CodeGenTokenizer" in tokenizer.__class__.__name__
or (cfg.is_mistral_derived_model and cfg.flash_attention)
or cfg.model_config_type == "mamba"
):
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
@@ -153,7 +155,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
if update:
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
skip_estimates = cfg.model_config_type == "mamba"
if not skip_estimates and not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
@@ -167,7 +171,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
if update:
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing:
if not skip_estimates and cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
@@ -267,12 +271,14 @@ def setup_fsdp_envs(cfg):
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
def prepare_optim_env(cfg):
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

View File

@@ -2,20 +2,20 @@
import os
from axolotl.utils.dict import DictDefault
def setup_wandb_env_vars(cfg):
if cfg.wandb_mode and cfg.wandb_mode == "offline":
os.environ["WANDB_MODE"] = cfg.wandb_mode
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
def setup_wandb_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("wandb_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
# Enable wandb if project name is present
if cfg.wandb_project and len(cfg.wandb_project) > 0:
cfg.use_wandb = True
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
os.environ.pop("WANDB_DISABLED", None) # Remove if present
else:
os.environ["WANDB_DISABLED"] = "true"

65
tests/e2e/test_mamba.py Normal file
View File

@@ -0,0 +1,65 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMistral(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_fft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "state-spaces/mamba-130m",
"model_type": "MambaLMHeadModel",
"tokenizer_type": "AutoTokenizer",
"tokenizer_config": "EleutherAI/gpt-neox-20b",
"flash_attention": False,
"sequence_len": 1024,
"load_in_8bit": False,
"val_set_size": 0.0,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"gradient_checkpointing": False,
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": None,
"save_safetensors": False,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,7 @@
"""Module for testing the validation module"""
import logging
import os
import unittest
from typing import Optional
@@ -8,6 +9,7 @@ import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.wandb_ import setup_wandb_env_vars
class ValidationTest(unittest.TestCase):
@@ -679,3 +681,83 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
class ValidationWandbTest(ValidationTest):
"""
Validation test for wandb
"""
def test_wandb_set_run_id_to_name(self):
cfg = DictDefault(
{
"wandb_run_id": "foo",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
in record.message
for record in self._caplog.records
)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
cfg = DictDefault(
{
"wandb_name": "foo",
}
)
validate_config(cfg)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
def test_wandb_sets_env(self):
cfg = DictDefault(
{
"wandb_project": "foo",
"wandb_name": "bar",
"wandb_run_id": "bat",
"wandb_entity": "baz",
"wandb_mode": "online",
"wandb_watch": "false",
"wandb_log_model": "checkpoint",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_PROJECT", "") == "foo"
assert os.environ.get("WANDB_NAME", "") == "bar"
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
assert os.environ.get("WANDB_ENTITY", "") == "baz"
assert os.environ.get("WANDB_MODE", "") == "online"
assert os.environ.get("WANDB_WATCH", "") == "false"
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
assert os.environ.get("WANDB_DISABLED", "") != "true"
def test_wandb_set_disabled(self):
cfg = DictDefault({})
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") == "true"
cfg = DictDefault(
{
"wandb_project": "foo",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") != "true"