Compare commits
1 Commits
hamelsmu-p
...
completion
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da154e6d56 |
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
||||
run: |
|
||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||
pip3 uninstall -y transformers accelerate
|
||||
pip3 install -U -e .[flash-attn,mamba-ssm]
|
||||
pip3 install -U -e .[flash-attn]
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run e2e tests
|
||||
|
||||
@@ -8,9 +8,6 @@ ignore_missing_imports = True
|
||||
[mypy-axolotl.monkeypatch.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-axolotl.models.mixtral.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-axolotl.models.phi.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
||||
95
README.md
95
README.md
@@ -36,9 +36,7 @@ Features:
|
||||
- [Train](#train)
|
||||
- [Inference](#inference)
|
||||
- [Merge LORA to Base](#merge-lora-to-base)
|
||||
- [Special Tokens](#special-tokens)
|
||||
- [Common Errors](#common-errors-)
|
||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||
- [Need Help?](#need-help-)
|
||||
- [Badge](#badge-)
|
||||
- [Community Showcase](#community-showcase)
|
||||
@@ -67,21 +65,19 @@ Features:
|
||||
|
||||
## Axolotl supports
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
@@ -102,7 +98,7 @@ pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
### Usage
|
||||
```bashtet
|
||||
```bash
|
||||
# finetune lora
|
||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
|
||||
@@ -249,17 +245,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
|
||||
```yml
|
||||
datasets:
|
||||
- path: <your-path>
|
||||
type: sharegpt
|
||||
conversation: llama-2
|
||||
```
|
||||
- `completion`: raw corpus
|
||||
```json
|
||||
{"text": "..."}
|
||||
@@ -623,12 +612,6 @@ eval_sample_packing:
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# Passed through to transformers when loading the model when launched without accelerate
|
||||
# Use `sequential` when training w/ model parallelism to limit memory
|
||||
device_map:
|
||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
||||
max_memory:
|
||||
|
||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
@@ -676,8 +659,7 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
|
||||
wandb_project: # Your wandb project name
|
||||
wandb_entity: # A wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_name: # Set the name of your wandb run
|
||||
wandb_run_id: # Set the ID of your wandb run
|
||||
wandb_run_id: # Set the name of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
@@ -700,11 +682,9 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||
save_strategy: # Set to `no` to skip checkpoint saves
|
||||
save_steps: # Leave empty to save at each epoch
|
||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
@@ -714,9 +694,6 @@ max_steps:
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
# Save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
@@ -783,7 +760,7 @@ max_grad_norm:
|
||||
# Augmentation techniques
|
||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||
# currently only supported on Llama and Mistral
|
||||
neftune_noise_alpha:
|
||||
noisy_embedding_alpha:
|
||||
|
||||
# Whether to bettertransformers
|
||||
flash_optimum:
|
||||
@@ -975,26 +952,10 @@ wandb_mode:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
##### Special Tokens
|
||||
|
||||
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||
|
||||
```yml
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
tokens: # these are delimiters
|
||||
- "<|im_start|>"
|
||||
- "<|im_end|>"
|
||||
```
|
||||
|
||||
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
|
||||
|
||||
### Inference
|
||||
|
||||
Pass the appropriate flag to the train command:
|
||||
@@ -1047,10 +1008,6 @@ Please reduce any below
|
||||
- `gradient_accumulation_steps`
|
||||
- `sequence_len`
|
||||
|
||||
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
|
||||
|
||||
Using adamw_bnb_8bit might also save you some memory.
|
||||
|
||||
> `failed (exitcode: -9)`
|
||||
|
||||
Usually means your system has run out of system memory.
|
||||
@@ -1073,20 +1030,6 @@ It's safe to ignore it.
|
||||
|
||||
See the [NCCL](docs/nccl.md) guide.
|
||||
|
||||
|
||||
### Tokenization Mismatch b/w Inference & Training
|
||||
|
||||
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
|
||||
|
||||
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
|
||||
|
||||
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
|
||||
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
|
||||
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
|
||||
4. As an additional troubleshooting step, you can look look at the token ids between 1 and 2 to make sure they are identical.
|
||||
|
||||
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
|
||||
|
||||
## Need help? 🙋♂️
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||
|
||||
@@ -24,6 +24,16 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -28,6 +28,16 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -32,6 +32,16 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 0,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 0,
|
||||
"stage3_max_reuse_distance": 0,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1"
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||
apt-get install -y vim curl
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ FROM winglian/axolotl:$BASE_TAG
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
output_dir: btlm-out
|
||||
@@ -72,8 +72,8 @@ gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
|
||||
warmup_steps: 32
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps:
|
||||
save_steps:
|
||||
save_total_limit:
|
||||
|
||||
debug:
|
||||
|
||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
@@ -49,8 +49,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -54,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -56,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -54,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -56,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -54,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -56,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./falcon-7b
|
||||
batch_size: 2
|
||||
@@ -51,8 +51,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 40
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 5
|
||||
save_steps: 43
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -80,8 +80,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 5
|
||||
save_steps: 10
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.000001
|
||||
|
||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./falcon-7b
|
||||
batch_size: 2
|
||||
@@ -51,8 +51,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 40
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 5
|
||||
save_steps: 43
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 2
|
||||
@@ -46,8 +46,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -42,8 +42,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -58,9 +58,9 @@ flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -32,7 +32,7 @@ lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./model-out
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -62,8 +62,8 @@ flash_attention:
|
||||
sdp_attention:
|
||||
flash_optimum:
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps:
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -54,10 +54,10 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -56,9 +56,9 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -35,7 +35,7 @@ relora_cpu_offload: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -60,8 +60,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps: 50
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -54,9 +54,9 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
base_model: state-spaces/mamba-2.8b
|
||||
model_type: MambaLMHeadModel
|
||||
tokenizer_type: AutoTokenizer
|
||||
tokenizer_config: EleutherAI/gpt-neox-20b
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
tokens:
|
||||
save_safetensors: False
|
||||
@@ -21,7 +21,7 @@ pad_to_sequence_len: true
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -46,10 +46,10 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./qlora-out
|
||||
|
||||
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||
unfrozen_parameters:
|
||||
# - lm_head.*
|
||||
# - model.embed_tokens.*
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
#lora_target_modules:
|
||||
# - gate
|
||||
# - q_proj
|
||||
# - k_proj
|
||||
# - v_proj
|
||||
# - o_proj
|
||||
# - w1
|
||||
# - w2
|
||||
# - w3
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed: deepspeed/zero2.json
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -62,14 +62,11 @@ logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: mpt-alpaca-7b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -44,8 +44,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
|
||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./openllama-out
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -49,8 +49,8 @@ flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-out
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -54,8 +54,8 @@ flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -48,8 +48,8 @@ flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -59,8 +59,8 @@ xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -59,8 +59,8 @@ xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./pythia-12b
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca-pythia
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -33,5 +33,5 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
weight_decay: 0.1
|
||||
evals_per_epoch: 4
|
||||
eval_steps: 0.05
|
||||
logging_steps: 1
|
||||
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -53,13 +53,13 @@ resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -53,13 +53,13 @@ resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: redpajama-alpaca-3b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./redpajama-alpaca-3b
|
||||
batch_size: 4
|
||||
@@ -45,8 +45,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
|
||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project: lora-replit
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-replit
|
||||
batch_size: 8
|
||||
@@ -45,8 +45,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0
|
||||
|
||||
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -78,8 +78,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
eval_steps: 50
|
||||
save_steps: 50
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -2,15 +2,14 @@
|
||||
auto-gptq==0.5.1
|
||||
packaging
|
||||
peft==0.6.0
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@ebfdb9ca62205279d5019ef1403877461b3b2da4
|
||||
tokenizers==0.15.0
|
||||
transformers==4.35.1
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate==0.24.1
|
||||
deepspeed
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
datasets>=2.15.0
|
||||
datasets>=2.14.0
|
||||
flash-attn==2.3.3
|
||||
sentencepiece
|
||||
wandb
|
||||
@@ -29,8 +28,8 @@ scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.34
|
||||
gradio==3.50.2
|
||||
fschat==0.2.29
|
||||
gradio
|
||||
tensorboard
|
||||
|
||||
# remote filesystems
|
||||
|
||||
5
setup.py
5
setup.py
@@ -46,13 +46,10 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.3.3",
|
||||
"flash-attn>=2.3.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.0.1",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -29,7 +29,6 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@@ -72,7 +71,7 @@ def do_merge_lora(
|
||||
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
model.to(dtype=cfg.torch_dtype)
|
||||
model.to(dtype=torch.float16)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||
@@ -297,8 +296,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
@@ -22,8 +22,8 @@ LOG = logging.getLogger("axolotl.cli.train")
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
|
||||
@@ -25,16 +25,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
MambaDataCollator,
|
||||
)
|
||||
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
@@ -52,9 +48,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
"""
|
||||
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
@@ -291,32 +284,6 @@ class AxolotlTrainer(Trainer):
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Mamba specific trainer to handle loss calculation
|
||||
"""
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
@@ -463,9 +430,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -494,8 +458,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return OneCycleLRSchedulerTrainer
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
@@ -563,7 +525,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.hub_strategy:
|
||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.save_safetensors is not None:
|
||||
if self.cfg.save_safetensors:
|
||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.cfg.sample_packing_eff_est:
|
||||
@@ -681,7 +643,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||
training_arguments_kwargs["run_name"] = (
|
||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||
self.cfg.wandb_run_id if self.cfg.use_wandb else None
|
||||
)
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
@@ -692,9 +654,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||
else "cosine"
|
||||
)
|
||||
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
training_arguments_kwargs["weight_decay"] = (
|
||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||
)
|
||||
@@ -714,13 +673,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs[
|
||||
"neftune_noise_alpha"
|
||||
] = self.cfg.neftune_noise_alpha
|
||||
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
@@ -775,7 +727,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=self.build_collator(**data_collator_kwargs),
|
||||
data_collator=BatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
@@ -795,13 +751,3 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.micro_batch_size
|
||||
|
||||
return trainer
|
||||
|
||||
def build_collator(self, **kwargs):
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||
|
||||
return BatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
"""
|
||||
Modeling module for Mamba models
|
||||
"""
|
||||
|
||||
|
||||
def fix_mamba_attn_for_loss():
|
||||
from mamba_ssm.models import mixer_seq_simple
|
||||
|
||||
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
||||
|
||||
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
|
||||
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
|
||||
@@ -1,42 +0,0 @@
|
||||
"""
|
||||
HF Transformers MambaConfig
|
||||
"""
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MambaConfig(PretrainedConfig):
|
||||
"""
|
||||
modeling configuration for state space model/mamba
|
||||
"""
|
||||
|
||||
model_type = "mamba"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50280,
|
||||
d_model=2560,
|
||||
n_layer=64,
|
||||
rms_norm=True,
|
||||
residual_in_fp32=True,
|
||||
fused_add_norm=True,
|
||||
pad_vocab_size_multiple=8,
|
||||
pad_token_id=50277,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.n_layer = n_layer
|
||||
self.rms_norm = rms_norm
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.fused_add_norm = fused_add_norm
|
||||
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,128 +0,0 @@
|
||||
# pylint: skip-file
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
||||
from mamba_ssm.utils.generation import GenerationMixin
|
||||
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from axolotl.models.mamba.configuration_mamba import MambaConfig
|
||||
|
||||
|
||||
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_layer: int,
|
||||
vocab_size: int,
|
||||
initializer_cfg=None,
|
||||
pad_vocab_size_multiple: int = 1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
**backbone_kwargs,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if vocab_size % pad_vocab_size_multiple != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (
|
||||
vocab_size % pad_vocab_size_multiple
|
||||
)
|
||||
self.config = MambaConfig(
|
||||
vocab_size=vocab_size,
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
||||
)
|
||||
self.backbone = MixerModel(
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
vocab_size=vocab_size,
|
||||
initializer_cfg=initializer_cfg,
|
||||
**backbone_kwargs,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(
|
||||
partial(
|
||||
_init_weights,
|
||||
n_layer=n_layer,
|
||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||
)
|
||||
)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self.lm_head.weight = self.backbone.embedding.weight
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.backbone.allocate_inference_cache(
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids=None,
|
||||
inference_params=None,
|
||||
num_last_tokens=0,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||
"""
|
||||
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
logits = lm_logits
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
|
||||
print(loss)
|
||||
return CausalLMOutput(logits=lm_logits, loss=loss)
|
||||
|
||||
else:
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
state_dict: Optional[dict] = None,
|
||||
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
|
||||
):
|
||||
if state_dict is None:
|
||||
state_dict = self.state_dict()
|
||||
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
||||
config = load_config_hf(pretrained_model_name)
|
||||
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
||||
model.load_state_dict(
|
||||
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
|
||||
)
|
||||
return model
|
||||
@@ -83,21 +83,14 @@ def get_turns( # pylint: disable=too-many-return-statements
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||
seps = [self.sep, self.sep2]
|
||||
if self.system_message:
|
||||
if self.messages:
|
||||
# For llama, the system message is incorporated into the first human instruction
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt += first_msg
|
||||
self.messages.pop(0)
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
else:
|
||||
yield "", "[INST] "
|
||||
for i, (role, message) in enumerate(self.messages[1:]):
|
||||
if message:
|
||||
if (i % 2 == 0 and not self.system_message) or (
|
||||
i % 2 != 0 and self.system_message
|
||||
):
|
||||
role = "<s> " + role
|
||||
yield role + " ", message
|
||||
yield role + " ", message + seps[i % 2]
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
Patches to support multipack for mixtral
|
||||
"""
|
||||
import transformers
|
||||
|
||||
|
||||
def replace_mixtral_attn_with_multipack_flash_attn():
|
||||
from .modeling_mixtral import (
|
||||
MixtralMultipackFlashAttention2,
|
||||
mixtral_decoder_layer_forward,
|
||||
mixtral_model_forward,
|
||||
)
|
||||
|
||||
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
|
||||
mixtral_decoder_layer_forward
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
||||
mixtral_model_forward
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
|
||||
"flash_attention_2"
|
||||
] = MixtralMultipackFlashAttention2
|
||||
@@ -1,379 +0,0 @@
|
||||
"""
|
||||
Mixtral modeling for multipack
|
||||
"""
|
||||
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
|
||||
import logging
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
from transformers import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
MixtralFlashAttention2,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
|
||||
|
||||
|
||||
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
|
||||
"""
|
||||
Custom multipack implementation w flash attention 2
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._flash_attn_uses_top_left_mask = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
attn_output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=self.attention_dropout,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def mixtral_decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_router_logits (`bool`, *optional*):
|
||||
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
||||
should not be returned during inference.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
if output_router_logits:
|
||||
outputs += (router_logits,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def mixtral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, MoeModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = (
|
||||
attention_mask
|
||||
if (attention_mask is not None and 0 in attention_mask)
|
||||
else None
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
LOG.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache()
|
||||
if use_legacy_cache
|
||||
else next_decoder_cache
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
all_router_logits,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
65
src/axolotl/monkeypatch/neft_embeddings.py
Normal file
65
src/axolotl/monkeypatch/neft_embeddings.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
def patch_neft(alpha, model):
|
||||
embeddings = None
|
||||
if isinstance(model, PreTrainedModel):
|
||||
embeddings = model.get_input_embeddings()
|
||||
if isinstance(model, PeftModel):
|
||||
embeddings = model.base_model.get_input_embeddings()
|
||||
if not embeddings:
|
||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||
embeddings.noisy_embedding_alpha = alpha
|
||||
old_forward = embeddings.forward
|
||||
|
||||
# This hack seems to be needed to properly use a custom forward pass
|
||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
||||
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
||||
embeddings, embeddings.__class__
|
||||
)
|
||||
setattr(embeddings, "forward", bound_method)
|
||||
|
||||
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
||||
return model
|
||||
|
||||
|
||||
def unpatch_neft(model):
|
||||
embeddings = None
|
||||
if isinstance(model, PreTrainedModel):
|
||||
embeddings = model.get_input_embeddings()
|
||||
if isinstance(model, PeftModel):
|
||||
embeddings = model.base_model.get_input_embeddings()
|
||||
if not embeddings:
|
||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||
if hasattr(embeddings, "_old_forward"):
|
||||
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
||||
del embeddings._old_forward # pylint: disable=protected-access
|
||||
del embeddings.noisy_embedding_alpha
|
||||
|
||||
|
||||
def neft_forward(self, inputs: torch.Tensor):
|
||||
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
||||
|
||||
if self.training:
|
||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
||||
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
||||
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def pretrain_hook(cfg, trainer):
|
||||
if cfg.noisy_embedding_alpha:
|
||||
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
||||
|
||||
|
||||
def post_train_hook(cfg, trainer):
|
||||
if cfg.noisy_embedding_alpha:
|
||||
unpatch_neft(trainer.model)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Basic completion text
|
||||
"""
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
@@ -64,6 +65,19 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
||||
|
||||
|
||||
class CompletionJSONPromptTokenizationStrategy(CompletionPromptTokenizingStrategy):
|
||||
"""
|
||||
Strategy to return the stringified JSON of the entire row as the training data
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
json.dumps(prompt),
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
|
||||
class CompletionPrompter:
|
||||
"""
|
||||
Prompter for completion
|
||||
@@ -82,7 +96,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
strat = CompletionPromptTokenizingStrategy(
|
||||
CompletionPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
True,
|
||||
cfg.sequence_len,
|
||||
max_length=cfg.sequence_len * 64,
|
||||
)
|
||||
@@ -90,3 +104,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
strat.field = ds_cfg["field"]
|
||||
|
||||
return strat
|
||||
|
||||
|
||||
def load_json(tokenizer, cfg):
|
||||
strat = CompletionJSONPromptTokenizationStrategy(
|
||||
CompletionPrompter(),
|
||||
tokenizer,
|
||||
True,
|
||||
cfg.sequence_len,
|
||||
max_length=cfg.sequence_len * 64,
|
||||
)
|
||||
|
||||
return strat
|
||||
|
||||
@@ -81,9 +81,8 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tokenizer.add_special_tokens(
|
||||
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")}
|
||||
)
|
||||
self.sequence_len = 4096
|
||||
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
|
||||
@@ -13,7 +13,7 @@ register_conv_template(
|
||||
system_message="You are a helpful assistant.",
|
||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
sep="<|im_end|>\n",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ class AlpacaPrompter(Prompter):
|
||||
Base class for alpaca prompters
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
system_format: str = "{system}"
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
|
||||
@@ -16,8 +16,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.monkeypatch import neft_embeddings
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.freeze import freeze_parameters_except
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
@@ -78,15 +78,11 @@ def train(
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
)
|
||||
|
||||
if hasattr(model, "config"):
|
||||
model.config.use_cache = False
|
||||
model.config.use_cache = False
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
@@ -96,8 +92,7 @@ def train(
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
@@ -179,19 +174,21 @@ def train(
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pretrain_hooks(_cfg, _trainer):
|
||||
def pretrain_hooks(cfg, trainer):
|
||||
"""
|
||||
Run hooks right before kicking off the training
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
neft_embeddings.pretrain_hook(cfg, trainer)
|
||||
|
||||
|
||||
def post_train_hooks(_cfg, _trainer):
|
||||
def post_train_hooks(cfg, trainer):
|
||||
"""
|
||||
Run hooks right after training completes
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
neft_embeddings.post_train_hook(cfg, trainer)
|
||||
|
||||
@@ -124,36 +124,6 @@ class GPUStatsCallback(
|
||||
return control
|
||||
|
||||
|
||||
class LossWatchDogCallback(TrainerCallback):
|
||||
"""Callback to track loss and stop training if loss is too high"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
self.violations = 0
|
||||
self.threshold = cfg.loss_watchdog_threshold
|
||||
self.patience = cfg.loss_watchdog_patience or 3
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
):
|
||||
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
||||
if state.log_history[-1]["loss"] > self.threshold:
|
||||
self.violations += 1
|
||||
if self.violations >= self.patience:
|
||||
LOG.warning(
|
||||
"Loss is too high, stopping training (loss_watchdog_threshold)"
|
||||
)
|
||||
control.should_training_stop = True
|
||||
else:
|
||||
self.violations = 0
|
||||
return control
|
||||
|
||||
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
|
||||
@@ -2,16 +2,12 @@
|
||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSeq2Seq:
|
||||
@@ -150,31 +146,3 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaDataCollator:
|
||||
"""
|
||||
Collator for State Space Models (Mamba)
|
||||
"""
|
||||
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
input_ids, labels = tuple(
|
||||
[torch.LongTensor(instance[key]) for instance in instances]
|
||||
for key in ("input_ids", "labels")
|
||||
)
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
labels, batch_first=True, padding_value=IGNORE_INDEX
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ def choose_device(cfg):
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.world_size == 1:
|
||||
cfg.device_map = cfg.device_map or "auto"
|
||||
cfg.device_map = "auto"
|
||||
else:
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": torch.cuda.current_device()}
|
||||
@@ -77,15 +77,6 @@ def normalize_config(cfg):
|
||||
else:
|
||||
cfg.torch_dtype = torch.float32
|
||||
|
||||
if cfg.saves_per_epoch:
|
||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||
if save_steps < 1.0: # prevent saves on every step
|
||||
cfg.save_steps = save_steps
|
||||
if cfg.evals_per_epoch:
|
||||
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
||||
if eval_steps < 1.0: # prevent evals on every step
|
||||
cfg.eval_steps = eval_steps
|
||||
|
||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||
|
||||
if not cfg.base_model_config:
|
||||
@@ -361,27 +352,6 @@ def validate_config(cfg):
|
||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||
"sharegpt_simple", "sharegpt"
|
||||
)
|
||||
|
||||
if cfg.saves_per_epoch and cfg.save_steps:
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||
)
|
||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if (
|
||||
cfg.evals_per_epoch
|
||||
and cfg.evaluation_strategy
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
@@ -427,27 +397,6 @@ def validate_config(cfg):
|
||||
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
||||
)
|
||||
|
||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||
cfg.wandb_name = cfg.wandb_run_id
|
||||
|
||||
LOG.warning(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
)
|
||||
|
||||
if cfg.noisy_embedding_alpha is not None:
|
||||
# Deprecated, use neftune_noise_alpha
|
||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||
if cfg.neftune_noise_alpha is None:
|
||||
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||
else:
|
||||
# User is providing both; bail and have them sort out their settings
|
||||
raise ValueError(
|
||||
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||
)
|
||||
|
||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
"""
|
||||
module to freeze/unfreeze parameters by name
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||
|
||||
|
||||
def freeze_parameters_except(model, regex_patterns):
|
||||
"""
|
||||
Freezes all layers of the given model except for the layers that match given regex patterns.
|
||||
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
||||
|
||||
Parameters:
|
||||
- model (nn.Module): The PyTorch model to be modified.
|
||||
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
||||
|
||||
Returns:
|
||||
None; the model is modified in place.
|
||||
"""
|
||||
# Escape periods and compile the regex patterns
|
||||
compiled_patterns = [
|
||||
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
||||
]
|
||||
|
||||
# First, freeze all parameters in the model
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Unfreeze layers that match the regex patterns
|
||||
for name, param in model.named_parameters():
|
||||
if any(pattern.match(name) for pattern in compiled_patterns):
|
||||
if is_main_process():
|
||||
LOG.debug(f"unfreezing {name}")
|
||||
param.requires_grad = True
|
||||
@@ -4,7 +4,6 @@ import math
|
||||
import os
|
||||
from typing import Optional, Tuple # noqa: F401
|
||||
|
||||
import addict
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
@@ -21,9 +20,7 @@ from transformers import ( # noqa: F401
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -31,50 +28,16 @@ from axolotl.utils.dict import DictDefault
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
||||
quant_config_exists = hasattr(model_config, "quantization_config")
|
||||
quant_config_method_is_gptq = (
|
||||
quant_config_exists
|
||||
and "quant_method" in model_config.quantization_config
|
||||
and model_config.quantization_config["quant_method"] == "gptq"
|
||||
)
|
||||
|
||||
if cfg.gptq and not quant_config_method_is_gptq:
|
||||
raise ValueError(
|
||||
"model_config.quantization_config is not set or quant_method is not set to gptq. "
|
||||
"Please make sure to point to a GPTQ model."
|
||||
)
|
||||
|
||||
if not cfg.gptq and quant_config_exists:
|
||||
raise ValueError(
|
||||
"model_config.quantization_config is set but `gptq` flag is not. "
|
||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
||||
)
|
||||
|
||||
|
||||
def load_model_config(cfg):
|
||||
model_config_name = cfg.base_model_config or cfg.base_model
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_config_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except ValueError as err:
|
||||
if "mamba" in model_config_name:
|
||||
return addict.Dict(
|
||||
{
|
||||
"model_type": "mamba",
|
||||
}
|
||||
)
|
||||
raise err
|
||||
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_config_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if cfg.model_config:
|
||||
for key, val in cfg.model_config.items():
|
||||
setattr(model_config, key, val)
|
||||
|
||||
check_model_config(cfg, model_config)
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
@@ -106,7 +69,6 @@ def load_tokenizer(cfg):
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
"CodeLlamaTokenizerFast",
|
||||
]
|
||||
and hasattr(tokenizer, "pad_token")
|
||||
and not tokenizer.pad_token
|
||||
@@ -139,23 +101,6 @@ def load_tokenizer(cfg):
|
||||
tokenizer.add_special_tokens(
|
||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||
)
|
||||
|
||||
# If we add bos_token and eos_token, we need to update the post processor to
|
||||
# handle them correctly.
|
||||
# https://github.com/huggingface/transformers/pull/24132
|
||||
bos_or_eos_in_special_tokens = (
|
||||
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
||||
)
|
||||
if (
|
||||
tokenizer.__class__.__name__
|
||||
in (
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizerFast",
|
||||
)
|
||||
and bos_or_eos_in_special_tokens
|
||||
):
|
||||
tokenizer.update_post_processor()
|
||||
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
@@ -250,18 +195,6 @@ def load_model(
|
||||
LOG.info("patching with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if (
|
||||
cfg.model_config_type == "mixtral"
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
from axolotl.monkeypatch.mixtral import (
|
||||
replace_mixtral_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching with flash attention")
|
||||
replace_mixtral_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||
replace_llama_rope_with_xpos_rope,
|
||||
@@ -283,12 +216,8 @@ def load_model(
|
||||
model_kwargs = {}
|
||||
|
||||
model_kwargs["device_map"] = cfg.device_map
|
||||
model_kwargs["max_memory"] = cfg.max_memory
|
||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
del model_kwargs["device_map"]
|
||||
|
||||
if cfg.model_revision:
|
||||
model_kwargs["revision"] = cfg.model_revision
|
||||
if cfg.gptq:
|
||||
@@ -312,26 +241,13 @@ def load_model(
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
# sample packing uses custom FA2 patch
|
||||
if cfg.flash_attention:
|
||||
if not cfg.sample_packing:
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
or cfg.is_falcon_derived_model
|
||||
or cfg.is_mistral_derived_model
|
||||
or model_config.model_type == "mixtral"
|
||||
):
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
)
|
||||
else:
|
||||
if model_config.model_type == "mixtral":
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
)
|
||||
else:
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
if cfg.flash_attention and not cfg.sample_packing:
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
or cfg.is_falcon_derived_model
|
||||
or cfg.is_mistral_derived_model
|
||||
):
|
||||
model_kwargs["use_flash_attention_2"] = True
|
||||
|
||||
try:
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
@@ -393,20 +309,6 @@ def load_model(
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif model_type == "MambaLMHeadModel":
|
||||
# FIXME this is janky at best and hacked together to make it work
|
||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||
|
||||
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
|
||||
model_kwargs["device"] = torch.cuda.current_device()
|
||||
del model_kwargs["torch_dtype"]
|
||||
del model_kwargs["device_map"]
|
||||
del model_kwargs["max_memory"]
|
||||
|
||||
model = MambaLMHeadModel.from_pretrained(
|
||||
base_model,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif model_type and not cfg.trust_remote_code:
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -466,17 +368,13 @@ def load_model(
|
||||
if cfg.resize_token_embeddings_to_32x
|
||||
else len(tokenizer)
|
||||
)
|
||||
if (
|
||||
hasattr(model, "get_input_embeddings")
|
||||
and model.get_input_embeddings().num_embeddings < embeddings_len
|
||||
):
|
||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
else:
|
||||
model.tie_weights()
|
||||
|
||||
if (
|
||||
hasattr(model, "config")
|
||||
and hasattr(model.config, "max_position_embeddings")
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
and model.config.max_position_embeddings
|
||||
and cfg.sequence_len > model.config.max_position_embeddings
|
||||
):
|
||||
@@ -486,22 +384,20 @@ def load_model(
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if (
|
||||
hasattr(model, "config")
|
||||
and hasattr(model.config, "bos_token_id")
|
||||
hasattr(model.config, "bos_token_id")
|
||||
and model.config.bos_token_id
|
||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
||||
):
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
|
||||
if (
|
||||
hasattr(model, "config")
|
||||
and hasattr(model.config, "eos_token_id")
|
||||
hasattr(model.config, "eos_token_id")
|
||||
and model.config.eos_token_id
|
||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
||||
):
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if hasattr(model, "device") and model.device.type == "cuda":
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
@@ -516,22 +412,15 @@ def load_model(
|
||||
module.to(torch.float32)
|
||||
|
||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||
skip_prepare_model_for_kbit_training = False
|
||||
|
||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||
):
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
if not skip_prepare_model_for_kbit_training:
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
needs_fa2_dtype = True
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
@@ -560,8 +449,7 @@ def load_model(
|
||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||
if len(requires_grad) == 0:
|
||||
LOG.warning("there are no parameters that require gradient updates")
|
||||
if hasattr(model, "config"):
|
||||
model.config.use_cache = False
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
@@ -131,10 +131,8 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
)
|
||||
|
||||
# Phi doesn't want the attention_mask feature when training
|
||||
if (
|
||||
"CodeGenTokenizer" in tokenizer.__class__.__name__
|
||||
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
||||
or cfg.model_config_type == "mamba"
|
||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
||||
):
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
@@ -155,9 +153,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if update:
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
skip_estimates = cfg.model_config_type == "mamba"
|
||||
|
||||
if not skip_estimates and not cfg.total_supervised_tokens:
|
||||
if not cfg.total_supervised_tokens:
|
||||
total_supervised_tokens = (
|
||||
train_dataset.data.column("labels")
|
||||
.to_pandas()
|
||||
@@ -171,7 +167,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if update:
|
||||
cfg.total_supervised_tokens = total_supervised_tokens
|
||||
|
||||
if not skip_estimates and cfg.sample_packing:
|
||||
if cfg.sample_packing:
|
||||
# we have to drop anything longer then sequence len otherwise
|
||||
# flash attention with position ids fails
|
||||
|
||||
@@ -271,15 +267,12 @@ def setup_fsdp_envs(cfg):
|
||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||
|
||||
|
||||
def prepare_optim_env(cfg):
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.fsdp:
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||
trainer_builder.train_dataset = train_dataset
|
||||
trainer_builder.eval_dataset = eval_dataset
|
||||
|
||||
@@ -2,20 +2,20 @@
|
||||
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def setup_wandb_env_vars(cfg: DictDefault):
|
||||
for key in cfg.keys():
|
||||
if key.startswith("wandb_"):
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value and isinstance(value, str) and len(value) > 0:
|
||||
os.environ[key.upper()] = value
|
||||
|
||||
# Enable wandb if project name is present
|
||||
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
def setup_wandb_env_vars(cfg):
|
||||
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
||||
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
||||
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||
cfg.use_wandb = True
|
||||
os.environ.pop("WANDB_DISABLED", None) # Remove if present
|
||||
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
||||
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||
else:
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "state-spaces/mamba-130m",
|
||||
"model_type": "MambaLMHeadModel",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"tokenizer_config": "EleutherAI/gpt-neox-20b",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": False,
|
||||
"val_set_size": 0.0,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"gradient_checkpointing": False,
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": None,
|
||||
"save_safetensors": False,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
File diff suppressed because one or more lines are too long
@@ -114,76 +114,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_sharegpt_llama(self):
|
||||
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
||||
prompter = ShareGPTPrompterV2(conversation="llama-2")
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
|
||||
def tokenize(conv):
|
||||
return strat.tokenize_prompt(conv)["input_ids"]
|
||||
|
||||
def decode(ids):
|
||||
return strat.tokenizer.decode(ids)
|
||||
|
||||
# Multi-turn conversations
|
||||
multi_turn_conv = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "123"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
}
|
||||
# fmt: off
|
||||
mt_ids = tokenize(multi_turn_conv)
|
||||
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
|
||||
# Single-turn conversations
|
||||
single_turn_conv = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
]
|
||||
}
|
||||
|
||||
st_ids = tokenize(single_turn_conv)
|
||||
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
||||
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, single-turn
|
||||
no_sys_conv = {
|
||||
"conversations": [
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
]
|
||||
}
|
||||
|
||||
ns_ids = tokenize(no_sys_conv)
|
||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, multi-turn
|
||||
no_sys_mt_conv = {
|
||||
"conversations": [
|
||||
{"from": "human", "value": "abc"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "123"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
}
|
||||
ns_mt_ids = tokenize(no_sys_mt_conv)
|
||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
# fmt: on
|
||||
|
||||
def test_sharegpt_changes_roles(self):
|
||||
conversation = {
|
||||
"roles": ["USER", "CHARACTER"],
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
@@ -9,7 +8,6 @@ import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
|
||||
class ValidationTest(unittest.TestCase):
|
||||
@@ -681,83 +679,3 @@ class ValidationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class ValidationWandbTest(ValidationTest):
|
||||
"""
|
||||
Validation test for wandb
|
||||
"""
|
||||
|
||||
def test_wandb_set_run_id_to_name(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_run_id": "foo",
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_name": "foo",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
|
||||
|
||||
def test_wandb_sets_env(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_project": "foo",
|
||||
"wandb_name": "bar",
|
||||
"wandb_run_id": "bat",
|
||||
"wandb_entity": "baz",
|
||||
"wandb_mode": "online",
|
||||
"wandb_watch": "false",
|
||||
"wandb_log_model": "checkpoint",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
||||
assert os.environ.get("WANDB_NAME", "") == "bar"
|
||||
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
|
||||
assert os.environ.get("WANDB_ENTITY", "") == "baz"
|
||||
assert os.environ.get("WANDB_MODE", "") == "online"
|
||||
assert os.environ.get("WANDB_WATCH", "") == "false"
|
||||
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||
|
||||
def test_wandb_set_disabled(self):
|
||||
cfg = DictDefault({})
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_project": "foo",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||
|
||||
Reference in New Issue
Block a user