Compare commits

..

1 Commits

Author SHA1 Message Date
mhenrichsen
9084879861 tinyllama 2023-11-16 13:36:01 +00:00
78 changed files with 317 additions and 2980 deletions

View File

@@ -71,9 +71,8 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
pip3 uninstall -y transformers accelerate pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn,mamba-ssm] pip3 install -U -e .[flash-attn]
pip3 install -r requirements-tests.txt pip3 install -r requirements-tests.txt
- name: Run e2e tests - name: Run e2e tests

View File

@@ -8,9 +8,6 @@ ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*] [mypy-axolotl.monkeypatch.*]
ignore_errors = True ignore_errors = True
[mypy-axolotl.models.mixtral.*]
ignore_errors = True
[mypy-axolotl.models.phi.*] [mypy-axolotl.models.phi.*]
ignore_errors = True ignore_errors = True

105
README.md
View File

@@ -36,9 +36,7 @@ Features:
- [Train](#train) - [Train](#train)
- [Inference](#inference) - [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base) - [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
- [Common Errors](#common-errors-) - [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Need Help?](#need-help-) - [Need Help?](#need-help-)
- [Badge](#badge-) - [Badge](#badge-)
- [Community Showcase](#community-showcase) - [Community Showcase](#community-showcase)
@@ -67,21 +65,18 @@ Features:
## Axolotl supports ## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn | | | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------| |----------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | | | | | | Pythia | ✅ | ✅ | ✅ | | | | |
| Mixtral-MoE | ✅ | ✅ | ✅ | | | | ❓ | | cerebras | ✅ | ✅ | ✅ | | | | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ | | btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | | | ❌ | ❌ | ❌ | ❓ | | mpt | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ | | falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | | | ❌ | ❌ | | ❓ | | gpt-j | ✅ | | | ❌ | ❌ | | ❓ |
| falcon | ✅ | | ✅ | | | | | | XGen | ✅ | | ✅ | | | | |
| gpt-j | ✅ | ✅ | ✅ | | | ❓ | ❓ | | phi | ✅ | ✅ | ✅ | | | ❓ | ❓ |
| XGen | ✅ | ❓ | | ❓ | ❓ | ❓ | | | RWKV | ✅ | ❓ | | ❓ | ❓ | ❓ | |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
## Quickstart ⚡ ## Quickstart ⚡
@@ -90,19 +85,14 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
**Requirements**: Python >=3.9 and Pytorch >=2.0. **Requirements**: Python >=3.9 and Pytorch >=2.0.
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
### For developers
```bash ```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl cd axolotl
pip3 install packaging pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]' pip3 install -e '.[flash-attn,deepspeed]'
``` pip3 install -U git+https://github.com/huggingface/peft.git
### Usage
```bashtet
# finetune lora # finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
@@ -249,17 +239,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"instruction": "...", "input": "...", "output": "..."} {"instruction": "...", "input": "...", "output": "..."}
``` ```
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt) - `sharegpt`: conversations where `from` is `human`/`gpt`
```json ```json
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
```yml
datasets:
- path: <your-path>
type: sharegpt
conversation: llama-2
```
- `completion`: raw corpus - `completion`: raw corpus
```json ```json
{"text": "..."} {"text": "..."}
@@ -511,7 +494,6 @@ is_falcon_derived_model:
is_llama_derived_model: is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default # Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model: is_mistral_derived_model:
is_qwen_derived_model:
# optional overrides to the base model configuration # optional overrides to the base model configuration
model_config: model_config:
@@ -556,8 +538,6 @@ datasets:
# Optional[str] fastchat conversation type, only used with type: sharegpt # Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Custom user prompt # Custom user prompt
- path: repo - path: repo
@@ -623,12 +603,6 @@ eval_sample_packing:
sample_packing_eff_est: sample_packing_eff_est:
total_num_tokens: total_num_tokens:
# Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory
device_map:
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
max_memory:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
# If you already have a lora model trained that you want to load, put that here. # If you already have a lora model trained that you want to load, put that here.
@@ -676,8 +650,7 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
wandb_project: # Your wandb project name wandb_project: # Your wandb project name
wandb_entity: # A wandb Team name if using a Team wandb_entity: # A wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_name: # Set the name of your wandb run wandb_run_id: # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# Where to save the full-finetuned model to # Where to save the full-finetuned model to
@@ -695,16 +668,13 @@ gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 2
eval_batch_size: eval_batch_size:
num_epochs: 4 num_epochs: 4
warmup_steps: 100 # cannot use with warmup_ratio warmup_steps: 100
warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
save_strategy: # Set to `no` to skip checkpoint saves save_strategy: # Set to `no` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch save_steps: # Leave empty to save at each epoch
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
save_total_limit: # Checkpoints saved at a time save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that # Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed. # if both are set, num_epochs will not be guaranteed.
@@ -714,9 +684,6 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package) # Save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
@@ -783,7 +750,7 @@ max_grad_norm:
# Augmentation techniques # Augmentation techniques
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# currently only supported on Llama and Mistral # currently only supported on Llama and Mistral
neftune_noise_alpha: noisy_embedding_alpha:
# Whether to bettertransformers # Whether to bettertransformers
flash_optimum: flash_optimum:
@@ -975,26 +942,10 @@ wandb_mode:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
``` ```
##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
```yml
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
tokens: # these are delimiters
- "<|im_start|>"
- "<|im_end|>"
```
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
### Inference ### Inference
Pass the appropriate flag to the train command: Pass the appropriate flag to the train command:
@@ -1047,10 +998,6 @@ Please reduce any below
- `gradient_accumulation_steps` - `gradient_accumulation_steps`
- `sequence_len` - `sequence_len`
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
Using adamw_bnb_8bit might also save you some memory.
> `failed (exitcode: -9)` > `failed (exitcode: -9)`
Usually means your system has run out of system memory. Usually means your system has run out of system memory.
@@ -1073,20 +1020,6 @@ It's safe to ignore it.
See the [NCCL](docs/nccl.md) guide. See the [NCCL](docs/nccl.md) guide.
### Tokenization Mismatch b/w Inference & Training
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
## Need help? 🙋♂️ ## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you

View File

@@ -24,6 +24,16 @@
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -28,6 +28,16 @@
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -32,6 +32,16 @@
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -1,39 +0,0 @@
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 0,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1"
ENV PYTORCH_VERSION=$PYTORCH_VERSION ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev apt-get install -y vim curl
WORKDIR /workspace WORKDIR /workspace

View File

@@ -4,7 +4,6 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub" ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh

View File

@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: btlm-out output_dir: btlm-out
@@ -72,8 +72,8 @@ gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 32 warmup_steps: 32
evals_per_epoch: 4 eval_steps:
saves_per_epoch: 1 save_steps:
save_total_limit: save_total_limit:
debug: debug:

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
batch_size: 4 batch_size: 4
@@ -49,8 +49,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./falcon-7b output_dir: ./falcon-7b
batch_size: 2 batch_size: 2
@@ -51,8 +51,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 40 warmup_steps: 40
evals_per_epoch: 4 eval_steps: 5
saves_per_epoch: 1 save_steps: 43
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
@@ -80,8 +80,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 5
saves_per_epoch: 1 save_steps: 10
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.000001 weight_decay: 0.000001

View File

@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./falcon-7b output_dir: ./falcon-7b
batch_size: 2 batch_size: 2
@@ -51,8 +51,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 40 warmup_steps: 40
evals_per_epoch: 4 eval_steps: 5
saves_per_epoch: 1 save_steps: 43
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
@@ -46,8 +46,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./jeopardy-bot-7b output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
@@ -42,8 +42,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 110
saves_per_epoch: 1 save_steps: 660
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
@@ -58,9 +58,9 @@ flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true flash_attn_fuse_mlp: true
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -32,7 +32,7 @@ lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./model-out output_dir: ./model-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
@@ -62,8 +62,8 @@ flash_attention:
sdp_attention: sdp_attention:
flash_optimum: flash_optimum:
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 eval_steps:
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -54,10 +54,10 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -56,9 +56,9 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -35,7 +35,7 @@ relora_cpu_offload: false
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -60,8 +60,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps: 50
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -4,19 +4,20 @@ model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
load_in_8bit: true load_in_8bit: false
load_in_4bit: false load_in_4bit: false
strict: false strict: false
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/context-aware-splits-english
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 200
output_dir: ./lora-out output_dir: ./tiny-llama
sequence_len: 4096 sequence_len: 8192
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -29,12 +30,12 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 8
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -53,13 +54,13 @@ logging_steps: 1
xformers_attention: xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 50
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
saves_per_epoch: 1 save_steps: 0.50
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.1
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:

View File

@@ -1,61 +0,0 @@
base_model: state-spaces/mamba-2.8b
model_type: MambaLMHeadModel
tokenizer_type: AutoTokenizer
tokenizer_config: EleutherAI/gpt-neox-20b
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./out
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
tokens:
save_safetensors: False

View File

@@ -21,7 +21,7 @@ pad_to_sequence_len: true
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -46,10 +46,10 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -1,88 +0,0 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./qlora-out
## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
# - lm_head.*
# - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
#lora_target_modules:
# - gate
# - q_proj
# - k_proj
# - v_proj
# - o_proj
# - w1
# - w2
# - w3
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed/zero2.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -38,7 +38,7 @@ lora_target_modules:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -62,14 +62,11 @@ logging_steps: 1
xformers_attention: xformers_attention:
flash_attention: true flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b wandb_project: mpt-alpaca-7b
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./mpt-alpaca-7b output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
@@ -44,8 +44,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 110
saves_per_epoch: 1 save_steps: 660
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0001 weight_decay: 0.0001

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./openllama-out output_dir: ./openllama-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
@@ -49,8 +49,8 @@ flash_attention: true
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./lora-out output_dir: ./lora-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
@@ -54,8 +54,8 @@ flash_attention: true
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
@@ -48,8 +48,8 @@ flash_attention: true
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -1,5 +1,5 @@
base_model: microsoft/phi-1_5 base_model: microsoft/phi-1_5
model_type: PhiForCausalLM model_type: MixFormerSequentialForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_llama_derived_model: false is_llama_derived_model: false
trust_remote_code: true trust_remote_code: true
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
@@ -59,8 +59,8 @@ xformers_attention:
flash_attention: flash_attention:
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
@@ -59,8 +59,8 @@ xformers_attention:
flash_attention: flash_attention:
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./pythia-12b output_dir: ./pythia-12b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./lora-alpaca-pythia output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
@@ -33,5 +33,5 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
weight_decay: 0.1 weight_decay: 0.1
evals_per_epoch: 4 eval_steps: 0.05
logging_steps: 1 logging_steps: 1

View File

@@ -1,68 +0,0 @@
base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 2048 # supports up to 8192
sample_packing: false
pad_to_sequence_len:
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,68 +0,0 @@
base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 2048 # supports up to 8192
sample_packing: false
pad_to_sequence_len:
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b wandb_project: redpajama-alpaca-3b
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./redpajama-alpaca-3b output_dir: ./redpajama-alpaca-3b
batch_size: 4 batch_size: 4
@@ -45,8 +45,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 110
saves_per_epoch: 1 save_steps: 660
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0001 weight_decay: 0.0001

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project: lora-replit wandb_project: lora-replit
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./lora-replit output_dir: ./lora-replit
batch_size: 8 batch_size: 8
@@ -45,8 +45,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 50
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0 weight_decay: 0

View File

@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
@@ -78,8 +78,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 50
saves_per_epoch: 1 save_steps: 50
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -1,21 +1,22 @@
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
auto-gptq==0.5.1 torch==2.0.1
auto-gptq==0.4.2
packaging packaging
peft==0.6.0 peft==0.6.0
transformers @ git+https://github.com/huggingface/transformers.git@ebfdb9ca62205279d5019ef1403877461b3b2da4 transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate==0.24.1 accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
deepspeed deepspeed
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
datasets>=2.15.0 datasets>=2.14.0
flash-attn==2.3.3 flash-attn>=2.3.0
sentencepiece sentencepiece
wandb wandb
einops einops
xformers==0.0.22 xformers>=0.0.22
optimum==1.13.2 optimum==1.13.2
hf_transfer hf_transfer
colorama colorama
@@ -29,8 +30,8 @@ scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.34 fschat==0.2.29
gradio==3.50.2 gradio
tensorboard tensorboard
# remote filesystems # remote filesystems

View File

@@ -46,13 +46,10 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.3.3", "flash-attn>=2.3.0",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed", "deepspeed",
], ],
"mamba-ssm": [
"mamba-ssm==1.0.1",
],
}, },
) )

View File

@@ -29,7 +29,6 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -72,7 +71,7 @@ def do_merge_lora(
LOG.info("running merge of LoRA with base model") LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload() model = model.merge_and_unload()
model.to(dtype=cfg.torch_dtype) model.to(dtype=torch.float16)
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
@@ -297,8 +296,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
validate_config(cfg) validate_config(cfg)
prepare_optim_env(cfg)
normalize_config(cfg) normalize_config(cfg)
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)

View File

@@ -22,8 +22,8 @@ LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))

View File

@@ -25,16 +25,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
) )
from axolotl.utils.collators import ( from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
BatchSamplerDataCollatorForSeq2Seq,
MambaDataCollator,
)
from axolotl.utils.samplers import MultipackBatchSampler from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
@@ -52,9 +48,6 @@ class AxolotlTrainingArguments(TrainingArguments):
Extend the base TrainingArguments for axolotl helpers Extend the base TrainingArguments for axolotl helpers
""" """
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field( lr_quadratic_warmup: bool = field(
default=False, default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."}, metadata={"help": "Use quadratic warmup for cosine scheduling."},
@@ -291,32 +284,6 @@ class AxolotlTrainer(Trainer):
return super().compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""
def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss
class OneCycleLRSchedulerTrainer(AxolotlTrainer): class OneCycleLRSchedulerTrainer(AxolotlTrainer):
""" """
Trainer subclass that uses the OneCycleLR scheduler Trainer subclass that uses the OneCycleLR scheduler
@@ -463,9 +430,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
@@ -494,19 +458,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return OneCycleLRSchedulerTrainer return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps: if self.cfg.relora_steps:
return ReLoRATrainer return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
return AxolotlTrainer return AxolotlTrainer
def build(self, total_num_steps): def build(self, total_num_steps):
warmup_steps = None warmup_steps = (
if self.cfg.warmup_steps is not None: self.cfg.warmup_steps
warmup_steps = self.cfg.warmup_steps if self.cfg.warmup_steps is not None
elif self.cfg.warmup_ratio is not None: else min(int(0.03 * total_num_steps), 100)
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) )
else:
warmup_steps = min(int(0.03 * total_num_steps), 100)
logging_steps = ( logging_steps = (
self.cfg.logging_steps self.cfg.logging_steps
if self.cfg.logging_steps is not None if self.cfg.logging_steps is not None
@@ -563,7 +522,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.hub_strategy: if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors is not None: if self.cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est: if self.cfg.sample_packing_eff_est:
@@ -681,7 +640,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = ( training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None self.cfg.wandb_run_id if self.cfg.use_wandb else None
) )
training_arguments_kwargs["optim"] = ( training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
@@ -692,9 +651,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine" else "cosine"
) )
training_arguments_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
@@ -702,9 +658,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.sample_packing if self.cfg.sample_packing else False self.cfg.sample_packing if self.cfg.sample_packing else False
) )
training_arguments_kwargs["eval_sample_packing"] = ( training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing self.cfg.sample_packing if self.cfg.sample_packing else False
if self.cfg.eval_sample_packing is not False
else False
) )
training_arguments_kwargs[ training_arguments_kwargs[
"sample_packing_seq_len_multiplier" "sample_packing_seq_len_multiplier"
@@ -714,13 +668,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
) )
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha
training_args = ( training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs, **training_arguments_kwargs,
@@ -775,7 +722,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset, eval_dataset=self.eval_dataset,
args=training_args, args=training_args,
data_collator=self.build_collator(**data_collator_kwargs), data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq( bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer, self.tokenizer,
return_tensors="pt", return_tensors="pt",
@@ -795,13 +746,3 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.micro_batch_size ] = self.cfg.micro_batch_size
return trainer return trainer
def build_collator(self, **kwargs):
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
)

View File

@@ -1,12 +0,0 @@
"""
Modeling module for Mamba models
"""
def fix_mamba_attn_for_loss():
from mamba_ssm.models import mixer_seq_simple
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name

View File

@@ -1,42 +0,0 @@
"""
HF Transformers MambaConfig
"""
from transformers import PretrainedConfig
class MambaConfig(PretrainedConfig):
"""
modeling configuration for state space model/mamba
"""
model_type = "mamba"
def __init__(
self,
vocab_size=50280,
d_model=2560,
n_layer=64,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=8,
pad_token_id=50277,
bos_token_id=0,
eos_token_id=0,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layer = n_layer
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@@ -1,128 +0,0 @@
# pylint: skip-file
import os
from collections import namedtuple
from functools import partial
from typing import Optional, Union
import torch
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
from torch import nn
from torch.nn import CrossEntropyLoss
from axolotl.models.mamba.configuration_mamba import MambaConfig
class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
d_model: int,
n_layer: int,
vocab_size: int,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
device=None,
dtype=None,
**backbone_kwargs,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (
vocab_size % pad_vocab_size_multiple
)
self.config = MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layer,
pad_vocab_size_multiple=pad_vocab_size_multiple,
)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
initializer_cfg=initializer_cfg,
**backbone_kwargs,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()
def tie_weights(self):
self.lm_head.weight = self.backbone.embedding.weight
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(
self,
input_ids,
position_ids=None,
inference_params=None,
num_last_tokens=0,
labels=None,
**kwargs,
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
loss = None
if labels is not None:
logits = lm_logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
print(loss)
return CausalLMOutput(logits=lm_logits, loss=loss)
else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
):
if state_dict is None:
state_dict = self.state_dict()
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config = load_config_hf(pretrained_model_name)
model = cls(**config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
)
return model

View File

@@ -3,6 +3,4 @@ MixFormers model architecture used for phi models
""" """
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
from .configuration_phi import PhiConfig # noqa
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
from .modeling_phi import PhiForCausalLM # noqa

View File

@@ -1,65 +0,0 @@
# pylint: skip-file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import Optional
from transformers import PretrainedConfig
class PhiConfig(PretrainedConfig):
"""Phi configuration."""
model_type = "phi"
attribute_map = {
"max_position_embeddings": "n_positions",
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
}
def __init__(
self,
vocab_size: int = 50304,
n_positions: int = 2048,
n_embd: int = 1024,
n_layer: int = 20,
n_inner: Optional[int] = None,
n_head: int = 16,
n_head_kv: Optional[int] = None,
rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new",
flash_attn: bool = False,
flash_rotary: bool = False,
fused_dense: bool = False,
attn_pdrop: float = 0.0,
embd_pdrop: float = 0.0,
resid_pdrop: float = 0.0,
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
tie_word_embeddings: bool = False,
pad_vocab_size_multiple: int = 64,
**kwargs
) -> None:
self.vocab_size = int(
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.n_head_kv = n_head_kv
self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.activation_function = activation_function
self.flash_attn = flash_attn
self.flash_rotary = flash_rotary
self.fused_dense = fused_dense
self.attn_pdrop = attn_pdrop
self.embd_pdrop = embd_pdrop
self.resid_pdrop = resid_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -83,21 +83,14 @@ def get_turns( # pylint: disable=too-many-return-statements
yield role + ":", "" yield role + ":", ""
return return
if self.sep_style == SeparatorStyle.LLAMA2: if self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message: if self.system_message:
if self.messages:
# For llama, the system message is incorporated into the first human instruction
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt yield "", system_prompt
for i, (role, message) in enumerate(self.messages): else:
yield "", "[INST] "
for i, (role, message) in enumerate(self.messages[1:]):
if message: if message:
if (i % 2 == 0 and not self.system_message) or ( yield role + " ", message + seps[i % 2]
i % 2 != 0 and self.system_message
):
role = "<s> " + role
yield role + " ", message
else: else:
yield role, "" yield role, ""
return return

View File

@@ -1,22 +0,0 @@
"""
Patches to support multipack for mixtral
"""
import transformers
def replace_mixtral_attn_with_multipack_flash_attn():
from .modeling_mixtral import (
MixtralMultipackFlashAttention2,
mixtral_decoder_layer_forward,
mixtral_model_forward,
)
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
mixtral_decoder_layer_forward
)
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
mixtral_model_forward
)
transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
"flash_attention_2"
] = MixtralMultipackFlashAttention2

View File

@@ -1,379 +0,0 @@
"""
Mixtral modeling for multipack
"""
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
import logging
import warnings
from typing import List, Optional, Tuple, Union
import torch
from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from transformers import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.mixtral.modeling_mixtral import (
MixtralFlashAttention2,
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
"""
Custom multipack implementation w flash attention 2
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
attn_output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=self.attention_dropout,
softmax_scale=None,
causal=True,
)
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def mixtral_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_router_logits (`bool`, *optional*):
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
should not be returned during inference.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
if output_router_logits:
outputs += (router_logits,)
return outputs
def mixtral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = (
attention_mask
if (attention_mask is not None and 0 in attention_mask)
else None
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
LOG.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_router_logits,
]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)

View File

@@ -0,0 +1,65 @@
"""
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
from peft import PeftModel
from transformers import PreTrainedModel
def patch_neft(alpha, model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
embeddings.noisy_embedding_alpha = alpha
old_forward = embeddings.forward
# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
embeddings, embeddings.__class__
)
setattr(embeddings, "forward", bound_method)
embeddings._old_forward = old_forward # pylint: disable=protected-access
return model
def unpatch_neft(model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
if hasattr(embeddings, "_old_forward"):
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
del embeddings._old_forward # pylint: disable=protected-access
del embeddings.noisy_embedding_alpha
def neft_forward(self, inputs: torch.Tensor):
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
if self.training:
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
-mag_norm, mag_norm
)
return embeddings
def pretrain_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
def post_train_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
unpatch_neft(trainer.model)

View File

@@ -81,9 +81,8 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.tokenizer.add_special_tokens( self.sequence_len = 4096
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")} self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
)
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):

View File

@@ -13,7 +13,7 @@ register_conv_template(
system_message="You are a helpful assistant.", system_message="You are a helpful assistant.",
roles=["<|im_start|>user", "<|im_start|>assistant"], roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML, sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>", sep="<|im_end|>\n",
) )
) )

View File

@@ -33,8 +33,8 @@ class AlpacaPrompter(Prompter):
Base class for alpaca prompters Base class for alpaca prompters
""" """
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_format: str = "{system}" system_format: str = "{system}"
turn_format: str turn_format: str
turn_no_input_format: str turn_no_input_format: str

View File

@@ -16,8 +16,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
@@ -78,15 +78,11 @@ def train(
) )
resume_from_checkpoint = cfg.resume_from_checkpoint resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.unfrozen_parameters:
freeze_parameters_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer( trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
) )
if hasattr(model, "config"): model.config.use_cache = False
model.config.use_cache = False
# go ahead and presave, so we have the adapter config available to inspect # go ahead and presave, so we have the adapter config available to inspect
if peft_config: if peft_config:
@@ -96,8 +92,7 @@ def train(
if not Path(cfg.output_dir).is_dir(): if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True) os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir))) tokenizer.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"): model.config.save_pretrained(str(Path(cfg.output_dir)))
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
@@ -179,19 +174,21 @@ def train(
return model, tokenizer return model, tokenizer
def pretrain_hooks(_cfg, _trainer): def pretrain_hooks(cfg, trainer):
""" """
Run hooks right before kicking off the training Run hooks right before kicking off the training
:param cfg: :param cfg:
:param trainer: :param trainer:
:return: :return:
""" """
neft_embeddings.pretrain_hook(cfg, trainer)
def post_train_hooks(_cfg, _trainer): def post_train_hooks(cfg, trainer):
""" """
Run hooks right after training completes Run hooks right after training completes
:param cfg: :param cfg:
:param trainer: :param trainer:
:return: :return:
""" """
neft_embeddings.post_train_hook(cfg, trainer)

View File

@@ -124,36 +124,6 @@ class GPUStatsCallback(
return control return control
class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
_args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
if self.violations >= self.patience:
LOG.warning(
"Loss is too high, stopping training (loss_watchdog_threshold)"
)
control.should_training_stop = True
else:
self.violations = 0
return control
def bench_eval_callback_factory(trainer, tokenizer): def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy") accuracy = evaluate.load("accuracy")
abcd_idx = [ abcd_idx = [

View File

@@ -2,16 +2,12 @@
DataCollator for axolotl to pad labels and position_ids for packed sequences DataCollator for axolotl to pad labels and position_ids for packed sequences
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Sequence, Union from typing import Any, Optional, Union
import numpy as np import numpy as np
import torch
import transformers
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
IGNORE_INDEX = -100
@dataclass @dataclass
class DataCollatorForSeq2Seq: class DataCollatorForSeq2Seq:
@@ -150,31 +146,3 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
chunked_data[feature] = np.concatenate(arrays) chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data] features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors) return super().__call__(features, return_tensors=return_tensors)
@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return {
"input_ids": input_ids,
"labels": labels,
}

View File

@@ -27,7 +27,7 @@ def choose_device(cfg):
cfg.device = get_device() cfg.device = get_device()
if cfg.world_size == 1: if cfg.world_size == 1:
cfg.device_map = cfg.device_map or "auto" cfg.device_map = "auto"
else: else:
if cfg.device.startswith("cuda"): if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()} cfg.device_map = {"": torch.cuda.current_device()}
@@ -77,15 +77,6 @@ def normalize_config(cfg):
else: else:
cfg.torch_dtype = torch.float32 cfg.torch_dtype = torch.float32
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step
cfg.save_steps = save_steps
if cfg.evals_per_epoch:
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
if eval_steps < 1.0: # prevent evals on every step
cfg.eval_steps = eval_steps
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count() cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
if not cfg.base_model_config: if not cfg.base_model_config:
@@ -131,19 +122,6 @@ def normalize_config(cfg):
or (cfg.model_type and "mistral" in cfg.model_type.lower()) or (cfg.model_type and "mistral" in cfg.model_type.lower())
) )
cfg.is_qwen_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
)
or cfg.is_qwen_derived_model
or "qwen" in cfg.base_model.lower()
or (cfg.model_type and "qwen" in cfg.model_type.lower())
)
if isinstance(cfg.learning_rate, str): if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate) cfg.learning_rate = float(cfg.learning_rate)
@@ -187,11 +165,7 @@ def validate_config(cfg):
"batch_size is not recommended. Please use gradient_accumulation_steps instead.", "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
) )
if ( if cfg.eval_batch_size != cfg.micro_batch_size:
cfg.eval_batch_size
and cfg.micro_batch_size
and cfg.eval_batch_size != cfg.micro_batch_size
):
LOG.warning( LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability." "eval_batch_size != micro_batch_size. This can lead to VRAM instability."
) )
@@ -361,27 +335,6 @@ def validate_config(cfg):
cfg.datasets[idx].type = cfg.datasets[idx].type.replace( cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt" "sharegpt_simple", "sharegpt"
) )
if cfg.saves_per_epoch and cfg.save_steps:
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
)
if (
cfg.evals_per_epoch
and cfg.evaluation_strategy
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError( raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
@@ -419,35 +372,6 @@ def validate_config(cfg):
if cfg.rope_scaling: if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`") LOG.warning("`rope_scaling` should now be be a key under `model_config`")
if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
LOG.warning(
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
)
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
if cfg.noisy_embedding_alpha is not None:
# Deprecated, use neftune_noise_alpha
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
if cfg.neftune_noise_alpha is None:
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
else:
# User is providing both; bail and have them sort out their settings
raise ValueError(
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
)
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
raise ValueError("neftune_noise_alpha must be > 0.0")
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -79,14 +79,6 @@ def prepare_dataset(cfg, tokenizer):
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset, tokenizer cfg, train_dataset, eval_dataset, tokenizer
) )
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
raise ValueError(
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
)
if cfg.max_steps: if cfg.max_steps:
total_num_steps = min( total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
@@ -242,14 +234,7 @@ def load_tokenized_prepared_datasets(
local_path = Path(config_dataset.path) local_path = Path(config_dataset.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk` ds = load_from_disk(config_dataset.path)
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
)
elif local_path.is_file(): elif local_path.is_file():
ds_type = get_ds_type(config_dataset) ds_type = get_ds_type(config_dataset)

View File

@@ -1,38 +0,0 @@
"""
module to freeze/unfreeze parameters by name
"""
import logging
import re
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.utils.freeze")
def freeze_parameters_except(model, regex_patterns):
"""
Freezes all layers of the given model except for the layers that match given regex patterns.
Periods in the patterns are treated as literal periods, not as wildcard characters.
Parameters:
- model (nn.Module): The PyTorch model to be modified.
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
Returns:
None; the model is modified in place.
"""
# Escape periods and compile the regex patterns
compiled_patterns = [
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
]
# First, freeze all parameters in the model
for param in model.parameters():
param.requires_grad = False
# Unfreeze layers that match the regex patterns
for name, param in model.named_parameters():
if any(pattern.match(name) for pattern in compiled_patterns):
if is_main_process():
LOG.debug(f"unfreezing {name}")
param.requires_grad = True

View File

@@ -4,7 +4,6 @@ import math
import os import os
from typing import Optional, Tuple # noqa: F401 from typing import Optional, Tuple # noqa: F401
import addict
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
@@ -21,9 +20,7 @@ from transformers import ( # noqa: F401
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from transformers.deepspeed import is_deepspeed_zero3_enabled
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -31,50 +28,16 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
quant_config_exists = hasattr(model_config, "quantization_config")
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
and model_config.quantization_config["quant_method"] == "gptq"
)
if cfg.gptq and not quant_config_method_is_gptq:
raise ValueError(
"model_config.quantization_config is not set or quant_method is not set to gptq. "
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)
def load_model_config(cfg): def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True trust_remote_code = cfg.trust_remote_code is True
model_config = AutoConfig.from_pretrained(
try: model_config_name, trust_remote_code=trust_remote_code
model_config = AutoConfig.from_pretrained( )
model_config_name, trust_remote_code=trust_remote_code
)
except ValueError as err:
if "mamba" in model_config_name:
return addict.Dict(
{
"model_type": "mamba",
}
)
raise err
if cfg.model_config: if cfg.model_config:
for key, val in cfg.model_config.items(): for key, val in cfg.model_config.items():
setattr(model_config, key, val) setattr(model_config, key, val)
check_model_config(cfg, model_config)
return model_config return model_config
@@ -106,7 +69,6 @@ def load_tokenizer(cfg):
"LlamaTokenizer", "LlamaTokenizer",
"LlamaTokenizerFast", "LlamaTokenizerFast",
"CodeLlamaTokenizer", "CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
] ]
and hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token and not tokenizer.pad_token
@@ -122,40 +84,11 @@ def load_tokenizer(cfg):
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
if cfg.special_tokens: if cfg.special_tokens:
for k, val in cfg.special_tokens.items(): for k, val in cfg.special_tokens.items():
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
) )
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens: if cfg.tokens:
tokenizer.add_tokens( tokenizer.add_tokens(
[ [
@@ -250,18 +183,6 @@ def load_model(
LOG.info("patching with flash attention") LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if (
cfg.model_config_type == "mixtral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mixtral import (
replace_mixtral_attn_with_multipack_flash_attn,
)
LOG.info("patching with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()
if cfg.is_llama_derived_model and cfg.xpos_rope: if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope, replace_llama_rope_with_xpos_rope,
@@ -283,12 +204,8 @@ def load_model(
model_kwargs = {} model_kwargs = {}
model_kwargs["device_map"] = cfg.device_map model_kwargs["device_map"] = cfg.device_map
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype model_kwargs["torch_dtype"] = cfg.torch_dtype
if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]
if cfg.model_revision: if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision model_kwargs["revision"] = cfg.model_revision
if cfg.gptq: if cfg.gptq:
@@ -312,26 +229,13 @@ def load_model(
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
) )
# sample packing uses custom FA2 patch # sample packing uses custom FA2 patch
if cfg.flash_attention: if cfg.flash_attention and not cfg.sample_packing:
if not cfg.sample_packing: if (
if ( cfg.is_llama_derived_model
cfg.is_llama_derived_model or cfg.is_falcon_derived_model
or cfg.is_falcon_derived_model or cfg.is_mistral_derived_model
or cfg.is_mistral_derived_model ):
or model_config.model_type == "mixtral" model_kwargs["use_flash_attention_2"] = True
):
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type == "mixtral":
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
try: try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
@@ -384,29 +288,15 @@ def load_model(
# device=cfg.device, # device=cfg.device,
# ) # )
# model.train() # sets to train instead of eval mode # model.train() # sets to train instead of eval mode
elif model_type == "PhiForCausalLM": elif model_type == "MixFormerSequentialForCausalLM":
from axolotl.models.phi import PhiForCausalLM from axolotl.models.phi import MixFormerSequentialForCausalLM
model = PhiForCausalLM.from_pretrained( model = MixFormerSequentialForCausalLM.from_pretrained(
base_model, base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs, **model_kwargs,
) )
elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
model_kwargs["device"] = torch.cuda.current_device()
del model_kwargs["torch_dtype"]
del model_kwargs["device_map"]
del model_kwargs["max_memory"]
model = MambaLMHeadModel.from_pretrained(
base_model,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code: elif model_type and not cfg.trust_remote_code:
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
@@ -466,17 +356,13 @@ def load_model(
if cfg.resize_token_embeddings_to_32x if cfg.resize_token_embeddings_to_32x
else len(tokenizer) else len(tokenizer)
) )
if ( if model.get_input_embeddings().num_embeddings < embeddings_len:
hasattr(model, "get_input_embeddings")
and model.get_input_embeddings().num_embeddings < embeddings_len
):
model.resize_token_embeddings(embeddings_len) model.resize_token_embeddings(embeddings_len)
else: else:
model.tie_weights() model.tie_weights()
if ( if (
hasattr(model, "config") hasattr(model.config, "max_position_embeddings")
and hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings and model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings and cfg.sequence_len > model.config.max_position_embeddings
): ):
@@ -486,22 +372,20 @@ def load_model(
model.config.max_position_embeddings = cfg.sequence_len model.config.max_position_embeddings = cfg.sequence_len
if ( if (
hasattr(model, "config") hasattr(model.config, "bos_token_id")
and hasattr(model.config, "bos_token_id")
and model.config.bos_token_id and model.config.bos_token_id
and model.config.bos_token_id != tokenizer.bos_token_id and model.config.bos_token_id != tokenizer.bos_token_id
): ):
model.config.bos_token_id = tokenizer.bos_token_id model.config.bos_token_id = tokenizer.bos_token_id
if ( if (
hasattr(model, "config") hasattr(model.config, "eos_token_id")
and hasattr(model.config, "eos_token_id")
and model.config.eos_token_id and model.config.eos_token_id
and model.config.eos_token_id != tokenizer.eos_token_id and model.config.eos_token_id != tokenizer.eos_token_id
): ):
model.config.eos_token_id = tokenizer.eos_token_id model.config.eos_token_id = tokenizer.eos_token_id
if hasattr(model, "device") and model.device.type == "cuda": if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device) log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021) # make sure these are fp32 per Ramesh et al. (2021)
@@ -516,22 +400,15 @@ def load_model(
module.to(torch.float32) module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
if (cfg.adapter == "lora" and load_in_8bit) or ( if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit cfg.adapter == "qlora" and cfg.load_in_4bit
): ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
if not skip_prepare_model_for_kbit_training: model = prepare_model_for_kbit_training(
model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=cfg.gradient_checkpointing
model, use_gradient_checkpointing=cfg.gradient_checkpointing )
)
needs_fa2_dtype = True needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
@@ -560,8 +437,7 @@ def load_model(
requires_grad.append(f"{name}: {param.requires_grad}") requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0: if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates") LOG.warning("there are no parameters that require gradient updates")
if hasattr(model, "config"): model.config.use_cache = False
model.config.use_cache = False
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.transform(model) model = BetterTransformer.transform(model)

View File

@@ -182,7 +182,7 @@ class MultipackBatchSampler(BatchSampler):
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return max( return max(
0, 1,
( (
world_size world_size
* math.floor( * math.floor(

View File

@@ -131,10 +131,8 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
) )
# Phi doesn't want the attention_mask feature when training # Phi doesn't want the attention_mask feature when training
if ( if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
"CodeGenTokenizer" in tokenizer.__class__.__name__ cfg.is_mistral_derived_model and cfg.flash_attention
or (cfg.is_mistral_derived_model and cfg.flash_attention)
or cfg.model_config_type == "mamba"
): ):
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset: if eval_dataset:
@@ -143,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
return train_dataset, eval_dataset return train_dataset, eval_dataset
def calculate_total_num_steps(cfg, train_dataset, update=True): def calculate_total_num_steps(cfg, train_dataset):
if not cfg.total_num_tokens: if not cfg.total_num_tokens:
total_num_tokens = np.sum( total_num_tokens = np.sum(
train_dataset.data.column("input_ids") train_dataset.data.column("input_ids")
@@ -152,12 +150,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.values .values
) )
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True) LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
if update: cfg.total_num_tokens = total_num_tokens
cfg.total_num_tokens = total_num_tokens
skip_estimates = cfg.model_config_type == "mamba" if not cfg.total_supervised_tokens:
if not skip_estimates and not cfg.total_supervised_tokens:
total_supervised_tokens = ( total_supervised_tokens = (
train_dataset.data.column("labels") train_dataset.data.column("labels")
.to_pandas() .to_pandas()
@@ -168,10 +163,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
f"`total_supervised_tokens: {total_supervised_tokens}`", f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True, main_process_only=True,
) )
if update: cfg.total_supervised_tokens = total_supervised_tokens
cfg.total_supervised_tokens = total_supervised_tokens
if not skip_estimates and cfg.sample_packing: if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise # we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails # flash attention with position ids fails
@@ -238,8 +232,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
sample_packing_eff_est = ( sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0 math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
) )
if update: cfg.sample_packing_eff_est = sample_packing_eff_est
cfg.sample_packing_eff_est = sample_packing_eff_est
LOG.debug( LOG.debug(
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}", f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
main_process_only=True, main_process_only=True,
@@ -271,15 +264,12 @@ def setup_fsdp_envs(cfg):
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.fsdp: if cfg.fsdp:
setup_fsdp_envs(cfg) setup_fsdp_envs(cfg)
elif cfg.deepspeed: elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset trainer_builder.eval_dataset = eval_dataset

View File

@@ -2,20 +2,20 @@
import os import os
from axolotl.utils.dict import DictDefault
def setup_wandb_env_vars(cfg):
def setup_wandb_env_vars(cfg: DictDefault): if cfg.wandb_mode and cfg.wandb_mode == "offline":
for key in cfg.keys(): os.environ["WANDB_MODE"] = cfg.wandb_mode
if key.startswith("wandb_"): elif cfg.wandb_project and len(cfg.wandb_project) > 0:
value = cfg.get(key, "") os.environ["WANDB_PROJECT"] = cfg.wandb_project
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
# Enable wandb if project name is present
if cfg.wandb_project and len(cfg.wandb_project) > 0:
cfg.use_wandb = True cfg.use_wandb = True
os.environ.pop("WANDB_DISABLED", None) # Remove if present if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
else: else:
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -1,65 +0,0 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMistral(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_fft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "state-spaces/mamba-130m",
"model_type": "MambaLMHeadModel",
"tokenizer_type": "AutoTokenizer",
"tokenizer_config": "EleutherAI/gpt-neox-20b",
"flash_attention": False,
"sequence_len": 1024,
"load_in_8bit": False,
"val_set_size": 0.0,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"gradient_checkpointing": False,
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": None,
"save_safetensors": False,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -31,7 +31,7 @@ class TestPhi(unittest.TestCase):
{ {
"base_model": "microsoft/phi-1_5", "base_model": "microsoft/phi-1_5",
"trust_remote_code": True, "trust_remote_code": True,
"model_type": "PhiForCausalLM", "model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
"sequence_len": 512, "sequence_len": 512,
"sample_packing": False, "sample_packing": False,
@@ -76,7 +76,7 @@ class TestPhi(unittest.TestCase):
{ {
"base_model": "microsoft/phi-1_5", "base_model": "microsoft/phi-1_5",
"trust_remote_code": True, "trust_remote_code": True,
"model_type": "PhiForCausalLM", "model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
"sequence_len": 512, "sequence_len": 512,
"sample_packing": True, "sample_packing": True,

File diff suppressed because one or more lines are too long

View File

@@ -114,76 +114,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
in self._caplog.records[0].message in self._caplog.records[0].message
) )
def test_sharegpt_llama(self):
"Make sure the sharegpt/llama is tokenized and formatted correctly."
prompter = ShareGPTPrompterV2(conversation="llama-2")
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
def tokenize(conv):
return strat.tokenize_prompt(conv)["input_ids"]
def decode(ids):
return strat.tokenizer.decode(ids)
# Multi-turn conversations
multi_turn_conv = {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "123"},
{"from": "gpt", "value": "sit"},
]
}
# fmt: off
mt_ids = tokenize(multi_turn_conv)
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# Single-turn conversations
single_turn_conv = {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
]
}
st_ids = tokenize(single_turn_conv)
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
# No system message, single-turn
no_sys_conv = {
"conversations": [
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
]
}
ns_ids = tokenize(no_sys_conv)
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
# No system message, multi-turn
no_sys_mt_conv = {
"conversations": [
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "123"},
{"from": "gpt", "value": "sit"},
]
}
ns_mt_ids = tokenize(no_sys_mt_conv)
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# fmt: on
def test_sharegpt_changes_roles(self): def test_sharegpt_changes_roles(self):
conversation = { conversation = {
"roles": ["USER", "CHARACTER"], "roles": ["USER", "CHARACTER"],

View File

@@ -1,7 +1,6 @@
"""Module for testing the validation module""" """Module for testing the validation module"""
import logging import logging
import os
import unittest import unittest
from typing import Optional from typing import Optional
@@ -9,7 +8,6 @@ import pytest
from axolotl.utils.config import validate_config from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.wandb_ import setup_wandb_env_vars
class ValidationTest(unittest.TestCase): class ValidationTest(unittest.TestCase):
@@ -651,113 +649,3 @@ class ValidationTest(unittest.TestCase):
) )
validate_config(cfg) validate_config(cfg)
def test_warmup_step_no_conflict(self):
cfg = DictDefault(
{
"warmup_steps": 10,
"warmup_ratio": 0.1,
}
)
with pytest.raises(
ValueError,
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
):
validate_config(cfg)
cfg = DictDefault(
{
"warmup_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"warmup_ratio": 0.1,
}
)
validate_config(cfg)
class ValidationWandbTest(ValidationTest):
"""
Validation test for wandb
"""
def test_wandb_set_run_id_to_name(self):
cfg = DictDefault(
{
"wandb_run_id": "foo",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
in record.message
for record in self._caplog.records
)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
cfg = DictDefault(
{
"wandb_name": "foo",
}
)
validate_config(cfg)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
def test_wandb_sets_env(self):
cfg = DictDefault(
{
"wandb_project": "foo",
"wandb_name": "bar",
"wandb_run_id": "bat",
"wandb_entity": "baz",
"wandb_mode": "online",
"wandb_watch": "false",
"wandb_log_model": "checkpoint",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_PROJECT", "") == "foo"
assert os.environ.get("WANDB_NAME", "") == "bar"
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
assert os.environ.get("WANDB_ENTITY", "") == "baz"
assert os.environ.get("WANDB_MODE", "") == "online"
assert os.environ.get("WANDB_WATCH", "") == "false"
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
assert os.environ.get("WANDB_DISABLED", "") != "true"
def test_wandb_set_disabled(self):
cfg = DictDefault({})
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") == "true"
cfg = DictDefault(
{
"wandb_project": "foo",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") != "true"