Compare commits
1 Commits
a581e9f8f6
...
completion
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da154e6d56 |
14
README.md
14
README.md
@@ -612,12 +612,6 @@ eval_sample_packing:
|
|||||||
sample_packing_eff_est:
|
sample_packing_eff_est:
|
||||||
total_num_tokens:
|
total_num_tokens:
|
||||||
|
|
||||||
# Passed through to transformers when loading the model when launched without accelerate
|
|
||||||
# Use `sequential` when training w/ model parallelism to limit memory
|
|
||||||
device_map:
|
|
||||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
|
||||||
max_memory:
|
|
||||||
|
|
||||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||||
adapter: lora
|
adapter: lora
|
||||||
# If you already have a lora model trained that you want to load, put that here.
|
# If you already have a lora model trained that you want to load, put that here.
|
||||||
@@ -665,8 +659,7 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
|
|||||||
wandb_project: # Your wandb project name
|
wandb_project: # Your wandb project name
|
||||||
wandb_entity: # A wandb Team name if using a Team
|
wandb_entity: # A wandb Team name if using a Team
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name: # Set the name of your wandb run
|
wandb_run_id: # Set the name of your wandb run
|
||||||
wandb_run_id: # Set the ID of your wandb run
|
|
||||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
# Where to save the full-finetuned model to
|
||||||
@@ -701,9 +694,6 @@ max_steps:
|
|||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|
||||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
|
||||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
|
||||||
|
|
||||||
# Save model as safetensors (require safetensors package)
|
# Save model as safetensors (require safetensors package)
|
||||||
save_safetensors:
|
save_safetensors:
|
||||||
|
|
||||||
@@ -962,7 +952,7 @@ wandb_mode:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,16 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -28,6 +28,16 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -32,6 +32,16 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
output_dir: btlm-out
|
output_dir: btlm-out
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./jeopardy-bot-7b
|
output_dir: ./jeopardy-bot-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ lora_target_linear:
|
|||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./model-out
|
output_dir: ./model-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ relora_cpu_offload: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ pad_to_sequence_len: true
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -62,9 +62,6 @@ logging_steps: 1
|
|||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 0.05
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: mpt-alpaca-7b
|
wandb_project: mpt-alpaca-7b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./mpt-alpaca-7b
|
output_dir: ./mpt-alpaca-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./openllama-out
|
output_dir: ./openllama-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./pythia-12b
|
output_dir: ./pythia-12b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-alpaca-pythia
|
output_dir: ./lora-alpaca-pythia
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -53,7 +53,7 @@ resume_from_checkpoint:
|
|||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention:
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 0.05
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -53,7 +53,7 @@ resume_from_checkpoint:
|
|||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention:
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
eval_steps: 0.05
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: redpajama-alpaca-3b
|
wandb_project: redpajama-alpaca-3b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./redpajama-alpaca-3b
|
output_dir: ./redpajama-alpaca-3b
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project: lora-replit
|
wandb_project: lora-replit
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-replit
|
output_dir: ./lora-replit
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -2,15 +2,14 @@
|
|||||||
auto-gptq==0.5.1
|
auto-gptq==0.5.1
|
||||||
packaging
|
packaging
|
||||||
peft==0.6.0
|
peft==0.6.0
|
||||||
transformers==4.35.2
|
transformers==4.35.1
|
||||||
tokenizers==0.15.0
|
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.24.1
|
accelerate==0.24.1
|
||||||
deepspeed
|
deepspeed
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
datasets>=2.15.0
|
datasets>=2.14.0
|
||||||
flash-attn==2.3.3
|
flash-attn==2.3.3
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
@@ -30,7 +29,7 @@ scikit-learn==1.2.2
|
|||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat==0.2.29
|
fschat==0.2.29
|
||||||
gradio==3.50.2
|
gradio
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
||||||
# remote filesystems
|
# remote filesystems
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import prepare_optim_env
|
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -72,7 +71,7 @@ def do_merge_lora(
|
|||||||
|
|
||||||
LOG.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
model.to(dtype=cfg.torch_dtype)
|
model.to(dtype=torch.float16)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||||
@@ -297,8 +296,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
|
||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
LossWatchDogCallback,
|
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
@@ -431,9 +430,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.loss_watchdog_threshold is not None:
|
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -647,7 +643,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
self.cfg.wandb_run_id if self.cfg.use_wandb else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["optim"] = (
|
training_arguments_kwargs["optim"] = (
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Basic completion text
|
Basic completion text
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, Generator, Optional, Tuple
|
from typing import Any, Dict, Generator, Optional, Tuple
|
||||||
|
|
||||||
@@ -64,6 +65,19 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|||||||
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionJSONPromptTokenizationStrategy(CompletionPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
Strategy to return the stringified JSON of the entire row as the training data
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||||
|
return (
|
||||||
|
json.dumps(prompt),
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompletionPrompter:
|
class CompletionPrompter:
|
||||||
"""
|
"""
|
||||||
Prompter for completion
|
Prompter for completion
|
||||||
@@ -82,7 +96,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
strat = CompletionPromptTokenizingStrategy(
|
strat = CompletionPromptTokenizingStrategy(
|
||||||
CompletionPrompter(),
|
CompletionPrompter(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
True,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
max_length=cfg.sequence_len * 64,
|
max_length=cfg.sequence_len * 64,
|
||||||
)
|
)
|
||||||
@@ -90,3 +104,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
strat.field = ds_cfg["field"]
|
strat.field = ds_cfg["field"]
|
||||||
|
|
||||||
return strat
|
return strat
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(tokenizer, cfg):
|
||||||
|
strat = CompletionJSONPromptTokenizationStrategy(
|
||||||
|
CompletionPrompter(),
|
||||||
|
tokenizer,
|
||||||
|
True,
|
||||||
|
cfg.sequence_len,
|
||||||
|
max_length=cfg.sequence_len * 64,
|
||||||
|
)
|
||||||
|
|
||||||
|
return strat
|
||||||
|
|||||||
@@ -124,36 +124,6 @@ class GPUStatsCallback(
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class LossWatchDogCallback(TrainerCallback):
|
|
||||||
"""Callback to track loss and stop training if loss is too high"""
|
|
||||||
|
|
||||||
def __init__(self, cfg):
|
|
||||||
self.cfg = cfg
|
|
||||||
self.logged = False
|
|
||||||
self.violations = 0
|
|
||||||
self.threshold = cfg.loss_watchdog_threshold
|
|
||||||
self.patience = cfg.loss_watchdog_patience or 3
|
|
||||||
|
|
||||||
def on_step_end(
|
|
||||||
self,
|
|
||||||
_args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
**_kwargs,
|
|
||||||
):
|
|
||||||
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
|
||||||
if state.log_history[-1]["loss"] > self.threshold:
|
|
||||||
self.violations += 1
|
|
||||||
if self.violations >= self.patience:
|
|
||||||
LOG.warning(
|
|
||||||
"Loss is too high, stopping training (loss_watchdog_threshold)"
|
|
||||||
)
|
|
||||||
control.should_training_stop = True
|
|
||||||
else:
|
|
||||||
self.violations = 0
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
def bench_eval_callback_factory(trainer, tokenizer):
|
def bench_eval_callback_factory(trainer, tokenizer):
|
||||||
accuracy = evaluate.load("accuracy")
|
accuracy = evaluate.load("accuracy")
|
||||||
abcd_idx = [
|
abcd_idx = [
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def choose_device(cfg):
|
|||||||
|
|
||||||
cfg.device = get_device()
|
cfg.device = get_device()
|
||||||
if cfg.world_size == 1:
|
if cfg.world_size == 1:
|
||||||
cfg.device_map = cfg.device_map or "auto"
|
cfg.device_map = "auto"
|
||||||
else:
|
else:
|
||||||
if cfg.device.startswith("cuda"):
|
if cfg.device.startswith("cuda"):
|
||||||
cfg.device_map = {"": torch.cuda.current_device()}
|
cfg.device_map = {"": torch.cuda.current_device()}
|
||||||
@@ -397,13 +397,6 @@ def validate_config(cfg):
|
|||||||
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
|
||||||
cfg.wandb_name = cfg.wandb_run_id
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -28,27 +28,6 @@ from axolotl.utils.dict import DictDefault
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|
||||||
quant_config_exists = hasattr(model_config, "quantization_config")
|
|
||||||
quant_config_method_is_gptq = (
|
|
||||||
quant_config_exists
|
|
||||||
and "quant_method" in model_config.quantization_config
|
|
||||||
and model_config.quantization_config["quant_method"] == "gptq"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.gptq and not quant_config_method_is_gptq:
|
|
||||||
raise ValueError(
|
|
||||||
"model_config.quantization_config is not set or quant_method is not set to gptq. "
|
|
||||||
"Please make sure to point to a GPTQ model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not cfg.gptq and quant_config_exists:
|
|
||||||
raise ValueError(
|
|
||||||
"model_config.quantization_config is set but `gptq` flag is not. "
|
|
||||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(cfg):
|
def load_model_config(cfg):
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
model_config_name = cfg.base_model_config or cfg.base_model
|
||||||
trust_remote_code = cfg.trust_remote_code is True
|
trust_remote_code = cfg.trust_remote_code is True
|
||||||
@@ -59,8 +38,6 @@ def load_model_config(cfg):
|
|||||||
for key, val in cfg.model_config.items():
|
for key, val in cfg.model_config.items():
|
||||||
setattr(model_config, key, val)
|
setattr(model_config, key, val)
|
||||||
|
|
||||||
check_model_config(cfg, model_config)
|
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|
||||||
@@ -239,7 +216,6 @@ def load_model(
|
|||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
model_kwargs["device_map"] = cfg.device_map
|
model_kwargs["device_map"] = cfg.device_map
|
||||||
model_kwargs["max_memory"] = cfg.max_memory
|
|
||||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||||
|
|
||||||
if cfg.model_revision:
|
if cfg.model_revision:
|
||||||
@@ -436,22 +412,15 @@ def load_model(
|
|||||||
module.to(torch.float32)
|
module.to(torch.float32)
|
||||||
|
|
||||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
skip_prepare_model_for_kbit_training = False
|
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
|
||||||
skip_prepare_model_for_kbit_training = True
|
|
||||||
|
|
||||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||||
):
|
):
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
if not skip_prepare_model_for_kbit_training:
|
model = prepare_model_for_kbit_training(
|
||||||
model = prepare_model_for_kbit_training(
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
)
|
||||||
)
|
|
||||||
needs_fa2_dtype = True
|
needs_fa2_dtype = True
|
||||||
|
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
|
|||||||
@@ -267,14 +267,12 @@ def setup_fsdp_envs(cfg):
|
|||||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||||
|
|
||||||
|
|
||||||
def prepare_optim_env(cfg):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
|
||||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||||
trainer_builder.train_dataset = train_dataset
|
trainer_builder.train_dataset = train_dataset
|
||||||
trainer_builder.eval_dataset = eval_dataset
|
trainer_builder.eval_dataset = eval_dataset
|
||||||
|
|||||||
@@ -2,20 +2,20 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
def setup_wandb_env_vars(cfg):
|
||||||
def setup_wandb_env_vars(cfg: DictDefault):
|
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
||||||
for key in cfg.keys():
|
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
||||||
if key.startswith("wandb_"):
|
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||||
value = cfg.get(key, "")
|
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||||
|
|
||||||
if value and isinstance(value, str) and len(value) > 0:
|
|
||||||
os.environ[key.upper()] = value
|
|
||||||
|
|
||||||
# Enable wandb if project name is present
|
|
||||||
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
|
||||||
cfg.use_wandb = True
|
cfg.use_wandb = True
|
||||||
os.environ.pop("WANDB_DISABLED", None) # Remove if present
|
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
||||||
|
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
||||||
|
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||||
|
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||||
|
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||||
|
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||||
|
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||||
|
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||||
else:
|
else:
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Module for testing the validation module"""
|
"""Module for testing the validation module"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -9,7 +8,6 @@ import pytest
|
|||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationTest(unittest.TestCase):
|
class ValidationTest(unittest.TestCase):
|
||||||
@@ -681,83 +679,3 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
|
||||||
class ValidationWandbTest(ValidationTest):
|
|
||||||
"""
|
|
||||||
Validation test for wandb
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_wandb_set_run_id_to_name(self):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"wandb_run_id": "foo",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
validate_config(cfg)
|
|
||||||
assert any(
|
|
||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
|
||||||
in record.message
|
|
||||||
for record in self._caplog.records
|
|
||||||
)
|
|
||||||
|
|
||||||
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"wandb_name": "foo",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
|
|
||||||
|
|
||||||
def test_wandb_sets_env(self):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"wandb_project": "foo",
|
|
||||||
"wandb_name": "bar",
|
|
||||||
"wandb_run_id": "bat",
|
|
||||||
"wandb_entity": "baz",
|
|
||||||
"wandb_mode": "online",
|
|
||||||
"wandb_watch": "false",
|
|
||||||
"wandb_log_model": "checkpoint",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
|
||||||
|
|
||||||
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
|
||||||
assert os.environ.get("WANDB_NAME", "") == "bar"
|
|
||||||
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
|
|
||||||
assert os.environ.get("WANDB_ENTITY", "") == "baz"
|
|
||||||
assert os.environ.get("WANDB_MODE", "") == "online"
|
|
||||||
assert os.environ.get("WANDB_WATCH", "") == "false"
|
|
||||||
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
|
||||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
|
||||||
|
|
||||||
def test_wandb_set_disabled(self):
|
|
||||||
cfg = DictDefault({})
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
|
||||||
|
|
||||||
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"wandb_project": "foo",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
|
||||||
|
|
||||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
|
||||||
|
|||||||
Reference in New Issue
Block a user