Merge branch 'main' into cj_tokenizer_default_prompt_template
This commit is contained in:
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -28,7 +28,13 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
4
.github/workflows/main.yml
vendored
4
.github/workflows/main.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -84,7 +84,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
4
.github/workflows/nightlies.yml
vendored
4
.github/workflows/nightlies.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
4
.github/workflows/tests-nightly.yml
vendored
4
.github/workflows/tests-nightly.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.0"]
|
||||
pytorch_version: ["2.3.1", "2.4.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -91,7 +91,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
|
||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.0"]
|
||||
pytorch_version: ["2.3.1", "2.4.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.0
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
[settings]
|
||||
profile=black
|
||||
known_third_party=wandb
|
||||
known_third_party=wandb,comet_ml
|
||||
|
||||
18
README.md
18
README.md
@@ -14,7 +14,7 @@ Features:
|
||||
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||
- Easily run with Docker locally or on the cloud
|
||||
- Log results and optionally checkpoints to wandb or mlflow
|
||||
- Log results and optionally checkpoints to wandb, mlflow or Comet
|
||||
- And more!
|
||||
|
||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
||||
@@ -515,6 +515,22 @@ wandb_name:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
##### Comet Logging
|
||||
|
||||
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
|
||||
|
||||
- wandb options
|
||||
```yaml
|
||||
use_comet:
|
||||
comet_api_key:
|
||||
comet_workspace:
|
||||
comet_project_name:
|
||||
comet_experiment_key:
|
||||
comet_mode:
|
||||
comet_online:
|
||||
comet_experiment_config:
|
||||
```
|
||||
|
||||
##### Special Tokens
|
||||
|
||||
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||
|
||||
@@ -316,6 +316,18 @@ mlflow_tracking_uri: # URI to mlflow
|
||||
mlflow_experiment_name: # Your experiment name
|
||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
||||
|
||||
# Comet configuration if you're using it
|
||||
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
|
||||
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
|
||||
use_comet: # Enable or disable Comet integration.
|
||||
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
|
||||
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
|
||||
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
|
||||
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
|
||||
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
|
||||
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
||||
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ flash-attn==2.6.3
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers==0.0.27
|
||||
xformers==0.0.28.post1
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
colorama
|
||||
|
||||
7
setup.py
7
setup.py
@@ -49,10 +49,17 @@ def parse_requirements():
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
if (major, minor) >= (2, 3):
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.26.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
elif (major, minor) >= (2, 2):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.25.post1")
|
||||
|
||||
@@ -31,6 +31,7 @@ from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
@@ -421,6 +422,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
|
||||
setup_mlflow_env_vars(cfg)
|
||||
|
||||
setup_comet_env_vars(cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils import is_mlflow_available
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
@@ -1111,6 +1111,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_comet and is_comet_available():
|
||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -1179,6 +1185,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer, self.tokenizer, "mlflow"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "comet_ml"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||
@@ -1430,6 +1441,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
report_to.append("mlflow")
|
||||
if self.cfg.use_tensorboard:
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
|
||||
training_arguments_kwargs["report_to"] = report_to
|
||||
training_arguments_kwargs["run_name"] = (
|
||||
|
||||
@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
|
||||
def reset_optimizer(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
*,
|
||||
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: list[str],
|
||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: List[str],
|
||||
prune_ratio: float = 0.9,
|
||||
):
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
"""
|
||||
Basic utils for Axolotl
|
||||
"""
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
return importlib.util.find_spec("mlflow") is not None
|
||||
|
||||
|
||||
def is_comet_available():
|
||||
return importlib.util.find_spec("comet_ml") is not None
|
||||
|
||||
@@ -29,7 +29,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils import is_mlflow_available
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
references=[[r] for r in references],
|
||||
predictions=predictions,
|
||||
)
|
||||
scores[metric_name] = score
|
||||
scores["eval_" + metric_name] = score
|
||||
return scores
|
||||
|
||||
def predict_with_generate():
|
||||
@@ -747,6 +747,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
||||
artifact_file="PredictionsVsGroundTruth.json",
|
||||
tracking_uri=tracking_uri,
|
||||
)
|
||||
elif logger == "comet_ml" and is_comet_available():
|
||||
import comet_ml
|
||||
|
||||
experiment = comet_ml.get_running_experiment()
|
||||
if experiment:
|
||||
experiment.log_table(
|
||||
f"{name} - Predictions vs Ground Truth.csv",
|
||||
pd.DataFrame(table_data),
|
||||
)
|
||||
|
||||
if is_main_process():
|
||||
log_table_from_dataloader("Eval", eval_dataloader)
|
||||
|
||||
43
src/axolotl/utils/callbacks/comet_.py
Normal file
43
src/axolotl/utils/callbacks/comet_.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Comet module for trainer callbacks"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import comet_ml
|
||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to comet"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
comet_experiment = comet_ml.start(source="axolotl")
|
||||
comet_experiment.log_other("Created from", "axolotl")
|
||||
comet_experiment.log_asset(
|
||||
self.axolotl_config_path,
|
||||
file_name="axolotl-config",
|
||||
)
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the Comet Experiment under assets."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
|
||||
return control
|
||||
File diff suppressed because one or more lines are too long
93
src/axolotl/utils/comet_.py
Normal file
93
src/axolotl/utils/comet_.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Module for wandb utilities"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.comet_")
|
||||
|
||||
COMET_ENV_MAPPING_OVERRIDE = {
|
||||
"comet_mode": "COMET_START_MODE",
|
||||
"comet_online": "COMET_START_ONLINE",
|
||||
}
|
||||
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
|
||||
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
|
||||
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
|
||||
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
|
||||
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
|
||||
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
|
||||
"auto_log_co2": "COMET_AUTO_LOG_CO2",
|
||||
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
|
||||
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
|
||||
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
|
||||
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
|
||||
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
|
||||
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
|
||||
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
|
||||
"log_code": "COMET_AUTO_LOG_CODE",
|
||||
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
|
||||
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
|
||||
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
|
||||
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
|
||||
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
|
||||
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
|
||||
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
|
||||
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
|
||||
"log_graph": "COMET_AUTO_LOG_GRAPH",
|
||||
"name": "COMET_START_EXPERIMENT_NAME",
|
||||
"offline_directory": "COMET_OFFLINE_DIRECTORY",
|
||||
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
|
||||
"tags": "COMET_START_EXPERIMENT_TAGS",
|
||||
}
|
||||
|
||||
|
||||
def python_value_to_environ_value(python_value):
|
||||
if isinstance(python_value, bool):
|
||||
if python_value is True:
|
||||
return "true"
|
||||
|
||||
return "false"
|
||||
|
||||
if isinstance(python_value, int):
|
||||
return str(python_value)
|
||||
|
||||
if isinstance(python_value, list): # Comet only have one list of string parameter
|
||||
return ",".join(map(str, python_value))
|
||||
|
||||
return python_value
|
||||
|
||||
|
||||
def setup_comet_env_vars(cfg: DictDefault):
|
||||
# TODO, we need to convert Axolotl configuration to environment variables
|
||||
# as Transformers integration are call first and would create an
|
||||
# Experiment first
|
||||
|
||||
for key in cfg.keys():
|
||||
if key.startswith("comet_") and key != "comet_experiment_config":
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value is not None and value != "":
|
||||
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
|
||||
final_value = python_value_to_environ_value(value)
|
||||
os.environ[env_variable_name] = final_value
|
||||
|
||||
if cfg.comet_experiment_config:
|
||||
for key, value in cfg.comet_experiment_config.items():
|
||||
if value is not None and value != "":
|
||||
config_env_variable_name = (
|
||||
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
|
||||
)
|
||||
|
||||
if config_env_variable_name is None:
|
||||
LOG.warning(
|
||||
f"Unknown Comet Experiment Config name {key}, ignoring it"
|
||||
)
|
||||
continue
|
||||
|
||||
final_value = python_value_to_environ_value(value)
|
||||
os.environ[config_env_variable_name] = final_value
|
||||
|
||||
# Enable comet if project name is present
|
||||
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
|
||||
cfg.use_comet = True
|
||||
@@ -516,6 +516,19 @@ class WandbConfig(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class CometConfig(BaseModel):
|
||||
"""Comet configuration subset"""
|
||||
|
||||
use_comet: Optional[bool] = None
|
||||
comet_api_key: Optional[str] = None
|
||||
comet_workspace: Optional[str] = None
|
||||
comet_project_name: Optional[str] = None
|
||||
comet_experiment_key: Optional[str] = None
|
||||
comet_mode: Optional[str] = None
|
||||
comet_online: Optional[bool] = None
|
||||
comet_experiment_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GradioConfig(BaseModel):
|
||||
"""Gradio configuration subset"""
|
||||
|
||||
@@ -536,6 +549,7 @@ class AxolotlInputConfig(
|
||||
HyperparametersConfig,
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
CometConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RemappedParameters,
|
||||
|
||||
@@ -73,7 +73,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
chat_templates("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
roles={
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Optional
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from axolotl.utils import is_comet_available
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -1329,3 +1330,105 @@ class TestValidationWandb(BaseValidation):
|
||||
|
||||
os.environ.pop("WANDB_PROJECT", None)
|
||||
os.environ.pop("WANDB_DISABLED", None)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
|
||||
class TestValidationComet(BaseValidation):
|
||||
"""
|
||||
Validation test for comet
|
||||
"""
|
||||
|
||||
def test_comet_sets_env(self, minimal_cfg):
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
|
||||
comet_config = {
|
||||
"comet_api_key": "foo",
|
||||
"comet_workspace": "some_workspace",
|
||||
"comet_project_name": "some_project",
|
||||
"comet_experiment_key": "some_experiment_key",
|
||||
"comet_mode": "get_or_create",
|
||||
"comet_online": False,
|
||||
"comet_experiment_config": {
|
||||
"auto_histogram_activation_logging": False,
|
||||
"auto_histogram_epoch_rate": 2,
|
||||
"auto_histogram_gradient_logging": True,
|
||||
"auto_histogram_tensorboard_logging": False,
|
||||
"auto_histogram_weight_logging": True,
|
||||
"auto_log_co2": False,
|
||||
"auto_metric_logging": True,
|
||||
"auto_metric_step_rate": 15,
|
||||
"auto_output_logging": False,
|
||||
"auto_param_logging": True,
|
||||
"comet_disabled": False,
|
||||
"display_summary_level": 2,
|
||||
"distributed_node_identifier": "some_distributed_node_identifier",
|
||||
"log_code": True,
|
||||
"log_env_cpu": False,
|
||||
"log_env_details": True,
|
||||
"log_env_disk": False,
|
||||
"log_env_gpu": True,
|
||||
"log_env_host": False,
|
||||
"log_env_network": True,
|
||||
"log_git_metadata": False,
|
||||
"log_git_patch": True,
|
||||
"log_graph": False,
|
||||
"name": "some_name",
|
||||
"offline_directory": "some_offline_directory",
|
||||
"parse_args": True,
|
||||
"tags": ["tag1", "tag2"],
|
||||
},
|
||||
}
|
||||
|
||||
cfg = DictDefault(comet_config) | minimal_cfg
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
|
||||
setup_comet_env_vars(new_cfg)
|
||||
|
||||
comet_env = {
|
||||
key: value for key, value in os.environ.items() if key.startswith("COMET_")
|
||||
}
|
||||
|
||||
assert (
|
||||
len(comet_env)
|
||||
== len(comet_config) + len(comet_config["comet_experiment_config"]) - 1
|
||||
)
|
||||
|
||||
assert comet_env == {
|
||||
"COMET_API_KEY": "foo",
|
||||
"COMET_AUTO_LOG_CLI_ARGUMENTS": "true",
|
||||
"COMET_AUTO_LOG_CO2": "false",
|
||||
"COMET_AUTO_LOG_CODE": "true",
|
||||
"COMET_AUTO_LOG_DISABLE": "false",
|
||||
"COMET_AUTO_LOG_ENV_CPU": "false",
|
||||
"COMET_AUTO_LOG_ENV_DETAILS": "true",
|
||||
"COMET_AUTO_LOG_ENV_DISK": "false",
|
||||
"COMET_AUTO_LOG_ENV_GPU": "true",
|
||||
"COMET_AUTO_LOG_ENV_HOST": "false",
|
||||
"COMET_AUTO_LOG_ENV_NETWORK": "true",
|
||||
"COMET_AUTO_LOG_GIT_METADATA": "false",
|
||||
"COMET_AUTO_LOG_GIT_PATCH": "true",
|
||||
"COMET_AUTO_LOG_GRAPH": "false",
|
||||
"COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false",
|
||||
"COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2",
|
||||
"COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true",
|
||||
"COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false",
|
||||
"COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true",
|
||||
"COMET_AUTO_LOG_METRIC_STEP_RATE": "15",
|
||||
"COMET_AUTO_LOG_METRICS": "true",
|
||||
"COMET_AUTO_LOG_OUTPUT_LOGGER": "false",
|
||||
"COMET_AUTO_LOG_PARAMETERS": "true",
|
||||
"COMET_DISPLAY_SUMMARY_LEVEL": "2",
|
||||
"COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier",
|
||||
"COMET_EXPERIMENT_KEY": "some_experiment_key",
|
||||
"COMET_OFFLINE_DIRECTORY": "some_offline_directory",
|
||||
"COMET_PROJECT_NAME": "some_project",
|
||||
"COMET_START_EXPERIMENT_NAME": "some_name",
|
||||
"COMET_START_EXPERIMENT_TAGS": "tag1,tag2",
|
||||
"COMET_START_MODE": "get_or_create",
|
||||
"COMET_START_ONLINE": "false",
|
||||
"COMET_WORKSPACE": "some_workspace",
|
||||
}
|
||||
|
||||
for key in comet_env.keys():
|
||||
os.environ.pop(key, None)
|
||||
|
||||
Reference in New Issue
Block a user