Feat(wandb): Refactor to be more flexible (#767)

* Feat: Update to handle wandb env better

* chore: rename wandb_run_id to wandb_name

* feat: add new recommendation and update config

* fix: indent and pop disabled env if project passed

* feat: test env set for wandb and recommendation

* feat: update to use wandb_name and allow id

* chore: add info to readme
This commit is contained in:
NanoCode012
2023-12-04 22:17:25 +09:00
committed by GitHub
parent 58ec8b1113
commit a1da39cd48
39 changed files with 140 additions and 50 deletions

View File

@@ -647,7 +647,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_run_id if self.cfg.use_wandb else None
self.cfg.wandb_name if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"

View File

@@ -397,6 +397,13 @@ def validate_config(cfg):
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
)
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -2,20 +2,20 @@
import os
from axolotl.utils.dict import DictDefault
def setup_wandb_env_vars(cfg):
if cfg.wandb_mode and cfg.wandb_mode == "offline":
os.environ["WANDB_MODE"] = cfg.wandb_mode
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
def setup_wandb_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("wandb_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
# Enable wandb if project name is present
if cfg.wandb_project and len(cfg.wandb_project) > 0:
cfg.use_wandb = True
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
os.environ.pop("WANDB_DISABLED", None) # Remove if present
else:
os.environ["WANDB_DISABLED"] = "true"