Compare commits
14 Commits
unsloth_mo
...
mixtral_op
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
450e04d3c4 | ||
|
|
b0cf397ecb | ||
|
|
5f79b8242f | ||
|
|
f1de29dd1e | ||
|
|
7fabc4d95e | ||
|
|
9a5eb3990c | ||
|
|
86487c2e96 | ||
|
|
35f9b0f149 | ||
|
|
68b227a7d8 | ||
|
|
03c6318ba3 | ||
|
|
40a6362c92 | ||
|
|
d339beb9d9 | ||
|
|
fde091cb12 | ||
|
|
06ae39200b |
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
||||
run: |
|
||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||
pip3 uninstall -y transformers accelerate
|
||||
pip3 install -U -e .[flash-attn]
|
||||
pip3 install -U -e .[flash-attn,mamba-ssm]
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run e2e tests
|
||||
|
||||
@@ -8,6 +8,9 @@ ignore_missing_imports = True
|
||||
[mypy-axolotl.monkeypatch.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-axolotl.models.mixtral.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-axolotl.models.phi.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
||||
38
README.md
38
README.md
@@ -65,19 +65,21 @@ Features:
|
||||
|
||||
## Axolotl supports
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
@@ -245,7 +247,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -689,9 +691,11 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||
save_strategy: # Set to `no` to skip checkpoint saves
|
||||
save_steps: # Leave empty to save at each epoch
|
||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
@@ -1018,6 +1022,10 @@ Please reduce any below
|
||||
- `gradient_accumulation_steps`
|
||||
- `sequence_len`
|
||||
|
||||
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
|
||||
|
||||
Using adamw_bnb_8bit might also save you some memory.
|
||||
|
||||
> `failed (exitcode: -9)`
|
||||
|
||||
Usually means your system has run out of system memory.
|
||||
|
||||
@@ -4,6 +4,7 @@ FROM winglian/axolotl:$BASE_TAG
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||
|
||||
|
||||
@@ -72,8 +72,8 @@ gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
|
||||
warmup_steps: 32
|
||||
eval_steps:
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
save_total_limit:
|
||||
|
||||
debug:
|
||||
|
||||
@@ -49,8 +49,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -54,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -56,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -54,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -56,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -54,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -56,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -51,8 +51,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 40
|
||||
eval_steps: 5
|
||||
save_steps: 43
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -80,8 +80,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 5
|
||||
save_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.000001
|
||||
|
||||
@@ -51,8 +51,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 40
|
||||
eval_steps: 5
|
||||
save_steps: 43
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -46,8 +46,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -42,8 +42,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -58,9 +58,9 @@ flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_steps: 100
|
||||
eval_steps: 0.05
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -62,8 +62,8 @@ flash_attention:
|
||||
sdp_attention:
|
||||
flash_optimum:
|
||||
warmup_steps: 100
|
||||
eval_steps:
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -54,10 +54,10 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -56,9 +56,9 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -60,8 +60,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
save_steps: 50
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -54,9 +54,9 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
61
examples/mamba/config.yml
Normal file
61
examples/mamba/config.yml
Normal file
@@ -0,0 +1,61 @@
|
||||
base_model: state-spaces/mamba-2.8b
|
||||
model_type: MambaLMHeadModel
|
||||
tokenizer_type: AutoTokenizer
|
||||
tokenizer_config: EleutherAI/gpt-neox-20b
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
tokens:
|
||||
save_safetensors: False
|
||||
@@ -46,10 +46,10 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
79
examples/mistral/mixtral.yml
Normal file
79
examples/mistral/mixtral.yml
Normal file
@@ -0,0 +1,79 @@
|
||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
#lora_target_modules:
|
||||
# - gate
|
||||
# - q_proj
|
||||
# - k_proj
|
||||
# - v_proj
|
||||
# - o_proj
|
||||
# - w1
|
||||
# - w2
|
||||
# - w3
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed: deepspeed/zero2.json
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -66,10 +66,10 @@ loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -44,8 +44,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
|
||||
@@ -49,8 +49,8 @@ flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -54,8 +54,8 @@ flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -48,8 +48,8 @@ flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -59,8 +59,8 @@ xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 100
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -59,8 +59,8 @@ xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 100
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -33,5 +33,5 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
weight_decay: 0.1
|
||||
eval_steps: 0.05
|
||||
evals_per_epoch: 4
|
||||
logging_steps: 1
|
||||
|
||||
@@ -56,10 +56,10 @@ xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -56,10 +56,10 @@ xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -45,8 +45,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
|
||||
@@ -45,8 +45,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0
|
||||
|
||||
@@ -78,8 +78,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps: 50
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
auto-gptq==0.5.1
|
||||
packaging
|
||||
peft==0.6.0
|
||||
transformers==4.35.2
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@e5079b0b2abcef11ecbdae60ba4a6636c57b725d
|
||||
tokenizers==0.15.0
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate==0.24.1
|
||||
@@ -29,7 +29,7 @@ scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.29
|
||||
fschat==0.2.34
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
|
||||
|
||||
5
setup.py
5
setup.py
@@ -46,10 +46,13 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn>=2.3.0",
|
||||
"flash-attn==2.3.3",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.0.1",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -31,7 +31,10 @@ from axolotl.utils.callbacks import (
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
MambaDataCollator,
|
||||
)
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
@@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
"""
|
||||
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
@@ -285,6 +291,32 @@ class AxolotlTrainer(Trainer):
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Mamba specific trainer to handle loss calculation
|
||||
"""
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
@@ -462,6 +494,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return OneCycleLRSchedulerTrainer
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
@@ -529,7 +563,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.hub_strategy:
|
||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.save_safetensors:
|
||||
if self.cfg.save_safetensors is not None:
|
||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.cfg.sample_packing_eff_est:
|
||||
@@ -677,6 +711,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
@@ -731,11 +766,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=BatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
data_collator=self.build_collator(**data_collator_kwargs),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
@@ -755,3 +786,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.micro_batch_size
|
||||
|
||||
return trainer
|
||||
|
||||
def build_collator(self, **kwargs):
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||
|
||||
return BatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
12
src/axolotl/models/mamba/__init__.py
Normal file
12
src/axolotl/models/mamba/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Modeling module for Mamba models
|
||||
"""
|
||||
|
||||
|
||||
def fix_mamba_attn_for_loss():
|
||||
from mamba_ssm.models import mixer_seq_simple
|
||||
|
||||
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
||||
|
||||
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
|
||||
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
|
||||
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
HF Transformers MambaConfig
|
||||
"""
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MambaConfig(PretrainedConfig):
|
||||
"""
|
||||
modeling configuration for state space model/mamba
|
||||
"""
|
||||
|
||||
model_type = "mamba"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50280,
|
||||
d_model=2560,
|
||||
n_layer=64,
|
||||
rms_norm=True,
|
||||
residual_in_fp32=True,
|
||||
fused_add_norm=True,
|
||||
pad_vocab_size_multiple=8,
|
||||
pad_token_id=50277,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.n_layer = n_layer
|
||||
self.rms_norm = rms_norm
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.fused_add_norm = fused_add_norm
|
||||
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# pylint: skip-file
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
||||
from mamba_ssm.utils.generation import GenerationMixin
|
||||
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from axolotl.models.mamba.configuration_mamba import MambaConfig
|
||||
|
||||
|
||||
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_layer: int,
|
||||
vocab_size: int,
|
||||
initializer_cfg=None,
|
||||
pad_vocab_size_multiple: int = 1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
**backbone_kwargs,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if vocab_size % pad_vocab_size_multiple != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (
|
||||
vocab_size % pad_vocab_size_multiple
|
||||
)
|
||||
self.config = MambaConfig(
|
||||
vocab_size=vocab_size,
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
||||
)
|
||||
self.backbone = MixerModel(
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
vocab_size=vocab_size,
|
||||
initializer_cfg=initializer_cfg,
|
||||
**backbone_kwargs,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(
|
||||
partial(
|
||||
_init_weights,
|
||||
n_layer=n_layer,
|
||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||
)
|
||||
)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self.lm_head.weight = self.backbone.embedding.weight
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.backbone.allocate_inference_cache(
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids=None,
|
||||
inference_params=None,
|
||||
num_last_tokens=0,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||
"""
|
||||
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
logits = lm_logits
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
|
||||
print(loss)
|
||||
return CausalLMOutput(logits=lm_logits, loss=loss)
|
||||
|
||||
else:
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
state_dict: Optional[dict] = None,
|
||||
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
|
||||
):
|
||||
if state_dict is None:
|
||||
state_dict = self.state_dict()
|
||||
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
||||
config = load_config_hf(pretrained_model_name)
|
||||
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
||||
model.load_state_dict(
|
||||
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
|
||||
)
|
||||
return model
|
||||
@@ -1,168 +0,0 @@
|
||||
# Adapted from Unsloth
|
||||
# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
MAX_FUSED_SIZE = 65536
|
||||
|
||||
def calculate_settings(n):
|
||||
BLOCK_SIZE = triton.next_power_of_2(n)
|
||||
# CUDA only supports 65536 - 2^16 threads per block
|
||||
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
||||
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
||||
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 32768: num_warps = 32
|
||||
elif BLOCK_SIZE >= 8192: num_warps = 16
|
||||
elif BLOCK_SIZE >= 2048: num_warps = 8
|
||||
return BLOCK_SIZE, num_warps
|
||||
pass
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
||||
loss_ptr,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,):
|
||||
"""
|
||||
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
||||
Pi = exp(xi) / sum(exp(xi))
|
||||
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
||||
= -y [ x - log[sum(exp(x))] ]
|
||||
= y * (log[sum(exp(x))] - x)
|
||||
If y == 0: CE_i = 0
|
||||
If y == 1: CE_i = logsumexp - x
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr += row_idx * logits_row_stride
|
||||
loss_ptr += row_idx
|
||||
lse_ptr += row_idx
|
||||
labels_ptr += row_idx
|
||||
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
# TODO: Fixup int32 locations to int64
|
||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
||||
max_logits = tl.max(logits, 0)
|
||||
# Maximum stops overflow
|
||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
||||
tl.store(lse_ptr, lse)
|
||||
|
||||
if label_idx != -100:
|
||||
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
|
||||
loss = lse - logits_label
|
||||
else:
|
||||
loss = 0.0
|
||||
tl.store(loss_ptr, loss)
|
||||
pass
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_backward(logits_ptr, logits_row_stride,
|
||||
dloss_ptr, dloss_row_stride,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,):
|
||||
"""
|
||||
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
||||
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
||||
|
||||
From https://en.wikipedia.org/wiki/LogSumExp
|
||||
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
||||
|
||||
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
||||
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
||||
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
||||
|
||||
If y == 0: dC/dx = 0
|
||||
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
||||
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr += row_idx * logits_row_stride
|
||||
dloss_ptr += row_idx * dloss_row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
# TODO: Fixup int32 locations to int64
|
||||
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
||||
|
||||
if label_idx != -100:
|
||||
dloss = tl.load(dloss_ptr)
|
||||
else:
|
||||
dloss = 0.0
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
lse = tl.load(lse_ptr + row_idx)
|
||||
probs = tl.exp(logits - lse)
|
||||
|
||||
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
||||
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
|
||||
|
||||
|
||||
|
||||
class CrossEntropyLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels):
|
||||
n_rows, n_cols = logits.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
|
||||
_cross_entropy_forward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
losses,
|
||||
logsumexp,
|
||||
labels,
|
||||
n_cols,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(logits, logsumexp, labels)
|
||||
return losses
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dlosses):
|
||||
logits, logsumexp, labels = ctx.saved_tensors
|
||||
n_rows, n_cols = logits.shape
|
||||
|
||||
_cross_entropy_backward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
dlosses, dlosses.stride(0),
|
||||
logsumexp,
|
||||
labels,
|
||||
n_cols,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
return logits, None, None,
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def fast_cross_entropy_loss(logits, labels):
|
||||
"""
|
||||
Arguments:
|
||||
logits: (batch, seq_len, vocab_size)
|
||||
labels: (batch, seq_len,)
|
||||
Returns:
|
||||
losses: float
|
||||
"""
|
||||
batch, seq_len, d = logits.shape
|
||||
assert(labels.shape == (batch, seq_len))
|
||||
|
||||
loss = CrossEntropyLoss.apply(
|
||||
logits.view(batch*seq_len, d),
|
||||
labels.view(-1),
|
||||
)
|
||||
n_items = torch.count_nonzero(labels != -100)
|
||||
return loss.sum() / n_items
|
||||
pass
|
||||
@@ -13,20 +13,16 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralForCausalLM as OriginalMistralForCausalLM,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||
|
||||
@@ -40,9 +36,6 @@ def replace_mistral_attn_with_flash_attn(
|
||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = (
|
||||
mistral_causallm_forward
|
||||
)
|
||||
if packed:
|
||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||
MistralDecoderLayer
|
||||
@@ -648,71 +641,3 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def mistral_causallm_forward(
|
||||
self: OriginalMistralForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args, **kwargs
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits
|
||||
if not hasattr(self, "extra_ignored_labels"):
|
||||
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
|
||||
|
||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
|
||||
# FAST CROSS ENTROPY
|
||||
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
22
src/axolotl/monkeypatch/mixtral/__init__.py
Normal file
22
src/axolotl/monkeypatch/mixtral/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Patches to support multipack for mixtral
|
||||
"""
|
||||
import transformers
|
||||
|
||||
|
||||
def replace_mixtral_attn_with_multipack_flash_attn():
|
||||
from .modeling_mixtral import (
|
||||
MixtralMultipackFlashAttention2,
|
||||
mixtral_decoder_layer_forward,
|
||||
mixtral_model_forward,
|
||||
)
|
||||
|
||||
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
|
||||
mixtral_decoder_layer_forward
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
||||
mixtral_model_forward
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
|
||||
"flash_attention_2"
|
||||
] = MixtralMultipackFlashAttention2
|
||||
379
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
Normal file
379
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
Mixtral modeling for multipack
|
||||
"""
|
||||
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
|
||||
import logging
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
from transformers import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
MixtralFlashAttention2,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
|
||||
|
||||
|
||||
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
|
||||
"""
|
||||
Custom multipack implementation w flash attention 2
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._flash_attn_uses_top_left_mask = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
attn_output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=self.attention_dropout,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def mixtral_decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_router_logits (`bool`, *optional*):
|
||||
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
||||
should not be returned during inference.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
if output_router_logits:
|
||||
outputs += (router_logits,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def mixtral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, MoeModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = (
|
||||
attention_mask
|
||||
if (attention_mask is not None and 0 in attention_mask)
|
||||
else None
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
LOG.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache()
|
||||
if use_legacy_cache
|
||||
else next_decoder_cache
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
all_router_logits,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
@@ -81,8 +81,9 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sequence_len = 4096
|
||||
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
self.tokenizer.add_special_tokens(
|
||||
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")}
|
||||
)
|
||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
|
||||
@@ -13,7 +13,7 @@ register_conv_template(
|
||||
system_message="You are a helpful assistant.",
|
||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>\n",
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ class AlpacaPrompter(Prompter):
|
||||
Base class for alpaca prompters
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||
system_format: str = "{system}"
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
|
||||
@@ -82,7 +82,8 @@ def train(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
)
|
||||
|
||||
model.config.use_cache = False
|
||||
if hasattr(model, "config"):
|
||||
model.config.use_cache = False
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
@@ -92,7 +93,8 @@ def train(
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
@@ -2,12 +2,16 @@
|
||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSeq2Seq:
|
||||
@@ -146,3 +150,31 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaDataCollator:
|
||||
"""
|
||||
Collator for State Space Models (Mamba)
|
||||
"""
|
||||
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
input_ids, labels = tuple(
|
||||
[torch.LongTensor(instance[key]) for instance in instances]
|
||||
for key in ("input_ids", "labels")
|
||||
)
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
labels, batch_first=True, padding_value=IGNORE_INDEX
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
@@ -77,6 +77,15 @@ def normalize_config(cfg):
|
||||
else:
|
||||
cfg.torch_dtype = torch.float32
|
||||
|
||||
if cfg.saves_per_epoch:
|
||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||
if save_steps < 1.0: # prevent saves on every step
|
||||
cfg.save_steps = save_steps
|
||||
if cfg.evals_per_epoch:
|
||||
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
||||
if eval_steps < 1.0: # prevent evals on every step
|
||||
cfg.eval_steps = eval_steps
|
||||
|
||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||
|
||||
if not cfg.base_model_config:
|
||||
@@ -352,6 +361,27 @@ def validate_config(cfg):
|
||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||
"sharegpt_simple", "sharegpt"
|
||||
)
|
||||
|
||||
if cfg.saves_per_epoch and cfg.save_steps:
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||
)
|
||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if (
|
||||
cfg.evals_per_epoch
|
||||
and cfg.evaluation_strategy
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
|
||||
@@ -4,6 +4,7 @@ import math
|
||||
import os
|
||||
from typing import Optional, Tuple # noqa: F401
|
||||
|
||||
import addict
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
@@ -21,6 +22,7 @@ from transformers import ( # noqa: F401
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -52,9 +54,20 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
||||
def load_model_config(cfg):
|
||||
model_config_name = cfg.base_model_config or cfg.base_model
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_config_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_config_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except ValueError as err:
|
||||
if "mamba" in model_config_name:
|
||||
return addict.Dict(
|
||||
{
|
||||
"model_type": "mamba",
|
||||
}
|
||||
)
|
||||
raise err
|
||||
|
||||
if cfg.model_config:
|
||||
for key, val in cfg.model_config.items():
|
||||
setattr(model_config, key, val)
|
||||
@@ -92,6 +105,7 @@ def load_tokenizer(cfg):
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
"CodeLlamaTokenizerFast",
|
||||
]
|
||||
and hasattr(tokenizer, "pad_token")
|
||||
and not tokenizer.pad_token
|
||||
@@ -124,6 +138,23 @@ def load_tokenizer(cfg):
|
||||
tokenizer.add_special_tokens(
|
||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||
)
|
||||
|
||||
# If we add bos_token and eos_token, we need to update the post processor to
|
||||
# handle them correctly.
|
||||
# https://github.com/huggingface/transformers/pull/24132
|
||||
bos_or_eos_in_special_tokens = (
|
||||
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
||||
)
|
||||
if (
|
||||
tokenizer.__class__.__name__
|
||||
in (
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizerFast",
|
||||
)
|
||||
and bos_or_eos_in_special_tokens
|
||||
):
|
||||
tokenizer.update_post_processor()
|
||||
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
@@ -218,6 +249,18 @@ def load_model(
|
||||
LOG.info("patching with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if (
|
||||
cfg.model_config_type == "mixtral"
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
from axolotl.monkeypatch.mixtral import (
|
||||
replace_mixtral_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching with flash attention")
|
||||
replace_mixtral_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||
replace_llama_rope_with_xpos_rope,
|
||||
@@ -265,13 +308,22 @@ def load_model(
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
# sample packing uses custom FA2 patch
|
||||
if cfg.flash_attention and not cfg.sample_packing:
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
or cfg.is_falcon_derived_model
|
||||
or cfg.is_mistral_derived_model
|
||||
):
|
||||
model_kwargs["use_flash_attention_2"] = True
|
||||
if cfg.flash_attention:
|
||||
if not cfg.sample_packing:
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
or cfg.is_falcon_derived_model
|
||||
or cfg.is_mistral_derived_model
|
||||
or model_config.model_type == "mixtral"
|
||||
):
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
)
|
||||
else:
|
||||
if model_config.model_type == "mixtral":
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
)
|
||||
|
||||
try:
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
@@ -333,6 +385,20 @@ def load_model(
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif model_type == "MambaLMHeadModel":
|
||||
# FIXME this is janky at best and hacked together to make it work
|
||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||
|
||||
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
|
||||
model_kwargs["device"] = torch.cuda.current_device()
|
||||
del model_kwargs["torch_dtype"]
|
||||
del model_kwargs["device_map"]
|
||||
del model_kwargs["max_memory"]
|
||||
|
||||
model = MambaLMHeadModel.from_pretrained(
|
||||
base_model,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif model_type and not cfg.trust_remote_code:
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -392,13 +458,17 @@ def load_model(
|
||||
if cfg.resize_token_embeddings_to_32x
|
||||
else len(tokenizer)
|
||||
)
|
||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||
if (
|
||||
hasattr(model, "get_input_embeddings")
|
||||
and model.get_input_embeddings().num_embeddings < embeddings_len
|
||||
):
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
else:
|
||||
model.tie_weights()
|
||||
|
||||
if (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
hasattr(model, "config")
|
||||
and hasattr(model.config, "max_position_embeddings")
|
||||
and model.config.max_position_embeddings
|
||||
and cfg.sequence_len > model.config.max_position_embeddings
|
||||
):
|
||||
@@ -408,20 +478,22 @@ def load_model(
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if (
|
||||
hasattr(model.config, "bos_token_id")
|
||||
hasattr(model, "config")
|
||||
and hasattr(model.config, "bos_token_id")
|
||||
and model.config.bos_token_id
|
||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
||||
):
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
|
||||
if (
|
||||
hasattr(model.config, "eos_token_id")
|
||||
hasattr(model, "config")
|
||||
and hasattr(model.config, "eos_token_id")
|
||||
and model.config.eos_token_id
|
||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
||||
):
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if model.device.type == "cuda":
|
||||
if hasattr(model, "device") and model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
@@ -480,7 +552,8 @@ def load_model(
|
||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||
if len(requires_grad) == 0:
|
||||
LOG.warning("there are no parameters that require gradient updates")
|
||||
model.config.use_cache = False
|
||||
if hasattr(model, "config"):
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
@@ -131,8 +131,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
)
|
||||
|
||||
# Phi doesn't want the attention_mask feature when training
|
||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
||||
if (
|
||||
"CodeGenTokenizer" in tokenizer.__class__.__name__
|
||||
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
||||
or cfg.model_config_type == "mamba"
|
||||
):
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
@@ -153,7 +155,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if update:
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
if not cfg.total_supervised_tokens:
|
||||
skip_estimates = cfg.model_config_type == "mamba"
|
||||
|
||||
if not skip_estimates and not cfg.total_supervised_tokens:
|
||||
total_supervised_tokens = (
|
||||
train_dataset.data.column("labels")
|
||||
.to_pandas()
|
||||
@@ -167,7 +171,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if update:
|
||||
cfg.total_supervised_tokens = total_supervised_tokens
|
||||
|
||||
if cfg.sample_packing:
|
||||
if not skip_estimates and cfg.sample_packing:
|
||||
# we have to drop anything longer then sequence len otherwise
|
||||
# flash attention with position ids fails
|
||||
|
||||
|
||||
65
tests/e2e/test_mamba.py
Normal file
65
tests/e2e/test_mamba.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "state-spaces/mamba-130m",
|
||||
"model_type": "MambaLMHeadModel",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"tokenizer_config": "EleutherAI/gpt-neox-20b",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": False,
|
||||
"val_set_size": 0.0,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"gradient_checkpointing": False,
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": None,
|
||||
"save_safetensors": False,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user