Feat(wandb): Refactor to be more flexible (#767)
* Feat: Update to handle wandb env better * chore: rename wandb_run_id to wandb_name * feat: add new recommendation and update config * fix: indent and pop disabled env if project passed * feat: test env set for wandb and recommendation * feat: update to use wandb_name and allow id * chore: add info to readme
This commit is contained in:
@@ -659,7 +659,8 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
|
|||||||
wandb_project: # Your wandb project name
|
wandb_project: # Your wandb project name
|
||||||
wandb_entity: # A wandb Team name if using a Team
|
wandb_entity: # A wandb Team name if using a Team
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id: # Set the name of your wandb run
|
wandb_name: # Set the name of your wandb run
|
||||||
|
wandb_run_id: # Set the ID of your wandb run
|
||||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
# Where to save the full-finetuned model to
|
||||||
@@ -955,7 +956,7 @@ wandb_mode:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
output_dir: btlm-out
|
output_dir: btlm-out
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./jeopardy-bot-7b
|
output_dir: ./jeopardy-bot-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ lora_target_linear:
|
|||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./model-out
|
output_dir: ./model-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ relora_cpu_offload: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ pad_to_sequence_len: true
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: mpt-alpaca-7b
|
wandb_project: mpt-alpaca-7b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./mpt-alpaca-7b
|
output_dir: ./mpt-alpaca-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./openllama-out
|
output_dir: ./openllama-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./pythia-12b
|
output_dir: ./pythia-12b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-alpaca-pythia
|
output_dir: ./lora-alpaca-pythia
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: redpajama-alpaca-3b
|
wandb_project: redpajama-alpaca-3b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./redpajama-alpaca-3b
|
output_dir: ./redpajama-alpaca-3b
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project: lora-replit
|
wandb_project: lora-replit
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-replit
|
output_dir: ./lora-replit
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -647,7 +647,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
self.cfg.wandb_run_id if self.cfg.use_wandb else None
|
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["optim"] = (
|
training_arguments_kwargs["optim"] = (
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||||
|
|||||||
@@ -397,6 +397,13 @@ def validate_config(cfg):
|
|||||||
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||||
|
cfg.wandb_name = cfg.wandb_run_id
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -2,20 +2,20 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
def setup_wandb_env_vars(cfg):
|
|
||||||
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
def setup_wandb_env_vars(cfg: DictDefault):
|
||||||
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
for key in cfg.keys():
|
||||||
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
if key.startswith("wandb_"):
|
||||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
value = cfg.get(key, "")
|
||||||
|
|
||||||
|
if value and isinstance(value, str) and len(value) > 0:
|
||||||
|
os.environ[key.upper()] = value
|
||||||
|
|
||||||
|
# Enable wandb if project name is present
|
||||||
|
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||||
cfg.use_wandb = True
|
cfg.use_wandb = True
|
||||||
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
os.environ.pop("WANDB_DISABLED", None) # Remove if present
|
||||||
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
|
||||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
|
||||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
|
||||||
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
|
||||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
|
||||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
|
||||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
|
||||||
else:
|
else:
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module for testing the validation module"""
|
"""Module for testing the validation module"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -8,6 +9,7 @@ import pytest
|
|||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
|
|
||||||
class ValidationTest(unittest.TestCase):
|
class ValidationTest(unittest.TestCase):
|
||||||
@@ -679,3 +681,83 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationWandbTest(ValidationTest):
|
||||||
|
"""
|
||||||
|
Validation test for wandb
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_wandb_set_run_id_to_name(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"wandb_run_id": "foo",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
|
in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"wandb_name": "foo",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
|
||||||
|
|
||||||
|
def test_wandb_sets_env(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"wandb_project": "foo",
|
||||||
|
"wandb_name": "bar",
|
||||||
|
"wandb_run_id": "bat",
|
||||||
|
"wandb_entity": "baz",
|
||||||
|
"wandb_mode": "online",
|
||||||
|
"wandb_watch": "false",
|
||||||
|
"wandb_log_model": "checkpoint",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
|
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
||||||
|
assert os.environ.get("WANDB_NAME", "") == "bar"
|
||||||
|
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
|
||||||
|
assert os.environ.get("WANDB_ENTITY", "") == "baz"
|
||||||
|
assert os.environ.get("WANDB_MODE", "") == "online"
|
||||||
|
assert os.environ.get("WANDB_WATCH", "") == "false"
|
||||||
|
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
||||||
|
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||||
|
|
||||||
|
def test_wandb_set_disabled(self):
|
||||||
|
cfg = DictDefault({})
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
|
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"wandb_project": "foo",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
|
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||||
|
|||||||
Reference in New Issue
Block a user