Comet integration (#1939)
* Add first version of a Comet integration * Remove debug prints * Add test for Comet Configuration transformation to env variables * Fix last lint warning * Update Readme for Comet logging documentation * Update Comet integration to be optional, update code and tests * Add documentation for Comet configuration * Add missing check
This commit is contained in:
@@ -31,6 +31,7 @@ from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
@@ -421,6 +422,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
|
||||
setup_mlflow_env_vars(cfg)
|
||||
|
||||
setup_comet_env_vars(cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils import is_mlflow_available
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
@@ -1111,6 +1111,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_comet and is_comet_available():
|
||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -1179,6 +1185,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer, self.tokenizer, "mlflow"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "comet_ml"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||
@@ -1430,6 +1441,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
report_to.append("mlflow")
|
||||
if self.cfg.use_tensorboard:
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
|
||||
training_arguments_kwargs["report_to"] = report_to
|
||||
training_arguments_kwargs["run_name"] = (
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
"""
|
||||
Basic utils for Axolotl
|
||||
"""
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
return importlib.util.find_spec("mlflow") is not None
|
||||
|
||||
|
||||
def is_comet_available():
|
||||
return importlib.util.find_spec("comet_ml") is not None
|
||||
|
||||
@@ -29,7 +29,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils import is_mlflow_available
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
@@ -747,6 +747,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
||||
artifact_file="PredictionsVsGroundTruth.json",
|
||||
tracking_uri=tracking_uri,
|
||||
)
|
||||
elif logger == "comet_ml" and is_comet_available():
|
||||
import comet_ml
|
||||
|
||||
experiment = comet_ml.get_running_experiment()
|
||||
if experiment:
|
||||
experiment.log_table(
|
||||
f"{name} - Predictions vs Ground Truth.csv",
|
||||
pd.DataFrame(table_data),
|
||||
)
|
||||
|
||||
if is_main_process():
|
||||
log_table_from_dataloader("Eval", eval_dataloader)
|
||||
|
||||
43
src/axolotl/utils/callbacks/comet_.py
Normal file
43
src/axolotl/utils/callbacks/comet_.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Comet module for trainer callbacks"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import comet_ml
|
||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to comet"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
comet_experiment = comet_ml.start(source="axolotl")
|
||||
comet_experiment.log_other("Created from", "axolotl")
|
||||
comet_experiment.log_asset(
|
||||
self.axolotl_config_path,
|
||||
file_name="axolotl-config",
|
||||
)
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the Comet Experiment under assets."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
|
||||
return control
|
||||
93
src/axolotl/utils/comet_.py
Normal file
93
src/axolotl/utils/comet_.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Module for wandb utilities"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.comet_")
|
||||
|
||||
COMET_ENV_MAPPING_OVERRIDE = {
|
||||
"comet_mode": "COMET_START_MODE",
|
||||
"comet_online": "COMET_START_ONLINE",
|
||||
}
|
||||
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
|
||||
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
|
||||
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
|
||||
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
|
||||
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
|
||||
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
|
||||
"auto_log_co2": "COMET_AUTO_LOG_CO2",
|
||||
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
|
||||
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
|
||||
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
|
||||
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
|
||||
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
|
||||
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
|
||||
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
|
||||
"log_code": "COMET_AUTO_LOG_CODE",
|
||||
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
|
||||
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
|
||||
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
|
||||
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
|
||||
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
|
||||
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
|
||||
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
|
||||
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
|
||||
"log_graph": "COMET_AUTO_LOG_GRAPH",
|
||||
"name": "COMET_START_EXPERIMENT_NAME",
|
||||
"offline_directory": "COMET_OFFLINE_DIRECTORY",
|
||||
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
|
||||
"tags": "COMET_START_EXPERIMENT_TAGS",
|
||||
}
|
||||
|
||||
|
||||
def python_value_to_environ_value(python_value):
|
||||
if isinstance(python_value, bool):
|
||||
if python_value is True:
|
||||
return "true"
|
||||
|
||||
return "false"
|
||||
|
||||
if isinstance(python_value, int):
|
||||
return str(python_value)
|
||||
|
||||
if isinstance(python_value, list): # Comet only have one list of string parameter
|
||||
return ",".join(map(str, python_value))
|
||||
|
||||
return python_value
|
||||
|
||||
|
||||
def setup_comet_env_vars(cfg: DictDefault):
|
||||
# TODO, we need to convert Axolotl configuration to environment variables
|
||||
# as Transformers integration are call first and would create an
|
||||
# Experiment first
|
||||
|
||||
for key in cfg.keys():
|
||||
if key.startswith("comet_") and key != "comet_experiment_config":
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value is not None and value != "":
|
||||
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
|
||||
final_value = python_value_to_environ_value(value)
|
||||
os.environ[env_variable_name] = final_value
|
||||
|
||||
if cfg.comet_experiment_config:
|
||||
for key, value in cfg.comet_experiment_config.items():
|
||||
if value is not None and value != "":
|
||||
config_env_variable_name = (
|
||||
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
|
||||
)
|
||||
|
||||
if config_env_variable_name is None:
|
||||
LOG.warning(
|
||||
f"Unknown Comet Experiment Config name {key}, ignoring it"
|
||||
)
|
||||
continue
|
||||
|
||||
final_value = python_value_to_environ_value(value)
|
||||
os.environ[config_env_variable_name] = final_value
|
||||
|
||||
# Enable comet if project name is present
|
||||
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
|
||||
cfg.use_comet = True
|
||||
@@ -489,6 +489,19 @@ class WandbConfig(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class CometConfig(BaseModel):
|
||||
"""Comet configuration subset"""
|
||||
|
||||
use_comet: Optional[bool] = None
|
||||
comet_api_key: Optional[str] = None
|
||||
comet_workspace: Optional[str] = None
|
||||
comet_project_name: Optional[str] = None
|
||||
comet_experiment_key: Optional[str] = None
|
||||
comet_mode: Optional[str] = None
|
||||
comet_online: Optional[bool] = None
|
||||
comet_experiment_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GradioConfig(BaseModel):
|
||||
"""Gradio configuration subset"""
|
||||
|
||||
@@ -509,6 +522,7 @@ class AxolotlInputConfig(
|
||||
HyperparametersConfig,
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
CometConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RemappedParameters,
|
||||
|
||||
Reference in New Issue
Block a user