Compare commits

..

1 Commits

Author SHA1 Message Date
mhenrichsen
9084879861 tinyllama 2023-11-16 13:36:01 +00:00
69 changed files with 167 additions and 3895 deletions

View File

@@ -71,9 +71,8 @@ jobs:
- name: Install dependencies
run: |
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn,mamba-ssm]
pip3 install -U -e .[flash-attn]
pip3 install -r requirements-tests.txt
- name: Run e2e tests

View File

@@ -8,9 +8,6 @@ ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*]
ignore_errors = True
[mypy-axolotl.models.mixtral.*]
ignore_errors = True
[mypy-axolotl.models.phi.*]
ignore_errors = True

View File

@@ -65,21 +65,18 @@ Features:
## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | | | | |
| Mixtral-MoE | ✅ | ✅ | ✅ | | | | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | | | ❌ | ❌ | | ❓ |
| falcon | ✅ | | ✅ | | | | |
| gpt-j | ✅ | ✅ | ✅ | | | ❓ | ❓ |
| XGen | ✅ | ❓ | | ❓ | ❓ | ❓ | |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | | | | |
| cerebras | ✅ | ✅ | ✅ | | | | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| gpt-j | ✅ | | | ❌ | ❌ | | ❓ |
| XGen | ✅ | | ✅ | | | | |
| phi | ✅ | ✅ | ✅ | | | ❓ | ❓ |
| RWKV | ✅ | ❓ | | ❓ | ❓ | ❓ | |
## Quickstart ⚡
@@ -88,19 +85,14 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
**Requirements**: Python >=3.9 and Pytorch >=2.0.
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
### For developers
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]'
```
pip3 install -U git+https://github.com/huggingface/peft.git
### Usage
```bash
# finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
@@ -247,7 +239,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
- `sharegpt`: conversations where `from` is `human`/`gpt`
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -502,7 +494,6 @@ is_falcon_derived_model:
is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model:
is_qwen_derived_model:
# optional overrides to the base model configuration
model_config:
@@ -547,8 +538,6 @@ datasets:
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Custom user prompt
- path: repo
@@ -614,12 +603,6 @@ eval_sample_packing:
sample_packing_eff_est:
total_num_tokens:
# Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory
device_map:
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
max_memory:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
# If you already have a lora model trained that you want to load, put that here.
@@ -667,8 +650,7 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
wandb_project: # Your wandb project name
wandb_entity: # A wandb Team name if using a Team
wandb_watch:
wandb_name: # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
wandb_run_id: # Set the name of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# Where to save the full-finetuned model to
@@ -686,8 +668,7 @@ gradient_accumulation_steps: 1
micro_batch_size: 2
eval_batch_size:
num_epochs: 4
warmup_steps: 100 # cannot use with warmup_ratio
warmup_ratio: 0.05 # cannot use with warmup_steps
warmup_steps: 100
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
@@ -703,9 +684,6 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package)
save_safetensors:
@@ -964,7 +942,7 @@ wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
```

View File

@@ -24,6 +24,16 @@
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",

View File

@@ -28,6 +28,16 @@
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",

View File

@@ -32,6 +32,16 @@
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",

View File

@@ -4,7 +4,6 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh

View File

@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: btlm-out

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
batch_size: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./falcon-7b
batch_size: 2

View File

@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out

View File

@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./falcon-7b
batch_size: 2

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
gradient_accumulation_steps: 2

View File

@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1

View File

@@ -32,7 +32,7 @@ lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./model-out
gradient_accumulation_steps: 1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -35,7 +35,7 @@ relora_cpu_offload: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -4,19 +4,20 @@ model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
- path: mhenrichsen/context-aware-splits-english
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
val_set_size: 200
output_dir: ./tiny-llama
sequence_len: 4096
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
@@ -29,12 +30,12 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
@@ -53,13 +54,13 @@ logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
warmup_steps: 50
eval_steps: 0.05
eval_table_size:
save_steps:
save_steps: 0.50
debug:
deepspeed:
weight_decay: 0.0
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,61 +0,0 @@
base_model: state-spaces/mamba-2.8b
model_type: MambaLMHeadModel
tokenizer_type: AutoTokenizer
tokenizer_config: EleutherAI/gpt-neox-20b
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./out
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 10
eval_steps:
eval_table_size:
eval_table_max_new_tokens: 128
save_steps: 0.25
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
tokens:
save_safetensors: False

View File

@@ -21,7 +21,7 @@ pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -1,79 +0,0 @@
base_model: DiscoResearch/mixtral-7b-8expert
model_type: MixtralForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
#lora_target_modules:
# - gate
# - q_proj
# - k_proj
# - v_proj
# - o_proj
# - w1
# - w2
# - w3
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
eval_steps:
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed: deepspeed/zero2.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -38,7 +38,7 @@ lora_target_modules:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
@@ -62,9 +62,6 @@ logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
eval_steps: 0.05
eval_table_size:

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./openllama-out
gradient_accumulation_steps: 1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-out
gradient_accumulation_steps: 1

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
gradient_accumulation_steps: 1

View File

@@ -1,5 +1,5 @@
base_model: microsoft/phi-1_5
model_type: PhiForCausalLM
model_type: MixFormerSequentialForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false
trust_remote_code: true
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./pythia-12b
gradient_accumulation_steps: 1

View File

@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1

View File

@@ -1,68 +0,0 @@
base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 2048 # supports up to 8192
sample_packing: false
pad_to_sequence_len:
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,68 +0,0 @@
base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 2048 # supports up to 8192
sample_packing: false
pad_to_sequence_len:
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./redpajama-alpaca-3b
batch_size: 4

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project: lora-replit
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-replit
batch_size: 8

View File

@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out

View File

@@ -1,21 +1,22 @@
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
auto-gptq==0.5.1
torch==2.0.1
auto-gptq==0.4.2
packaging
peft==0.6.0
transformers @ git+https://github.com/huggingface/transformers.git@df5c5c62ae253055336f5bb0828ca8e3e15ab6bd
tokenizers==0.15.0
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
bitsandbytes>=0.41.1
accelerate==0.24.1
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
deepspeed
addict
fire
PyYAML>=6.0
datasets>=2.15.0
flash-attn==2.3.3
datasets>=2.14.0
flash-attn>=2.3.0
sentencepiece
wandb
einops
xformers==0.0.22
xformers>=0.0.22
optimum==1.13.2
hf_transfer
colorama
@@ -30,7 +31,7 @@ scikit-learn==1.2.2
pynvml
art
fschat==0.2.29
gradio==3.50.2
gradio
tensorboard
# remote filesystems

View File

@@ -46,13 +46,10 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.3.3",
"flash-attn>=2.3.0",
],
"deepspeed": [
"deepspeed",
],
"mamba-ssm": [
"mamba-ssm==1.0.1",
],
},
)

View File

@@ -23,13 +23,12 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import add_defaults, normalize_config, validate_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -72,7 +71,7 @@ def do_merge_lora(
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=cfg.torch_dtype)
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
@@ -297,12 +296,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
validate_config(cfg)
prepare_optim_env(cfg)
normalize_config(cfg)
add_defaults(cfg)
setup_wandb_env_vars(cfg)
return cfg

View File

@@ -25,16 +25,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
MambaDataCollator,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
@@ -52,9 +48,6 @@ class AxolotlTrainingArguments(TrainingArguments):
Extend the base TrainingArguments for axolotl helpers
"""
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
@@ -291,32 +284,6 @@ class AxolotlTrainer(Trainer):
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""
def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
@@ -463,9 +430,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -494,19 +458,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
return AxolotlTrainer
def build(self, total_num_steps):
warmup_steps = None
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio is not None:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_steps = min(int(0.03 * total_num_steps), 100)
warmup_steps = (
self.cfg.warmup_steps
if self.cfg.warmup_steps is not None
else min(int(0.03 * total_num_steps), 100)
)
logging_steps = (
self.cfg.logging_steps
if self.cfg.logging_steps is not None
@@ -563,7 +522,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors is not None:
if self.cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est:
@@ -681,7 +640,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None
self.cfg.wandb_run_id if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
@@ -699,9 +658,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing
if self.cfg.eval_sample_packing is not False
else False
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
@@ -711,7 +668,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
@@ -766,7 +722,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=self.build_collator(**data_collator_kwargs),
data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
@@ -786,13 +746,3 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.micro_batch_size
return trainer
def build_collator(self, **kwargs):
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
)

View File

@@ -1,12 +0,0 @@
"""
Modeling module for Mamba models
"""
def fix_mamba_attn_for_loss():
from mamba_ssm.models import mixer_seq_simple
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name

View File

@@ -1,42 +0,0 @@
"""
HF Transformers MambaConfig
"""
from transformers import PretrainedConfig
class MambaConfig(PretrainedConfig):
"""
modeling configuration for state space model/mamba
"""
model_type = "mamba"
def __init__(
self,
vocab_size=50280,
d_model=2560,
n_layer=64,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
pad_vocab_size_multiple=8,
pad_token_id=50277,
bos_token_id=0,
eos_token_id=0,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layer = n_layer
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@@ -1,128 +0,0 @@
# pylint: skip-file
import os
from collections import namedtuple
from functools import partial
from typing import Optional, Union
import torch
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
from torch import nn
from torch.nn import CrossEntropyLoss
from axolotl.models.mamba.configuration_mamba import MambaConfig
class MambaLMHeadModel(nn.Module, GenerationMixin):
def __init__(
self,
d_model: int,
n_layer: int,
vocab_size: int,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
device=None,
dtype=None,
**backbone_kwargs,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if vocab_size % pad_vocab_size_multiple != 0:
vocab_size += pad_vocab_size_multiple - (
vocab_size % pad_vocab_size_multiple
)
self.config = MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layer,
pad_vocab_size_multiple=pad_vocab_size_multiple,
)
self.backbone = MixerModel(
d_model=d_model,
n_layer=n_layer,
vocab_size=vocab_size,
initializer_cfg=initializer_cfg,
**backbone_kwargs,
**factory_kwargs,
)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
# Initialize weights and apply final processing
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
self.tie_weights()
def tie_weights(self):
self.lm_head.weight = self.backbone.embedding.weight
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(
self,
input_ids,
position_ids=None,
inference_params=None,
num_last_tokens=0,
labels=None,
**kwargs,
):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.backbone(input_ids, inference_params=inference_params)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
loss = None
if labels is not None:
logits = lm_logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
print(loss)
return CausalLMOutput(logits=lm_logits, loss=loss)
else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
):
if state_dict is None:
state_dict = self.state_dict()
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
config = load_config_hf(pretrained_model_name)
model = cls(**config, device=device, dtype=dtype, **kwargs)
model.load_state_dict(
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
)
return model

View File

@@ -1,6 +0,0 @@
"""
Custom modeling code for mixtral
"""
from .configuration_moe_mistral import MixtralConfig # noqa
from .modeling_moe_mistral import MixtralForCausalLM # noqa

View File

@@ -1,154 +0,0 @@
# coding=utf-8
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mistral model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
"mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
}
class MixtralConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MistralModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention window size. If not specified, will default to `4096`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import MistralModel, MistralConfig
>>> # Initializing a Mistral 7B style configuration
>>> configuration = MixtralConfig()
>>> # Initializing a model from the Mistral 7B style configuration
>>> model = MixtralModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mistral"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
attention_dropout=0.0,
num_experts_per_token=2,
num_experts=8,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts = num_experts
self.num_experts_per_token = num_experts_per_token
# pylint: disable=duplicate-code
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,4 @@ MixFormers model architecture used for phi models
"""
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
from .configuration_phi import PhiConfig # noqa
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
from .modeling_phi import PhiForCausalLM # noqa

View File

@@ -1,65 +0,0 @@
# pylint: skip-file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import Optional
from transformers import PretrainedConfig
class PhiConfig(PretrainedConfig):
"""Phi configuration."""
model_type = "phi"
attribute_map = {
"max_position_embeddings": "n_positions",
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
}
def __init__(
self,
vocab_size: int = 50304,
n_positions: int = 2048,
n_embd: int = 1024,
n_layer: int = 20,
n_inner: Optional[int] = None,
n_head: int = 16,
n_head_kv: Optional[int] = None,
rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new",
flash_attn: bool = False,
flash_rotary: bool = False,
fused_dense: bool = False,
attn_pdrop: float = 0.0,
embd_pdrop: float = 0.0,
resid_pdrop: float = 0.0,
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
tie_word_embeddings: bool = False,
pad_vocab_size_multiple: int = 64,
**kwargs
) -> None:
self.vocab_size = int(
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.n_head_kv = n_head_kv
self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.activation_function = activation_function
self.flash_attn = flash_attn
self.flash_rotary = flash_rotary
self.fused_dense = fused_dense
self.attn_pdrop = attn_pdrop
self.embd_pdrop = embd_pdrop
self.resid_pdrop = resid_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -13,7 +13,7 @@ register_conv_template(
system_message="You are a helpful assistant.",
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
sep="<|im_end|>\n",
)
)

View File

@@ -82,8 +82,7 @@ def train(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
if hasattr(model, "config"):
model.config.use_cache = False
model.config.use_cache = False
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
@@ -93,8 +92,7 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:

View File

@@ -124,36 +124,6 @@ class GPUStatsCallback(
return control
class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
_args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
if self.violations >= self.patience:
LOG.warning(
"Loss is too high, stopping training (loss_watchdog_threshold)"
)
control.should_training_stop = True
else:
self.violations = 0
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [

View File

@@ -2,16 +2,12 @@
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Dict, Optional, Sequence, Union
from typing import Any, Optional, Union
import numpy as np
import torch
import transformers
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
IGNORE_INDEX = -100
@dataclass
class DataCollatorForSeq2Seq:
@@ -150,31 +146,3 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return {
"input_ids": input_ids,
"labels": labels,
}

View File

@@ -27,7 +27,7 @@ def choose_device(cfg):
cfg.device = get_device()
if cfg.world_size == 1:
cfg.device_map = cfg.device_map or "auto"
cfg.device_map = "auto"
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
@@ -41,16 +41,6 @@ def choose_device(cfg):
cfg.device_map = None
def add_defaults(cfg):
# setup sane defaults if left unspecified
if cfg.dataloader_num_workers is None:
cfg.dataloader_num_workers = int(os.getenv("WORLD_SIZE", "1"))
if cfg.dataloader_prefetch_factor is None:
cfg.dataloader_prefetch_factor = cfg.batch_size * 2
if cfg.dataloader_pin_memory is None:
cfg.dataloader_pin_memory = True
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
@@ -132,19 +122,6 @@ def normalize_config(cfg):
or (cfg.model_type and "mistral" in cfg.model_type.lower())
)
cfg.is_qwen_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
)
or cfg.is_qwen_derived_model
or "qwen" in cfg.base_model.lower()
or (cfg.model_type and "qwen" in cfg.model_type.lower())
)
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
@@ -188,11 +165,7 @@ def validate_config(cfg):
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if (
cfg.eval_batch_size
and cfg.micro_batch_size
and cfg.eval_batch_size != cfg.micro_batch_size
):
if cfg.eval_batch_size != cfg.micro_batch_size:
LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)
@@ -399,21 +372,6 @@ def validate_config(cfg):
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
LOG.warning(
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
)
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -79,14 +79,6 @@ def prepare_dataset(cfg, tokenizer):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset, tokenizer
)
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
raise ValueError(
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
)
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
@@ -242,14 +234,7 @@ def load_tokenized_prepared_datasets(
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
)
ds = load_from_disk(config_dataset.path)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)

View File

@@ -4,7 +4,6 @@ import math
import os
from typing import Optional, Tuple # noqa: F401
import addict
import bitsandbytes as bnb
import torch
import transformers
@@ -22,7 +21,6 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase,
)
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
@@ -30,56 +28,16 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl")
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
quant_config_exists = hasattr(model_config, "quantization_config")
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
and model_config.quantization_config["quant_method"] == "gptq"
)
if cfg.gptq and not quant_config_method_is_gptq:
raise ValueError(
"model_config.quantization_config is not set or quant_method is not set to gptq. "
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True
model_type = cfg.model_type
if model_type == "MixtralForCausalLM":
from axolotl.models.mixtral.configuration_moe_mistral import MixtralConfig
model_config = MixtralConfig.from_pretrained(model_config_name)
else:
try:
model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
except ValueError as err:
if "mamba" in model_config_name:
return addict.Dict(
{
"model_type": "mamba",
}
)
raise err
model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
return model_config
@@ -111,7 +69,6 @@ def load_tokenizer(cfg):
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
@@ -127,40 +84,11 @@ def load_tokenizer(cfg):
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
if cfg.special_tokens:
for k, val in cfg.special_tokens.items():
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens:
tokenizer.add_tokens(
[
@@ -276,7 +204,6 @@ def load_model(
model_kwargs = {}
model_kwargs["device_map"] = cfg.device_map
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype
if cfg.model_revision:
@@ -308,9 +235,7 @@ def load_model(
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
):
# TODO enable once properly supported in transformers
# model_kwargs["attn_implementation"] = "flash_attention_2"
model_kwargs["use_flash_attention_2"] = True # legacy, to be deprecated
model_kwargs["use_flash_attention_2"] = True
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
@@ -363,38 +288,15 @@ def load_model(
# device=cfg.device,
# )
# model.train() # sets to train instead of eval mode
elif model_type == "PhiForCausalLM":
from axolotl.models.phi import PhiForCausalLM
elif model_type == "MixFormerSequentialForCausalLM":
from axolotl.models.phi import MixFormerSequentialForCausalLM
model = PhiForCausalLM.from_pretrained(
model = MixFormerSequentialForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
elif model_type == "MixtralForCausalLM":
from axolotl.models.mixtral import MixtralForCausalLM
model = MixtralForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
model_kwargs["device"] = torch.cuda.current_device()
del model_kwargs["torch_dtype"]
del model_kwargs["device_map"]
del model_kwargs["max_memory"]
model = MambaLMHeadModel.from_pretrained(
base_model,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
@@ -454,17 +356,13 @@ def load_model(
if cfg.resize_token_embeddings_to_32x
else len(tokenizer)
)
if (
hasattr(model, "get_input_embeddings")
and model.get_input_embeddings().num_embeddings < embeddings_len
):
if model.get_input_embeddings().num_embeddings < embeddings_len:
model.resize_token_embeddings(embeddings_len)
else:
model.tie_weights()
if (
hasattr(model, "config")
and hasattr(model.config, "max_position_embeddings")
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
@@ -474,22 +372,20 @@ def load_model(
model.config.max_position_embeddings = cfg.sequence_len
if (
hasattr(model, "config")
and hasattr(model.config, "bos_token_id")
hasattr(model.config, "bos_token_id")
and model.config.bos_token_id
and model.config.bos_token_id != tokenizer.bos_token_id
):
model.config.bos_token_id = tokenizer.bos_token_id
if (
hasattr(model, "config")
and hasattr(model.config, "eos_token_id")
hasattr(model.config, "eos_token_id")
and model.config.eos_token_id
and model.config.eos_token_id != tokenizer.eos_token_id
):
model.config.eos_token_id = tokenizer.eos_token_id
if hasattr(model, "device") and model.device.type == "cuda":
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)
@@ -504,22 +400,15 @@ def load_model(
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
if not skip_prepare_model_for_kbit_training:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
@@ -548,8 +437,7 @@ def load_model(
requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates")
if hasattr(model, "config"):
model.config.use_cache = False
model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)

View File

@@ -182,7 +182,7 @@ class MultipackBatchSampler(BatchSampler):
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return max(
0,
1,
(
world_size
* math.floor(

View File

@@ -131,10 +131,8 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
)
# Phi doesn't want the attention_mask feature when training
if (
"CodeGenTokenizer" in tokenizer.__class__.__name__
or (cfg.is_mistral_derived_model and cfg.flash_attention)
or cfg.model_config_type == "mamba"
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
cfg.is_mistral_derived_model and cfg.flash_attention
):
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
@@ -143,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
return train_dataset, eval_dataset
def calculate_total_num_steps(cfg, train_dataset, update=True):
def calculate_total_num_steps(cfg, train_dataset):
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
@@ -152,12 +150,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
if update:
cfg.total_num_tokens = total_num_tokens
cfg.total_num_tokens = total_num_tokens
skip_estimates = cfg.model_config_type == "mamba"
if not skip_estimates and not cfg.total_supervised_tokens:
if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
@@ -168,10 +163,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
if update:
cfg.total_supervised_tokens = total_supervised_tokens
cfg.total_supervised_tokens = total_supervised_tokens
if not skip_estimates and cfg.sample_packing:
if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
@@ -238,8 +232,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
)
if update:
cfg.sample_packing_eff_est = sample_packing_eff_est
cfg.sample_packing_eff_est = sample_packing_eff_est
LOG.debug(
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
main_process_only=True,
@@ -271,14 +264,12 @@ def setup_fsdp_envs(cfg):
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

View File

@@ -2,20 +2,20 @@
import os
from axolotl.utils.dict import DictDefault
def setup_wandb_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("wandb_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
# Enable wandb if project name is present
if cfg.wandb_project and len(cfg.wandb_project) > 0:
def setup_wandb_env_vars(cfg):
if cfg.wandb_mode and cfg.wandb_mode == "offline":
os.environ["WANDB_MODE"] = cfg.wandb_mode
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True
os.environ.pop("WANDB_DISABLED", None) # Remove if present
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
else:
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -1,65 +0,0 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMistral(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_fft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "state-spaces/mamba-130m",
"model_type": "MambaLMHeadModel",
"tokenizer_type": "AutoTokenizer",
"tokenizer_config": "EleutherAI/gpt-neox-20b",
"flash_attention": False,
"sequence_len": 1024,
"load_in_8bit": False,
"val_set_size": 0.0,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"gradient_checkpointing": False,
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": None,
"save_safetensors": False,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -31,7 +31,7 @@ class TestPhi(unittest.TestCase):
{
"base_model": "microsoft/phi-1_5",
"trust_remote_code": True,
"model_type": "PhiForCausalLM",
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 512,
"sample_packing": False,
@@ -76,7 +76,7 @@ class TestPhi(unittest.TestCase):
{
"base_model": "microsoft/phi-1_5",
"trust_remote_code": True,
"model_type": "PhiForCausalLM",
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 512,
"sample_packing": True,

View File

@@ -1,7 +1,6 @@
"""Module for testing the validation module"""
import logging
import os
import unittest
from typing import Optional
@@ -9,7 +8,6 @@ import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.wandb_ import setup_wandb_env_vars
class ValidationTest(unittest.TestCase):
@@ -651,113 +649,3 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
def test_warmup_step_no_conflict(self):
cfg = DictDefault(
{
"warmup_steps": 10,
"warmup_ratio": 0.1,
}
)
with pytest.raises(
ValueError,
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
):
validate_config(cfg)
cfg = DictDefault(
{
"warmup_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"warmup_ratio": 0.1,
}
)
validate_config(cfg)
class ValidationWandbTest(ValidationTest):
"""
Validation test for wandb
"""
def test_wandb_set_run_id_to_name(self):
cfg = DictDefault(
{
"wandb_run_id": "foo",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
in record.message
for record in self._caplog.records
)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
cfg = DictDefault(
{
"wandb_name": "foo",
}
)
validate_config(cfg)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
def test_wandb_sets_env(self):
cfg = DictDefault(
{
"wandb_project": "foo",
"wandb_name": "bar",
"wandb_run_id": "bat",
"wandb_entity": "baz",
"wandb_mode": "online",
"wandb_watch": "false",
"wandb_log_model": "checkpoint",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_PROJECT", "") == "foo"
assert os.environ.get("WANDB_NAME", "") == "bar"
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
assert os.environ.get("WANDB_ENTITY", "") == "baz"
assert os.environ.get("WANDB_MODE", "") == "online"
assert os.environ.get("WANDB_WATCH", "") == "false"
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
assert os.environ.get("WANDB_DISABLED", "") != "true"
def test_wandb_set_disabled(self):
cfg = DictDefault({})
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") == "true"
cfg = DictDefault(
{
"wandb_project": "foo",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") != "true"