Compare commits

..

14 Commits

Author SHA1 Message Date
kallewoof
450e04d3c4 fix: remove excessive newlines in system prompt(s) for alpaca (#936) 2023-12-13 16:40:02 +09:00
Juraj Bednar
b0cf397ecb More hints on what to do with CUDA Out of memory errors (#925) 2023-12-13 16:38:38 +09:00
Wing Lian
5f79b8242f new evals_per_epoch and saves_per_epoch to make things cleaner (#944)
* new evals_per_epoch and saves_per_epoch to make things cleaner

* update per PR feedback
2023-12-12 15:35:23 -05:00
Hamel Husain
f1de29dd1e Respect sequence_len in config for type: llama2_chat (#926)
* Respect sequence_len in config for `type: llama2_chat`

It was hardcoded to `4096` I am not sure why?  This updates it to pull from the config. 

cc: @winglian

* Update llama2_chat.py

* apply black formatting

* fix tokenizer

* update test data

* lint fixtures
2023-12-12 09:39:22 -08:00
Wing Lian
7fabc4d95e Mixtral official (#942)
* multipack support for official mixtral implementation

* fix patch to load multipack for mixtral

* chore: lint
2023-12-11 23:44:33 -05:00
Motoki Wu
9a5eb3990c Update requirements.txt (#940) 2023-12-11 22:57:28 -05:00
Casper
86487c2e96 Mixtral: More correct MoE, lower loss (#932)
* More correct MoE

* Fix formatting
2023-12-10 10:34:25 -05:00
Wing Lian
35f9b0f149 update to latest transformers for mixstral support (#929)
* update to latest transformers for mixstral support

* pin transformers

* fix typo
2023-12-10 10:32:27 -05:00
Wing Lian
68b227a7d8 Mixtral multipack (#928)
* mixtral multipack

* use mixtral model

* sample yml

* calculate cu_seqlens properly

* use updated flash ettention setting

* attn var checks

* force use of flash attention 2 for packing

* lint

* disable future fix for now

* update support table
2023-12-09 21:26:30 -05:00
Timothy Lim
03c6318ba3 fixing prompt template of chatml by removal of linebreak (#922)
Co-authored-by: Timothy  Lim <timothyyonglee.lim@kxrdev.com>
2023-12-09 13:07:44 -05:00
Wing Lian
40a6362c92 support for mamba (#915)
* support for mamba

* more mamba fixes

* use fork for mamba kwargs fix

* grad checkpointing doesn't work

* fix extras for mamaba

* mamba loss fix

* use fp32 and remove verbose logging

* mamba fixes

* fix collator for mamba

* set model_type on training_args

* don't save safetensors for mamba

* update mamba config to disable safetensor checkpooints, install for tests

* no evals for mamba tests

* handle save_pretrained

* handle unused safetensors arg
2023-12-09 12:10:41 -05:00
NanoCode012
d339beb9d9 chore: clarify Readme on sharegpt system role 2023-12-08 11:35:53 +09:00
NanoCode012
fde091cb12 fix(tokenizer): handle fast tokenizer properly for bos/eos (#914) 2023-12-08 11:31:13 +09:00
Casper
06ae39200b Pin flash-attn to 2.3.3 (#919) 2023-12-07 07:36:52 +01:00
59 changed files with 1107 additions and 364 deletions

View File

@@ -73,7 +73,7 @@ jobs:
run: |
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn]
pip3 install -U -e .[flash-attn,mamba-ssm]
pip3 install -r requirements-tests.txt
- name: Run e2e tests

View File

@@ -8,6 +8,9 @@ ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*]
ignore_errors = True
[mypy-axolotl.models.mixtral.*]
ignore_errors = True
[mypy-axolotl.models.phi.*]
ignore_errors = True

View File

@@ -65,19 +65,21 @@ Features:
## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | | | | |
| cerebras | ✅ | ✅ | ✅ | | | | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| gpt-j | ✅ | | | ❌ | ❌ | | ❓ |
| XGen | ✅ | | ✅ | | | | |
| phi | ✅ | ✅ | ✅ | | | ❓ | ❓ |
| RWKV | ✅ | ❓ | | ❓ | ❓ | ❓ | |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | | | | |
| Mixtral-MoE | ✅ | ✅ | ✅ | | | | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | | | ❌ | ❌ | | ❓ |
| falcon | ✅ | | ✅ | | | | |
| gpt-j | ✅ | ✅ | ✅ | | | ❓ | ❓ |
| XGen | ✅ | ❓ | | ❓ | ❓ | ❓ | |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
## Quickstart ⚡
@@ -245,7 +247,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations where `from` is `human`/`gpt`
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -689,9 +691,11 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
save_strategy: # Set to `no` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
@@ -1018,6 +1022,10 @@ Please reduce any below
- `gradient_accumulation_steps`
- `sequence_len`
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
Using adamw_bnb_8bit might also save you some memory.
> `failed (exitcode: -9)`
Usually means your system has run out of system memory.

View File

@@ -4,6 +4,7 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh

View File

@@ -72,8 +72,8 @@ gptq_groupsize:
gptq_model_v1:
warmup_steps: 32
eval_steps:
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
save_total_limit:
debug:

View File

@@ -49,8 +49,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -51,8 +51,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 40
eval_steps: 5
save_steps: 43
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -80,8 +80,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 5
save_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.000001

View File

@@ -51,8 +51,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 40
eval_steps: 5
save_steps: 43
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -46,8 +46,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -42,8 +42,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 110
save_steps: 660
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -58,9 +58,9 @@ flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
save_steps:
saves_per_epoch: 1
debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1

View File

@@ -62,8 +62,8 @@ flash_attention:
sdp_attention:
flash_optimum:
warmup_steps: 100
eval_steps:
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -54,10 +54,10 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -56,9 +56,9 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -60,8 +60,8 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps: 50
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -54,9 +54,9 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

61
examples/mamba/config.yml Normal file
View File

@@ -0,0 +1,61 @@
base_model: state-spaces/mamba-2.8b
model_type: MambaLMHeadModel
tokenizer_type: AutoTokenizer
tokenizer_config: EleutherAI/gpt-neox-20b
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./out
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
tokens:
save_safetensors: False

View File

@@ -46,10 +46,10 @@ xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -0,0 +1,79 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
#lora_target_modules:
# - gate
# - q_proj
# - k_proj
# - v_proj
# - o_proj
# - w1
# - w2
# - w3
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed/zero2.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -66,10 +66,10 @@ loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -44,8 +44,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 110
save_steps: 660
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0001

View File

@@ -49,8 +49,8 @@ flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -54,8 +54,8 @@ flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -48,8 +48,8 @@ flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -59,8 +59,8 @@ xformers_attention:
flash_attention:
warmup_steps: 100
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -59,8 +59,8 @@ xformers_attention:
flash_attention:
warmup_steps: 100
eval_steps: 0.05
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.1

View File

@@ -33,5 +33,5 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
weight_decay: 0.1
eval_steps: 0.05
evals_per_epoch: 4
logging_steps: 1

View File

@@ -56,10 +56,10 @@ xformers_attention:
flash_attention:
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -56,10 +56,10 @@ xformers_attention:
flash_attention:
warmup_steps: 10
eval_steps: 0.05
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -45,8 +45,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 110
save_steps: 660
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0001

View File

@@ -45,8 +45,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20
eval_steps: 50
save_steps:
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0

View File

@@ -78,8 +78,8 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 50
save_steps: 50
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0

View File

@@ -2,7 +2,7 @@
auto-gptq==0.5.1
packaging
peft==0.6.0
transformers==4.35.2
transformers @ git+https://github.com/huggingface/transformers.git@e5079b0b2abcef11ecbdae60ba4a6636c57b725d
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.24.1
@@ -29,7 +29,7 @@ scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.29
fschat==0.2.34
gradio==3.50.2
tensorboard

View File

@@ -46,10 +46,13 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn>=2.3.0",
"flash-attn==2.3.3",
],
"deepspeed": [
"deepspeed",
],
"mamba-ssm": [
"mamba-ssm==1.0.1",
],
},
)

View File

@@ -31,7 +31,10 @@ from axolotl.utils.callbacks import (
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
MambaDataCollator,
)
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
@@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments):
Extend the base TrainingArguments for axolotl helpers
"""
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
@@ -285,6 +291,32 @@ class AxolotlTrainer(Trainer):
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""
def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
@@ -462,6 +494,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -529,7 +563,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors:
if self.cfg.save_safetensors is not None:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est:
@@ -677,6 +711,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
@@ -731,11 +766,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
data_collator=self.build_collator(**data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
@@ -755,3 +786,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.micro_batch_size
return trainer
def build_collator(self, **kwargs):
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
)

View File

@@ -0,0 +1,12 @@
"""
Modeling module for Mamba models
"""
def fix_mamba_attn_for_loss():
from mamba_ssm.models import mixer_seq_simple
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name

View File

@@ -0,0 +1,42 @@
"""
HF Transformers MambaConfig
"""
from transformers import PretrainedConfig
class MambaConfig(PretrainedConfig):
"""
modeling configuration for state space model/mamba
"""
model_type = "mamba"
def __init__(
self,
vocab_size=50280,
d_model=2560,
n_layer=64,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=8,
pad_token_id=50277,
bos_token_id=0,
eos_token_id=0,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layer = n_layer
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@@ -0,0 +1,128 @@
# pylint: skip-file
import os
from collections import namedtuple
from functools import partial
from typing import Optional, Union
import torch
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
from torch import nn
from torch.nn import CrossEntropyLoss
from axolotl.models.mamba.configuration_mamba import MambaConfig
class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
d_model: int,
n_layer: int,
vocab_size: int,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
device=None,
dtype=None,
**backbone_kwargs,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (
vocab_size % pad_vocab_size_multiple
)
self.config = MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layer,
pad_vocab_size_multiple=pad_vocab_size_multiple,
)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
initializer_cfg=initializer_cfg,
**backbone_kwargs,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()
def tie_weights(self):
self.lm_head.weight = self.backbone.embedding.weight
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(
self,
input_ids,
position_ids=None,
inference_params=None,
num_last_tokens=0,
labels=None,
**kwargs,
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
loss = None
if labels is not None:
logits = lm_logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
print(loss)
return CausalLMOutput(logits=lm_logits, loss=loss)
else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
):
if state_dict is None:
state_dict = self.state_dict()
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config = load_config_hf(pretrained_model_name)
model = cls(**config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
)
return model

View File

@@ -1,168 +0,0 @@
# Adapted from Unsloth
# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py
import triton
import triton.language as tl
import torch
MAX_FUSED_SIZE = 65536
def calculate_settings(n):
BLOCK_SIZE = triton.next_power_of_2(n)
# CUDA only supports 65536 - 2^16 threads per block
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
num_warps = 4
if BLOCK_SIZE >= 32768: num_warps = 32
elif BLOCK_SIZE >= 8192: num_warps = 16
elif BLOCK_SIZE >= 2048: num_warps = 8
return BLOCK_SIZE, num_warps
pass
@triton.jit
def _cross_entropy_forward(logits_ptr, logits_row_stride,
loss_ptr,
lse_ptr,
labels_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Pi = exp(xi) / sum(exp(xi))
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
= -y [ x - log[sum(exp(x))] ]
= y * (log[sum(exp(x))] - x)
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride
loss_ptr += row_idx
lse_ptr += row_idx
labels_ptr += row_idx
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# TODO: Fixup int32 locations to int64
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
max_logits = tl.max(logits, 0)
# Maximum stops overflow
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
tl.store(lse_ptr, lse)
if label_idx != -100:
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = lse - logits_label
else:
loss = 0.0
tl.store(loss_ptr, loss)
pass
@triton.jit
def _cross_entropy_backward(logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
lse_ptr,
labels_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
From https://en.wikipedia.org/wiki/LogSumExp
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
If y == 0: dC/dx = 0
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride
dloss_ptr += row_idx * dloss_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# TODO: Fixup int32 locations to int64
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
if label_idx != -100:
dloss = tl.load(dloss_ptr)
else:
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels):
n_rows, n_cols = logits.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
_cross_entropy_forward[(n_rows,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
n_cols,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(logits, logsumexp, labels)
return losses
pass
@staticmethod
def backward(ctx, dlosses):
logits, logsumexp, labels = ctx.saved_tensors
n_rows, n_cols = logits.shape
_cross_entropy_backward[(n_rows,)](
logits, logits.stride(0),
dlosses, dlosses.stride(0),
logsumexp,
labels,
n_cols,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
return logits, None, None,
pass
pass
def fast_cross_entropy_loss(logits, labels):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
labels: (batch, seq_len,)
Returns:
losses: float
"""
batch, seq_len, d = logits.shape
assert(labels.shape == (batch, seq_len))
loss = CrossEntropyLoss.apply(
logits.view(batch*seq_len, d),
labels.view(-1),
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
pass

View File

@@ -13,20 +13,16 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
from transformers.models.mistral.modeling_mistral import (
MistralForCausalLM as OriginalMistralForCausalLM,
)
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
@@ -40,9 +36,6 @@ def replace_mistral_attn_with_flash_attn(
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
flashattn_forward
)
transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = (
mistral_causallm_forward
)
if packed:
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
MistralDecoderLayer
@@ -648,71 +641,3 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
outputs += (present_key_value,)
return outputs
def mistral_causallm_forward(
self: OriginalMistralForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
shift_labels = shift_labels.to(shift_logits.device)
# FAST CROSS ENTROPY
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -0,0 +1,22 @@
"""
Patches to support multipack for mixtral
"""
import transformers
def replace_mixtral_attn_with_multipack_flash_attn():
from .modeling_mixtral import (
MixtralMultipackFlashAttention2,
mixtral_decoder_layer_forward,
mixtral_model_forward,
)
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
mixtral_decoder_layer_forward
)
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
mixtral_model_forward
)
transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
"flash_attention_2"
] = MixtralMultipackFlashAttention2

View File

@@ -0,0 +1,379 @@
"""
Mixtral modeling for multipack
"""
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
import logging
import warnings
from typing import List, Optional, Tuple, Union
import torch
from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from transformers import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.mixtral.modeling_mixtral import (
MixtralFlashAttention2,
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
"""
Custom multipack implementation w flash attention 2
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
attn_output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=self.attention_dropout,
softmax_scale=None,
causal=True,
)
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def mixtral_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_router_logits (`bool`, *optional*):
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
should not be returned during inference.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
if output_router_logits:
outputs += (router_logits,)
return outputs
def mixtral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = (
attention_mask
if (attention_mask is not None and 0 in attention_mask)
else None
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
LOG.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_router_logits,
]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)

View File

@@ -81,8 +81,9 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sequence_len = 4096
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
self.tokenizer.add_special_tokens(
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")}
)
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
def tokenize_prompt(self, prompt):

View File

@@ -13,7 +13,7 @@ register_conv_template(
system_message="You are a helpful assistant.",
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>\n",
sep="<|im_end|>",
)
)

View File

@@ -33,8 +33,8 @@ class AlpacaPrompter(Prompter):
Base class for alpaca prompters
"""
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
system_format: str = "{system}"
turn_format: str
turn_no_input_format: str

View File

@@ -82,7 +82,8 @@ def train(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
if hasattr(model, "config"):
model.config.use_cache = False
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
@@ -92,7 +93,8 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
model.config.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:

View File

@@ -2,12 +2,16 @@
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union
import numpy as np
import torch
import transformers
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
IGNORE_INDEX = -100
@dataclass
class DataCollatorForSeq2Seq:
@@ -146,3 +150,31 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return {
"input_ids": input_ids,
"labels": labels,
}

View File

@@ -77,6 +77,15 @@ def normalize_config(cfg):
else:
cfg.torch_dtype = torch.float32
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step
cfg.save_steps = save_steps
if cfg.evals_per_epoch:
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
if eval_steps < 1.0: # prevent evals on every step
cfg.eval_steps = eval_steps
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
if not cfg.base_model_config:
@@ -352,6 +361,27 @@ def validate_config(cfg):
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt"
)
if cfg.saves_per_epoch and cfg.save_steps:
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
)
if (
cfg.evals_per_epoch
and cfg.evaluation_strategy
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."

View File

@@ -4,6 +4,7 @@ import math
import os
from typing import Optional, Tuple # noqa: F401
import addict
import bitsandbytes as bnb
import torch
import transformers
@@ -21,6 +22,7 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase,
)
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
@@ -52,9 +54,20 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True
model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
try:
model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
except ValueError as err:
if "mamba" in model_config_name:
return addict.Dict(
{
"model_type": "mamba",
}
)
raise err
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)
@@ -92,6 +105,7 @@ def load_tokenizer(cfg):
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
@@ -124,6 +138,23 @@ def load_tokenizer(cfg):
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens:
tokenizer.add_tokens(
[
@@ -218,6 +249,18 @@ def load_model(
LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if (
cfg.model_config_type == "mixtral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mixtral import (
replace_mixtral_attn_with_multipack_flash_attn,
)
LOG.info("patching with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()
if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope,
@@ -265,13 +308,22 @@ def load_model(
bnb_4bit_quant_type="nf4",
)
# sample packing uses custom FA2 patch
if cfg.flash_attention and not cfg.sample_packing:
if (
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
):
model_kwargs["use_flash_attention_2"] = True
if cfg.flash_attention:
if not cfg.sample_packing:
if (
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
or model_config.model_type == "mixtral"
):
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type == "mixtral":
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
@@ -333,6 +385,20 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
model_kwargs["device"] = torch.cuda.current_device()
del model_kwargs["torch_dtype"]
del model_kwargs["device_map"]
del model_kwargs["max_memory"]
model = MambaLMHeadModel.from_pretrained(
base_model,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
@@ -392,13 +458,17 @@ def load_model(
if cfg.resize_token_embeddings_to_32x
else len(tokenizer)
)
if model.get_input_embeddings().num_embeddings < embeddings_len:
if (
hasattr(model, "get_input_embeddings")
and model.get_input_embeddings().num_embeddings < embeddings_len
):
model.resize_token_embeddings(embeddings_len)
else:
model.tie_weights()
if (
hasattr(model.config, "max_position_embeddings")
hasattr(model, "config")
and hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
@@ -408,20 +478,22 @@ def load_model(
model.config.max_position_embeddings = cfg.sequence_len
if (
hasattr(model.config, "bos_token_id")
hasattr(model, "config")
and hasattr(model.config, "bos_token_id")
and model.config.bos_token_id
and model.config.bos_token_id != tokenizer.bos_token_id
):
model.config.bos_token_id = tokenizer.bos_token_id
if (
hasattr(model.config, "eos_token_id")
hasattr(model, "config")
and hasattr(model.config, "eos_token_id")
and model.config.eos_token_id
and model.config.eos_token_id != tokenizer.eos_token_id
):
model.config.eos_token_id = tokenizer.eos_token_id
if model.device.type == "cuda":
if hasattr(model, "device") and model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)
@@ -480,7 +552,8 @@ def load_model(
requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates")
model.config.use_cache = False
if hasattr(model, "config"):
model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)

View File

@@ -131,8 +131,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
)
# Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
cfg.is_mistral_derived_model and cfg.flash_attention
if (
"CodeGenTokenizer" in tokenizer.__class__.__name__
or (cfg.is_mistral_derived_model and cfg.flash_attention)
or cfg.model_config_type == "mamba"
):
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
@@ -153,7 +155,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
if update:
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
skip_estimates = cfg.model_config_type == "mamba"
if not skip_estimates and not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
@@ -167,7 +171,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
if update:
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing:
if not skip_estimates and cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails

65
tests/e2e/test_mamba.py Normal file
View File

@@ -0,0 +1,65 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMistral(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_fft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "state-spaces/mamba-130m",
"model_type": "MambaLMHeadModel",
"tokenizer_type": "AutoTokenizer",
"tokenizer_config": "EleutherAI/gpt-neox-20b",
"flash_attention": False,
"sequence_len": 1024,
"load_in_8bit": False,
"val_set_size": 0.0,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"gradient_checkpointing": False,
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": None,
"save_safetensors": False,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

File diff suppressed because one or more lines are too long