Compare commits
1 Commits
mixtral_sw
...
multipack-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
effb281b24 |
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||||
pip3 uninstall -y transformers accelerate
|
pip3 uninstall -y transformers accelerate
|
||||||
pip3 install -U -e .[flash-attn,mamba-ssm]
|
pip3 install -U -e .[flash-attn]
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run e2e tests
|
- name: Run e2e tests
|
||||||
|
|||||||
@@ -8,9 +8,6 @@ ignore_missing_imports = True
|
|||||||
[mypy-axolotl.monkeypatch.*]
|
[mypy-axolotl.monkeypatch.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-axolotl.models.mixtral.*]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-axolotl.models.phi.*]
|
[mypy-axolotl.models.phi.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|||||||
44
README.md
44
README.md
@@ -65,21 +65,19 @@ Features:
|
|||||||
|
|
||||||
## Axolotl supports
|
## Axolotl supports
|
||||||
|
|
||||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||||
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
|
||||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
|
||||||
|
|
||||||
|
|
||||||
## Quickstart ⚡
|
## Quickstart ⚡
|
||||||
@@ -247,7 +245,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"instruction": "...", "input": "...", "output": "..."}
|
{"instruction": "...", "input": "...", "output": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
|
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
@@ -614,12 +612,6 @@ eval_sample_packing:
|
|||||||
sample_packing_eff_est:
|
sample_packing_eff_est:
|
||||||
total_num_tokens:
|
total_num_tokens:
|
||||||
|
|
||||||
# Passed through to transformers when loading the model when launched without accelerate
|
|
||||||
# Use `sequential` when training w/ model parallelism to limit memory
|
|
||||||
device_map:
|
|
||||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
|
||||||
max_memory:
|
|
||||||
|
|
||||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||||
adapter: lora
|
adapter: lora
|
||||||
# If you already have a lora model trained that you want to load, put that here.
|
# If you already have a lora model trained that you want to load, put that here.
|
||||||
@@ -667,8 +659,7 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
|
|||||||
wandb_project: # Your wandb project name
|
wandb_project: # Your wandb project name
|
||||||
wandb_entity: # A wandb Team name if using a Team
|
wandb_entity: # A wandb Team name if using a Team
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name: # Set the name of your wandb run
|
wandb_run_id: # Set the name of your wandb run
|
||||||
wandb_run_id: # Set the ID of your wandb run
|
|
||||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
# Where to save the full-finetuned model to
|
||||||
@@ -703,9 +694,6 @@ max_steps:
|
|||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|
||||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
|
||||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
|
||||||
|
|
||||||
# Save model as safetensors (require safetensors package)
|
# Save model as safetensors (require safetensors package)
|
||||||
save_safetensors:
|
save_safetensors:
|
||||||
|
|
||||||
@@ -964,7 +952,7 @@ wandb_mode:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,16 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -28,6 +28,16 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -32,6 +32,16 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ FROM winglian/axolotl:$BASE_TAG
|
|||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
|
||||||
|
|
||||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
output_dir: btlm-out
|
output_dir: btlm-out
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./jeopardy-bot-7b
|
output_dir: ./jeopardy-bot-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ lora_target_linear:
|
|||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./model-out
|
output_dir: ./model-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ relora_cpu_offload: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
base_model: state-spaces/mamba-2.8b
|
|
||||||
model_type: MambaLMHeadModel
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
tokenizer_config: EleutherAI/gpt-neox-20b
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./out
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 2
|
|
||||||
optimizer: paged_adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 5e-5
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: true
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention:
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
eval_steps:
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
save_steps: 0.25
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
tokens:
|
|
||||||
save_safetensors: False
|
|
||||||
@@ -21,7 +21,7 @@ pad_to_sequence_len: true
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -1,79 +0,0 @@
|
|||||||
base_model: DiscoResearch/mixtral-7b-8expert
|
|
||||||
model_type: MixtralForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./qlora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
#lora_target_modules:
|
|
||||||
# - gate
|
|
||||||
# - q_proj
|
|
||||||
# - k_proj
|
|
||||||
# - v_proj
|
|
||||||
# - o_proj
|
|
||||||
# - w1
|
|
||||||
# - w2
|
|
||||||
# - w3
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
eval_steps:
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
save_steps:
|
|
||||||
debug:
|
|
||||||
deepspeed: deepspeed/zero2.json
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -62,9 +62,6 @@ logging_steps: 1
|
|||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 0.05
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: mpt-alpaca-7b
|
wandb_project: mpt-alpaca-7b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./mpt-alpaca-7b
|
output_dir: ./mpt-alpaca-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./openllama-out
|
output_dir: ./openllama-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./pythia-12b
|
output_dir: ./pythia-12b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-alpaca-pythia
|
output_dir: ./lora-alpaca-pythia
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -53,7 +53,7 @@ resume_from_checkpoint:
|
|||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention:
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 0.05
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -53,7 +53,7 @@ resume_from_checkpoint:
|
|||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention:
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 0.05
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: redpajama-alpaca-3b
|
wandb_project: redpajama-alpaca-3b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./redpajama-alpaca-3b
|
output_dir: ./redpajama-alpaca-3b
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project: lora-replit
|
wandb_project: lora-replit
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-replit
|
output_dir: ./lora-replit
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
auto-gptq==0.5.1
|
auto-gptq==0.5.1
|
||||||
packaging
|
packaging
|
||||||
peft==0.6.0
|
peft==0.6.0
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@df5c5c62ae253055336f5bb0828ca8e3e15ab6bd
|
transformers==4.35.2
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.24.1
|
accelerate==0.24.1
|
||||||
|
|||||||
5
setup.py
5
setup.py
@@ -46,13 +46,10 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.3.3",
|
"flash-attn>=2.3.0",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed",
|
"deepspeed",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
|
||||||
"mamba-ssm==1.0.1",
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import prepare_optim_env
|
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -72,7 +71,7 @@ def do_merge_lora(
|
|||||||
|
|
||||||
LOG.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
model.to(dtype=cfg.torch_dtype)
|
model.to(dtype=torch.float16)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||||
@@ -297,8 +296,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
|
||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
|
|||||||
@@ -25,16 +25,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
LossWatchDogCallback,
|
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
|
||||||
MambaDataCollator,
|
|
||||||
)
|
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler
|
from axolotl.utils.samplers import MultipackBatchSampler
|
||||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
@@ -52,9 +48,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
Extend the base TrainingArguments for axolotl helpers
|
Extend the base TrainingArguments for axolotl helpers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "HF model configuration model_type."}
|
|
||||||
)
|
|
||||||
lr_quadratic_warmup: bool = field(
|
lr_quadratic_warmup: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||||
@@ -291,32 +284,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Mamba specific trainer to handle loss calculation
|
|
||||||
"""
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
inputs,
|
|
||||||
return_outputs=False, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
input_ids = inputs.pop("input_ids")
|
|
||||||
lm_logits = model(input_ids).logits
|
|
||||||
|
|
||||||
labels = input_ids.to(lm_logits.device)
|
|
||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
|
||||||
labels = labels[:, 1:].contiguous()
|
|
||||||
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
|
||||||
lm_loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
return lm_loss
|
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
@@ -463,9 +430,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.loss_watchdog_threshold is not None:
|
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -494,8 +458,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return OneCycleLRSchedulerTrainer
|
return OneCycleLRSchedulerTrainer
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
return ReLoRATrainer
|
return ReLoRATrainer
|
||||||
if self.cfg.model_config_type == "mamba":
|
|
||||||
return AxolotlMambaTrainer
|
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
@@ -563,7 +525,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.hub_strategy:
|
if self.cfg.hub_strategy:
|
||||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||||
|
|
||||||
if self.cfg.save_safetensors is not None:
|
if self.cfg.save_safetensors:
|
||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.cfg.sample_packing_eff_est:
|
if self.cfg.sample_packing_eff_est:
|
||||||
@@ -681,7 +643,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
self.cfg.wandb_run_id if self.cfg.use_wandb else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["optim"] = (
|
training_arguments_kwargs["optim"] = (
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||||
@@ -711,7 +673,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
|
||||||
training_args = (
|
training_args = (
|
||||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
@@ -766,7 +727,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=self.build_collator(**data_collator_kwargs),
|
data_collator=BatchSamplerDataCollatorForSeq2Seq(
|
||||||
|
self.tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
**data_collator_kwargs,
|
||||||
|
),
|
||||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@@ -786,13 +751,3 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.micro_batch_size
|
] = self.cfg.micro_batch_size
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def build_collator(self, **kwargs):
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
|
||||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
|
||||||
|
|
||||||
return BatchSamplerDataCollatorForSeq2Seq(
|
|
||||||
self.tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
Modeling module for Mamba models
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def fix_mamba_attn_for_loss():
|
|
||||||
from mamba_ssm.models import mixer_seq_simple
|
|
||||||
|
|
||||||
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
|
||||||
|
|
||||||
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
|
|
||||||
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
"""
|
|
||||||
HF Transformers MambaConfig
|
|
||||||
"""
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
class MambaConfig(PretrainedConfig):
|
|
||||||
"""
|
|
||||||
modeling configuration for state space model/mamba
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_type = "mamba"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=50280,
|
|
||||||
d_model=2560,
|
|
||||||
n_layer=64,
|
|
||||||
rms_norm=True,
|
|
||||||
residual_in_fp32=True,
|
|
||||||
fused_add_norm=True,
|
|
||||||
pad_vocab_size_multiple=8,
|
|
||||||
pad_token_id=50277,
|
|
||||||
bos_token_id=0,
|
|
||||||
eos_token_id=0,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.d_model = d_model
|
|
||||||
self.n_layer = n_layer
|
|
||||||
self.rms_norm = rms_norm
|
|
||||||
self.residual_in_fp32 = residual_in_fp32
|
|
||||||
self.fused_add_norm = fused_add_norm
|
|
||||||
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
|
||||||
super().__init__(
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
import os
|
|
||||||
from collections import namedtuple
|
|
||||||
from functools import partial
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
|
||||||
from mamba_ssm.utils.generation import GenerationMixin
|
|
||||||
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
|
|
||||||
from axolotl.models.mamba.configuration_mamba import MambaConfig
|
|
||||||
|
|
||||||
|
|
||||||
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model: int,
|
|
||||||
n_layer: int,
|
|
||||||
vocab_size: int,
|
|
||||||
initializer_cfg=None,
|
|
||||||
pad_vocab_size_multiple: int = 1,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
**backbone_kwargs,
|
|
||||||
) -> None:
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
super().__init__()
|
|
||||||
if vocab_size % pad_vocab_size_multiple != 0:
|
|
||||||
vocab_size += pad_vocab_size_multiple - (
|
|
||||||
vocab_size % pad_vocab_size_multiple
|
|
||||||
)
|
|
||||||
self.config = MambaConfig(
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
d_model=d_model,
|
|
||||||
n_layer=n_layer,
|
|
||||||
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
|
||||||
)
|
|
||||||
self.backbone = MixerModel(
|
|
||||||
d_model=d_model,
|
|
||||||
n_layer=n_layer,
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
initializer_cfg=initializer_cfg,
|
|
||||||
**backbone_kwargs,
|
|
||||||
**factory_kwargs,
|
|
||||||
)
|
|
||||||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.apply(
|
|
||||||
partial(
|
|
||||||
_init_weights,
|
|
||||||
n_layer=n_layer,
|
|
||||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.tie_weights()
|
|
||||||
|
|
||||||
def tie_weights(self):
|
|
||||||
self.lm_head.weight = self.backbone.embedding.weight
|
|
||||||
|
|
||||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
|
||||||
return self.backbone.allocate_inference_cache(
|
|
||||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
position_ids=None,
|
|
||||||
inference_params=None,
|
|
||||||
num_last_tokens=0,
|
|
||||||
labels=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
|
||||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
|
||||||
"""
|
|
||||||
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
|
||||||
if num_last_tokens > 0:
|
|
||||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
|
||||||
return CausalLMOutput(logits=lm_logits)
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
logits = lm_logits
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
|
|
||||||
print(loss)
|
|
||||||
return CausalLMOutput(logits=lm_logits, loss=loss)
|
|
||||||
|
|
||||||
else:
|
|
||||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
|
||||||
return CausalLMOutput(logits=lm_logits)
|
|
||||||
|
|
||||||
def save_pretrained(
|
|
||||||
self,
|
|
||||||
save_directory: Union[str, os.PathLike],
|
|
||||||
state_dict: Optional[dict] = None,
|
|
||||||
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
if state_dict is None:
|
|
||||||
state_dict = self.state_dict()
|
|
||||||
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
|
||||||
config = load_config_hf(pretrained_model_name)
|
|
||||||
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
|
||||||
model.load_state_dict(
|
|
||||||
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
"""
|
|
||||||
Custom modeling code for mixtral
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .configuration_moe_mistral import MixtralConfig # noqa
|
|
||||||
from .modeling_moe_mistral import ( # noqa
|
|
||||||
MixtralForCausalLM,
|
|
||||||
replace_mixtral_mlp_with_swiglu,
|
|
||||||
)
|
|
||||||
@@ -1,154 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
""" Mistral model configuration"""
|
|
||||||
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|
||||||
"mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
|
|
||||||
"mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralConfig(PretrainedConfig):
|
|
||||||
r"""
|
|
||||||
This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
|
|
||||||
Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
||||||
with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
|
|
||||||
|
|
||||||
[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
|
||||||
[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab_size (`int`, *optional*, defaults to 32000):
|
|
||||||
Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
|
|
||||||
`inputs_ids` passed when calling [`MistralModel`]
|
|
||||||
hidden_size (`int`, *optional*, defaults to 4096):
|
|
||||||
Dimension of the hidden representations.
|
|
||||||
intermediate_size (`int`, *optional*, defaults to 14336):
|
|
||||||
Dimension of the MLP representations.
|
|
||||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
|
||||||
Number of hidden layers in the Transformer encoder.
|
|
||||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
|
||||||
Number of attention heads for each attention layer in the Transformer encoder.
|
|
||||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
|
||||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
||||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
||||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
||||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
||||||
by meanpooling all the original heads within that group. For more details checkout [this
|
|
||||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
|
||||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
|
||||||
The non-linear activation function (function or string) in the decoder.
|
|
||||||
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
|
|
||||||
The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
|
|
||||||
allows sequence of up to 4096*32 tokens.
|
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
||||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
||||||
The epsilon used by the rms normalization layers.
|
|
||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
||||||
relevant if `config.is_decoder=True`.
|
|
||||||
pad_token_id (`int`, *optional*):
|
|
||||||
The id of the padding token.
|
|
||||||
bos_token_id (`int`, *optional*, defaults to 1):
|
|
||||||
The id of the "beginning-of-sequence" token.
|
|
||||||
eos_token_id (`int`, *optional*, defaults to 2):
|
|
||||||
The id of the "end-of-sequence" token.
|
|
||||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether the model's input and output word embeddings should be tied.
|
|
||||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
|
||||||
The base period of the RoPE embeddings.
|
|
||||||
sliding_window (`int`, *optional*, defaults to 4096):
|
|
||||||
Sliding window attention window size. If not specified, will default to `4096`.
|
|
||||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
||||||
The dropout ratio for the attention probabilities.
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import MistralModel, MistralConfig
|
|
||||||
|
|
||||||
>>> # Initializing a Mistral 7B style configuration
|
|
||||||
>>> configuration = MixtralConfig()
|
|
||||||
|
|
||||||
>>> # Initializing a model from the Mistral 7B style configuration
|
|
||||||
>>> model = MixtralModel(configuration)
|
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
|
||||||
>>> configuration = model.config
|
|
||||||
```"""
|
|
||||||
|
|
||||||
model_type = "mistral"
|
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=32000,
|
|
||||||
hidden_size=4096,
|
|
||||||
intermediate_size=14336,
|
|
||||||
num_hidden_layers=32,
|
|
||||||
num_attention_heads=32,
|
|
||||||
num_key_value_heads=8,
|
|
||||||
hidden_act="silu",
|
|
||||||
max_position_embeddings=4096 * 32,
|
|
||||||
initializer_range=0.02,
|
|
||||||
rms_norm_eps=1e-6,
|
|
||||||
use_cache=True,
|
|
||||||
pad_token_id=None,
|
|
||||||
bos_token_id=1,
|
|
||||||
eos_token_id=2,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
rope_theta=10000.0,
|
|
||||||
attention_dropout=0.0,
|
|
||||||
num_experts_per_token=2,
|
|
||||||
num_experts=8,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
|
|
||||||
# for backward compatibility
|
|
||||||
if num_key_value_heads is None:
|
|
||||||
num_key_value_heads = num_attention_heads
|
|
||||||
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.rms_norm_eps = rms_norm_eps
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
self.num_experts = num_experts
|
|
||||||
self.num_experts_per_token = num_experts_per_token
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
super().__init__(
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,7 @@ register_conv_template(
|
|||||||
system_message="You are a helpful assistant.",
|
system_message="You are a helpful assistant.",
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>",
|
sep="<|im_end|>\n",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -82,8 +82,7 @@ def train(
|
|||||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(model, "config"):
|
model.config.use_cache = False
|
||||||
model.config.use_cache = False
|
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
@@ -93,8 +92,7 @@ def train(
|
|||||||
if not Path(cfg.output_dir).is_dir():
|
if not Path(cfg.output_dir).is_dir():
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
if hasattr(model, "config"):
|
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
|
||||||
|
|
||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
|
|||||||
@@ -124,36 +124,6 @@ class GPUStatsCallback(
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class LossWatchDogCallback(TrainerCallback):
|
|
||||||
"""Callback to track loss and stop training if loss is too high"""
|
|
||||||
|
|
||||||
def __init__(self, cfg):
|
|
||||||
self.cfg = cfg
|
|
||||||
self.logged = False
|
|
||||||
self.violations = 0
|
|
||||||
self.threshold = cfg.loss_watchdog_threshold
|
|
||||||
self.patience = cfg.loss_watchdog_patience or 3
|
|
||||||
|
|
||||||
def on_step_end(
|
|
||||||
self,
|
|
||||||
_args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
**_kwargs,
|
|
||||||
):
|
|
||||||
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
|
||||||
if state.log_history[-1]["loss"] > self.threshold:
|
|
||||||
self.violations += 1
|
|
||||||
if self.violations >= self.patience:
|
|
||||||
LOG.warning(
|
|
||||||
"Loss is too high, stopping training (loss_watchdog_threshold)"
|
|
||||||
)
|
|
||||||
control.should_training_stop = True
|
|
||||||
else:
|
|
||||||
self.violations = 0
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
def bench_eval_callback_factory(trainer, tokenizer):
|
def bench_eval_callback_factory(trainer, tokenizer):
|
||||||
accuracy = evaluate.load("accuracy")
|
accuracy = evaluate.load("accuracy")
|
||||||
abcd_idx = [
|
abcd_idx = [
|
||||||
|
|||||||
@@ -2,16 +2,12 @@
|
|||||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Sequence, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -150,31 +146,3 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
chunked_data[feature] = np.concatenate(arrays)
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
features = [chunked_data]
|
features = [chunked_data]
|
||||||
return super().__call__(features, return_tensors=return_tensors)
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MambaDataCollator:
|
|
||||||
"""
|
|
||||||
Collator for State Space Models (Mamba)
|
|
||||||
"""
|
|
||||||
|
|
||||||
tokenizer: transformers.PreTrainedTokenizer
|
|
||||||
|
|
||||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
|
||||||
input_ids, labels = tuple(
|
|
||||||
[torch.LongTensor(instance[key]) for instance in instances]
|
|
||||||
for key in ("input_ids", "labels")
|
|
||||||
)
|
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
||||||
input_ids,
|
|
||||||
batch_first=True,
|
|
||||||
padding_value=self.tokenizer.pad_token_id,
|
|
||||||
)
|
|
||||||
labels = torch.nn.utils.rnn.pad_sequence(
|
|
||||||
labels, batch_first=True, padding_value=IGNORE_INDEX
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"labels": labels,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def choose_device(cfg):
|
|||||||
|
|
||||||
cfg.device = get_device()
|
cfg.device = get_device()
|
||||||
if cfg.world_size == 1:
|
if cfg.world_size == 1:
|
||||||
cfg.device_map = cfg.device_map or "auto"
|
cfg.device_map = "auto"
|
||||||
else:
|
else:
|
||||||
if cfg.device.startswith("cuda"):
|
if cfg.device.startswith("cuda"):
|
||||||
cfg.device_map = {"": torch.cuda.current_device()}
|
cfg.device_map = {"": torch.cuda.current_device()}
|
||||||
@@ -397,13 +397,6 @@ def validate_config(cfg):
|
|||||||
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
|
||||||
cfg.wandb_name = cfg.wandb_run_id
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -698,6 +698,24 @@ def get_dataset_wrapper(
|
|||||||
return dataset_wrapper, dataset_prompter
|
return dataset_wrapper, dataset_prompter
|
||||||
|
|
||||||
|
|
||||||
|
def encode_packed_pretraining(
|
||||||
|
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
||||||
|
):
|
||||||
|
# tokenize all the examples
|
||||||
|
# rows get split with stride (overlap)
|
||||||
|
res = tokenizer(
|
||||||
|
examples,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_tokens,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_overflowing_tokens=True,
|
||||||
|
stride=256,
|
||||||
|
)
|
||||||
|
# convert to a dataset.from_list
|
||||||
|
# use a dataloader and multipack batch sampler to pack the data
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
def encode_pretraining(
|
||||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
||||||
) -> Dict[str, List]:
|
) -> Dict[str, List]:
|
||||||
@@ -813,6 +831,7 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
|||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
encode,
|
encode,
|
||||||
batched=True,
|
batched=True,
|
||||||
|
batch_size=10_000,
|
||||||
input_columns="text",
|
input_columns="text",
|
||||||
# remove all the existing columns after mapping since they end up having
|
# remove all the existing columns after mapping since they end up having
|
||||||
# a different length than the encoded/tokenized column
|
# a different length than the encoded/tokenized column
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import math
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple # noqa: F401
|
from typing import Optional, Tuple # noqa: F401
|
||||||
|
|
||||||
import addict
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -22,7 +21,6 @@ from transformers import ( # noqa: F401
|
|||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -30,56 +28,16 @@ from axolotl.utils.dict import DictDefault
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|
||||||
quant_config_exists = hasattr(model_config, "quantization_config")
|
|
||||||
quant_config_method_is_gptq = (
|
|
||||||
quant_config_exists
|
|
||||||
and "quant_method" in model_config.quantization_config
|
|
||||||
and model_config.quantization_config["quant_method"] == "gptq"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.gptq and not quant_config_method_is_gptq:
|
|
||||||
raise ValueError(
|
|
||||||
"model_config.quantization_config is not set or quant_method is not set to gptq. "
|
|
||||||
"Please make sure to point to a GPTQ model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not cfg.gptq and quant_config_exists:
|
|
||||||
raise ValueError(
|
|
||||||
"model_config.quantization_config is set but `gptq` flag is not. "
|
|
||||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(cfg):
|
def load_model_config(cfg):
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
model_config_name = cfg.base_model_config or cfg.base_model
|
||||||
trust_remote_code = cfg.trust_remote_code is True
|
trust_remote_code = cfg.trust_remote_code is True
|
||||||
model_type = cfg.model_type
|
model_config = AutoConfig.from_pretrained(
|
||||||
|
model_config_name, trust_remote_code=trust_remote_code
|
||||||
if model_type == "MixtralForCausalLM":
|
)
|
||||||
from axolotl.models.mixtral.configuration_moe_mistral import MixtralConfig
|
|
||||||
|
|
||||||
model_config = MixtralConfig.from_pretrained(model_config_name)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
model_config = AutoConfig.from_pretrained(
|
|
||||||
model_config_name, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
except ValueError as err:
|
|
||||||
if "mamba" in model_config_name:
|
|
||||||
return addict.Dict(
|
|
||||||
{
|
|
||||||
"model_type": "mamba",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
raise err
|
|
||||||
|
|
||||||
if cfg.model_config:
|
if cfg.model_config:
|
||||||
for key, val in cfg.model_config.items():
|
for key, val in cfg.model_config.items():
|
||||||
setattr(model_config, key, val)
|
setattr(model_config, key, val)
|
||||||
|
|
||||||
check_model_config(cfg, model_config)
|
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|
||||||
@@ -111,7 +69,6 @@ def load_tokenizer(cfg):
|
|||||||
"LlamaTokenizer",
|
"LlamaTokenizer",
|
||||||
"LlamaTokenizerFast",
|
"LlamaTokenizerFast",
|
||||||
"CodeLlamaTokenizer",
|
"CodeLlamaTokenizer",
|
||||||
"CodeLlamaTokenizerFast",
|
|
||||||
]
|
]
|
||||||
and hasattr(tokenizer, "pad_token")
|
and hasattr(tokenizer, "pad_token")
|
||||||
and not tokenizer.pad_token
|
and not tokenizer.pad_token
|
||||||
@@ -144,23 +101,6 @@ def load_tokenizer(cfg):
|
|||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we add bos_token and eos_token, we need to update the post processor to
|
|
||||||
# handle them correctly.
|
|
||||||
# https://github.com/huggingface/transformers/pull/24132
|
|
||||||
bos_or_eos_in_special_tokens = (
|
|
||||||
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
tokenizer.__class__.__name__
|
|
||||||
in (
|
|
||||||
"LlamaTokenizerFast",
|
|
||||||
"CodeLlamaTokenizerFast",
|
|
||||||
)
|
|
||||||
and bos_or_eos_in_special_tokens
|
|
||||||
):
|
|
||||||
tokenizer.update_post_processor()
|
|
||||||
|
|
||||||
if cfg.tokens:
|
if cfg.tokens:
|
||||||
tokenizer.add_tokens(
|
tokenizer.add_tokens(
|
||||||
[
|
[
|
||||||
@@ -276,7 +216,6 @@ def load_model(
|
|||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
model_kwargs["device_map"] = cfg.device_map
|
model_kwargs["device_map"] = cfg.device_map
|
||||||
model_kwargs["max_memory"] = cfg.max_memory
|
|
||||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||||
|
|
||||||
if cfg.model_revision:
|
if cfg.model_revision:
|
||||||
@@ -308,9 +247,7 @@ def load_model(
|
|||||||
or cfg.is_falcon_derived_model
|
or cfg.is_falcon_derived_model
|
||||||
or cfg.is_mistral_derived_model
|
or cfg.is_mistral_derived_model
|
||||||
):
|
):
|
||||||
# TODO enable once properly supported in transformers
|
model_kwargs["use_flash_attention_2"] = True
|
||||||
# model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
model_kwargs["use_flash_attention_2"] = True # legacy, to be deprecated
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||||
@@ -372,37 +309,6 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif model_type == "MixtralForCausalLM":
|
|
||||||
from axolotl.models.mixtral import (
|
|
||||||
MixtralForCausalLM,
|
|
||||||
replace_mixtral_mlp_with_swiglu,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = MixtralForCausalLM.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.flash_attn_fuse_mlp:
|
|
||||||
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
|
|
||||||
replace_mixtral_mlp_with_swiglu(model)
|
|
||||||
|
|
||||||
elif model_type == "MambaLMHeadModel":
|
|
||||||
# FIXME this is janky at best and hacked together to make it work
|
|
||||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
|
|
||||||
model_kwargs["device"] = torch.cuda.current_device()
|
|
||||||
del model_kwargs["torch_dtype"]
|
|
||||||
del model_kwargs["device_map"]
|
|
||||||
del model_kwargs["max_memory"]
|
|
||||||
|
|
||||||
model = MambaLMHeadModel.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
elif model_type and not cfg.trust_remote_code:
|
elif model_type and not cfg.trust_remote_code:
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -462,17 +368,13 @@ def load_model(
|
|||||||
if cfg.resize_token_embeddings_to_32x
|
if cfg.resize_token_embeddings_to_32x
|
||||||
else len(tokenizer)
|
else len(tokenizer)
|
||||||
)
|
)
|
||||||
if (
|
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||||
hasattr(model, "get_input_embeddings")
|
|
||||||
and model.get_input_embeddings().num_embeddings < embeddings_len
|
|
||||||
):
|
|
||||||
model.resize_token_embeddings(embeddings_len)
|
model.resize_token_embeddings(embeddings_len)
|
||||||
else:
|
else:
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model, "config")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
and hasattr(model.config, "max_position_embeddings")
|
|
||||||
and model.config.max_position_embeddings
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len > model.config.max_position_embeddings
|
and cfg.sequence_len > model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
@@ -482,22 +384,20 @@ def load_model(
|
|||||||
model.config.max_position_embeddings = cfg.sequence_len
|
model.config.max_position_embeddings = cfg.sequence_len
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model, "config")
|
hasattr(model.config, "bos_token_id")
|
||||||
and hasattr(model.config, "bos_token_id")
|
|
||||||
and model.config.bos_token_id
|
and model.config.bos_token_id
|
||||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
and model.config.bos_token_id != tokenizer.bos_token_id
|
||||||
):
|
):
|
||||||
model.config.bos_token_id = tokenizer.bos_token_id
|
model.config.bos_token_id = tokenizer.bos_token_id
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model, "config")
|
hasattr(model.config, "eos_token_id")
|
||||||
and hasattr(model.config, "eos_token_id")
|
|
||||||
and model.config.eos_token_id
|
and model.config.eos_token_id
|
||||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
and model.config.eos_token_id != tokenizer.eos_token_id
|
||||||
):
|
):
|
||||||
model.config.eos_token_id = tokenizer.eos_token_id
|
model.config.eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
if hasattr(model, "device") and model.device.type == "cuda":
|
if model.device.type == "cuda":
|
||||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||||
|
|
||||||
# make sure these are fp32 per Ramesh et al. (2021)
|
# make sure these are fp32 per Ramesh et al. (2021)
|
||||||
@@ -512,22 +412,15 @@ def load_model(
|
|||||||
module.to(torch.float32)
|
module.to(torch.float32)
|
||||||
|
|
||||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
skip_prepare_model_for_kbit_training = False
|
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
|
||||||
skip_prepare_model_for_kbit_training = True
|
|
||||||
|
|
||||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||||
):
|
):
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
if not skip_prepare_model_for_kbit_training:
|
model = prepare_model_for_kbit_training(
|
||||||
model = prepare_model_for_kbit_training(
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
)
|
||||||
)
|
|
||||||
needs_fa2_dtype = True
|
needs_fa2_dtype = True
|
||||||
|
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
@@ -556,8 +449,7 @@ def load_model(
|
|||||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||||
if len(requires_grad) == 0:
|
if len(requires_grad) == 0:
|
||||||
LOG.warning("there are no parameters that require gradient updates")
|
LOG.warning("there are no parameters that require gradient updates")
|
||||||
if hasattr(model, "config"):
|
model.config.use_cache = False
|
||||||
model.config.use_cache = False
|
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.transform(model)
|
model = BetterTransformer.transform(model)
|
||||||
|
|||||||
@@ -131,10 +131,8 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Phi doesn't want the attention_mask feature when training
|
# Phi doesn't want the attention_mask feature when training
|
||||||
if (
|
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
||||||
"CodeGenTokenizer" in tokenizer.__class__.__name__
|
cfg.is_mistral_derived_model and cfg.flash_attention
|
||||||
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
|
||||||
or cfg.model_config_type == "mamba"
|
|
||||||
):
|
):
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
@@ -155,9 +153,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
if update:
|
if update:
|
||||||
cfg.total_num_tokens = total_num_tokens
|
cfg.total_num_tokens = total_num_tokens
|
||||||
|
|
||||||
skip_estimates = cfg.model_config_type == "mamba"
|
if not cfg.total_supervised_tokens:
|
||||||
|
|
||||||
if not skip_estimates and not cfg.total_supervised_tokens:
|
|
||||||
total_supervised_tokens = (
|
total_supervised_tokens = (
|
||||||
train_dataset.data.column("labels")
|
train_dataset.data.column("labels")
|
||||||
.to_pandas()
|
.to_pandas()
|
||||||
@@ -171,7 +167,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
if update:
|
if update:
|
||||||
cfg.total_supervised_tokens = total_supervised_tokens
|
cfg.total_supervised_tokens = total_supervised_tokens
|
||||||
|
|
||||||
if not skip_estimates and cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
# we have to drop anything longer then sequence len otherwise
|
# we have to drop anything longer then sequence len otherwise
|
||||||
# flash attention with position ids fails
|
# flash attention with position ids fails
|
||||||
|
|
||||||
@@ -271,14 +267,12 @@ def setup_fsdp_envs(cfg):
|
|||||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||||
|
|
||||||
|
|
||||||
def prepare_optim_env(cfg):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
|
||||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||||
trainer_builder.train_dataset = train_dataset
|
trainer_builder.train_dataset = train_dataset
|
||||||
trainer_builder.eval_dataset = eval_dataset
|
trainer_builder.eval_dataset = eval_dataset
|
||||||
|
|||||||
@@ -2,20 +2,20 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
def setup_wandb_env_vars(cfg):
|
||||||
def setup_wandb_env_vars(cfg: DictDefault):
|
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
||||||
for key in cfg.keys():
|
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
||||||
if key.startswith("wandb_"):
|
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||||
value = cfg.get(key, "")
|
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||||
|
|
||||||
if value and isinstance(value, str) and len(value) > 0:
|
|
||||||
os.environ[key.upper()] = value
|
|
||||||
|
|
||||||
# Enable wandb if project name is present
|
|
||||||
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
|
||||||
cfg.use_wandb = True
|
cfg.use_wandb = True
|
||||||
os.environ.pop("WANDB_DISABLED", None) # Remove if present
|
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
||||||
|
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
||||||
|
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||||
|
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||||
|
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||||
|
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||||
|
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||||
|
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||||
else:
|
else:
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for lora llama
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMistral(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test case for Llama models using LoRA
|
|
||||||
"""
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_fft(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "state-spaces/mamba-130m",
|
|
||||||
"model_type": "MambaLMHeadModel",
|
|
||||||
"tokenizer_type": "AutoTokenizer",
|
|
||||||
"tokenizer_config": "EleutherAI/gpt-neox-20b",
|
|
||||||
"flash_attention": False,
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": False,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"gradient_checkpointing": False,
|
|
||||||
"num_epochs": 2,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 20,
|
|
||||||
"save_steps": 10,
|
|
||||||
"eval_steps": None,
|
|
||||||
"save_safetensors": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Module for testing the validation module"""
|
"""Module for testing the validation module"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -9,7 +8,6 @@ import pytest
|
|||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationTest(unittest.TestCase):
|
class ValidationTest(unittest.TestCase):
|
||||||
@@ -681,83 +679,3 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
|
||||||
class ValidationWandbTest(ValidationTest):
|
|
||||||
"""
|
|
||||||
Validation test for wandb
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_wandb_set_run_id_to_name(self):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"wandb_run_id": "foo",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
validate_config(cfg)
|
|
||||||
assert any(
|
|
||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
|
||||||
in record.message
|
|
||||||
for record in self._caplog.records
|
|
||||||
)
|
|
||||||
|
|
||||||
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"wandb_name": "foo",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
|
|
||||||
|
|
||||||
def test_wandb_sets_env(self):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"wandb_project": "foo",
|
|
||||||
"wandb_name": "bar",
|
|
||||||
"wandb_run_id": "bat",
|
|
||||||
"wandb_entity": "baz",
|
|
||||||
"wandb_mode": "online",
|
|
||||||
"wandb_watch": "false",
|
|
||||||
"wandb_log_model": "checkpoint",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
|
||||||
|
|
||||||
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
|
||||||
assert os.environ.get("WANDB_NAME", "") == "bar"
|
|
||||||
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
|
|
||||||
assert os.environ.get("WANDB_ENTITY", "") == "baz"
|
|
||||||
assert os.environ.get("WANDB_MODE", "") == "online"
|
|
||||||
assert os.environ.get("WANDB_WATCH", "") == "false"
|
|
||||||
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
|
||||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
|
||||||
|
|
||||||
def test_wandb_set_disabled(self):
|
|
||||||
cfg = DictDefault({})
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
|
||||||
|
|
||||||
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"wandb_project": "foo",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
|
||||||
|
|
||||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
|
||||||
|
|||||||
Reference in New Issue
Block a user