Compare commits

..

2 Commits

Author SHA1 Message Date
sunny
cdd8be7097 wip on multimodal packing support 2024-10-04 15:08:36 -04:00
sunny
08143c7b0d wip on multimodal sample packing support 2024-10-04 14:59:35 -04:00
28 changed files with 128 additions and 583 deletions

View File

@@ -28,13 +28,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -27,7 +27,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
@@ -84,7 +84,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:

View File

@@ -26,7 +26,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
@@ -83,7 +83,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:

View File

@@ -25,7 +25,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1"] pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -91,7 +91,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
nightly_build: "true" nightly_build: "true"

View File

@@ -36,7 +36,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1"] pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -94,7 +94,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
steps: steps:

View File

@@ -1,3 +1,3 @@
[settings] [settings]
profile=black profile=black
known_third_party=wandb,comet_ml known_third_party=wandb

View File

@@ -14,7 +14,7 @@ Features:
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking - Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed - Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud - Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet - Log results and optionally checkpoints to wandb or mlflow
- And more! - And more!
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25"> <a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
@@ -515,22 +515,6 @@ wandb_name:
wandb_log_model: wandb_log_model:
``` ```
##### Comet Logging
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```
##### Special Tokens ##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:

View File

@@ -90,7 +90,6 @@ datasets:
shards: # Optional[int] number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt # Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -268,18 +267,6 @@ mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name mlflow_experiment_name: # Your experiment name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
use_comet: # Enable or disable Comet integration.
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Where to save the full-finetuned model to # Where to save the full-finetuned model to
output_dir: ./completed-model output_dir: ./completed-model

View File

@@ -16,7 +16,7 @@ flash-attn==2.6.3
sentencepiece sentencepiece
wandb wandb
einops einops
xformers==0.0.28.post1 xformers==0.0.27
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
colorama colorama
@@ -46,9 +46,3 @@ gcsfs>=2024.5.0
trl==0.9.6 trl==0.9.6
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore
# lm eval harness
lm_eval==0.4.4
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2

View File

@@ -49,17 +49,10 @@ def parse_requirements():
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
if (major, minor) >= (2, 3): if (major, minor) >= (2, 3):
if patch == 0: if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1") _install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2): elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1") _install_requires.append("xformers>=0.0.25.post1")

View File

@@ -31,7 +31,6 @@ from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import ( from axolotl.utils.config import (
normalize_cfg_datasets, normalize_cfg_datasets,
normalize_config, normalize_config,
@@ -55,22 +54,8 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_axolotl_text_art(suffix=None):
def print_legacy_axolotl_text_art(suffix=None):
font = "nancyj" font = "nancyj"
ascii_text = " axolotl" ascii_text = " axolotl"
if suffix: if suffix:
@@ -83,13 +68,6 @@ def print_legacy_axolotl_text_art(suffix=None):
print_dep_versions() print_dep_versions()
def print_axolotl_text_art(
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
print(AXOLOTL_LOGO)
def print_dep_versions(): def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages) max_len = max(len(pkg) for pkg in packages)
@@ -443,8 +421,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
setup_mlflow_env_vars(cfg) setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg return cfg

View File

@@ -3,11 +3,13 @@ CLI to run training on a model
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Tuple, Union
import fire import fire
from dotenv import load_dotenv from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.cli import ( from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
@@ -18,7 +20,6 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager
from axolotl.prompt_strategies.sharegpt import ( from axolotl.prompt_strategies.sharegpt import (
register_chatml_template, register_chatml_template,
register_llama3_template, register_llama3_template,
@@ -38,7 +39,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return do_train(parsed_cfg, parsed_cli_args) return do_train(parsed_cfg, parsed_cli_args)
def do_train(cfg, cli_args) -> None: def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
@@ -63,13 +64,7 @@ def do_train(cfg, cli_args) -> None:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -48,7 +48,7 @@ from trl.trainer.utils import pad_to_length
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
@@ -1111,12 +1111,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append( callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
return callbacks return callbacks
@@ -1185,11 +1179,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer, self.tokenizer, "mlflow" trainer, self.tokenizer, "mlflow"
) )
callbacks.append(LogPredictionCallback(self.cfg)) callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval: if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
@@ -1441,8 +1430,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to.append("mlflow") report_to.append("mlflow")
if self.cfg.use_tensorboard: if self.cfg.use_tensorboard:
report_to.append("tensorboard") report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = ( training_arguments_kwargs["run_name"] = (

View File

@@ -159,29 +159,6 @@ class BasePlugin:
List[callable]: A list of callback functions to be added to the TrainingArgs List[callable]: A list of callback functions to be added to the TrainingArgs
""" """
def post_train(self, cfg, model):
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
"""
def post_train_unload(self, cfg):
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def load_plugin(plugin_name: str) -> BasePlugin: def load_plugin(plugin_name: str) -> BasePlugin:
""" """
@@ -404,17 +381,3 @@ class PluginManager:
for plugin in self.plugins: for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks return callbacks
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.post_train_unload(cfg)

View File

@@ -1,13 +0,0 @@
# LM Eval Harness
### Usage
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
```

View File

@@ -1,42 +0,0 @@
"""
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
class LMEvalPlugin(BasePlugin):
"""
Plugin for LM Evaluation Harness integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)

View File

@@ -1,15 +0,0 @@
"""
Module for handling lm eval harness input arguments.
"""
from typing import List, Optional
from pydantic import BaseModel
class LMEvalArgs(BaseModel):
"""
Input args for lm eval harness
"""
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8

View File

@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
def reset_optimizer( def reset_optimizer(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
*, *,
reset_params: List[str], # where str is the key to a torch.nn.Parameter reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str], optimizer_state_keys: list[str],
prune_ratio: float = 0.9, prune_ratio: float = 0.9,
): ):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)

View File

@@ -1,12 +1,8 @@
""" """
Basic utils for Axolotl Basic utils for Axolotl
""" """
import importlib.util import importlib
def is_mlflow_available(): def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None return importlib.util.find_spec("mlflow") is not None
def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None

View File

@@ -29,7 +29,7 @@ from transformers import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
references=[[r] for r in references], references=[[r] for r in references],
predictions=predictions, predictions=predictions,
) )
scores["eval_" + metric_name] = score scores[metric_name] = score
return scores return scores
def predict_with_generate(): def predict_with_generate():
@@ -747,15 +747,6 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
artifact_file="PredictionsVsGroundTruth.json", artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri, tracking_uri=tracking_uri,
) )
elif logger == "comet_ml" and is_comet_available():
import comet_ml
experiment = comet_ml.get_running_experiment()
if experiment:
experiment.log_table(
f"{name} - Predictions vs Ground Truth.csv",
pd.DataFrame(table_data),
)
if is_main_process(): if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader) log_table_from_dataloader("Eval", eval_dataloader)

View File

@@ -1,43 +0,0 @@
"""Comet module for trainer callbacks"""
import logging
from typing import TYPE_CHECKING
import comet_ml
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
"""Callback to save axolotl config to comet"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
comet_experiment = comet_ml.start(source="axolotl")
comet_experiment.log_other("Created from", "axolotl")
comet_experiment.log_asset(
self.axolotl_config_path,
file_name="axolotl-config",
)
LOG.info(
"The Axolotl config has been saved to the Comet Experiment under assets."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
return control

View File

@@ -5,9 +5,7 @@ These templates are used for formatting messages in a conversation.
CHAT_TEMPLATES = { CHAT_TEMPLATES = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1... "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
"mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral...
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",

View File

@@ -20,6 +20,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
return_tensors: str = "pt" return_tensors: str = "pt"
chat_template: Optional[str] = None chat_template: Optional[str] = None
packing: bool = False packing: bool = False
sequence_length: Optional[int] = None
max_images: int = -1 max_images: int = -1
padding: Union[bool, str, PaddingStrategy] = True padding: Union[bool, str, PaddingStrategy] = True
pad_to_multiple_of: Optional[int] = None pad_to_multiple_of: Optional[int] = None
@@ -32,11 +33,112 @@ class MultiModalChatDataCollator(DataCollatorMixin):
self, examples: List[Union[List[int], Any, Dict[str, Any]]] self, examples: List[Union[List[int], Any, Dict[str, Any]]]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor. # Handle dict or lists with proper padding and conversion to tensor.
if self.packing:
return self.__class__.process_rows_packing(
examples,
self.processor,
self.chat_template,
self.max_images,
self.sequence_length,
)
return self.__class__.process_rows( return self.__class__.process_rows(
examples, self.processor, self.chat_template, self.max_images examples, self.processor, self.chat_template, self.max_images
) )
@staticmethod
def process_rows_packing(
examples,
processor,
chat_template,
max_images,
sequence_length,
length_only=False,
):
import torch
# Perform sample packing within a batch
if processor.tokenizer.sep_token is None:
sep_token = "[SEP]"
processor.tokenizer.add_tokens([sep_token])
processor.tokenizer.sep_token = sep_token
sep_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.tokenizer.sep_token
)
pad_token_id = processor.tokenizer.pad_token_id
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]
images = [example["images"] for example in examples]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]
batch = processor(text=texts, images=images, padding=False)
n_sequence = len(examples)
n_seq_in_batch = 0
pack_len = 0
features_pack = {}
packed = {}
features = list[batch.keys()]
for feature in features:
features_pack[feature] = []
packed[feature] = []
features.remove("input_ids")
for seq_in_batch_id in range(n_sequence):
next_seq_len = len(batch["input_ids"][seq_in_batch_id])
if not pack_len + next_seq_len + 1 < sequence_length:
n_seq_in_batch += 1
pack_len += next_seq_len + 1
features_pack["input_ids"] += batch["input_ids"][seq_in_batch_id] + [
sep_token_id
]
"""
Do something with attention mask and cross-attention
"""
for feature in features:
features_pack[feature] += batch[feature][seq_in_batch_id]
else:
for _ in range(sequence_length - pack_len):
features_pack["input_ids"] += [pad_token_id]
packed["input_ids"].append(
torch.tensor(features_pack["input_ids"].copy())
)
for feature in features:
packed[feature].append(torch.tensor(features_pack[feature].copy()))
features_pack[feature] = []
pack_len = 0
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels = [pack.clone() for pack in packed["input_ids"]]
for label_id, label in enumerate(labels):
labels[label_id][label == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
labels[label_id][label == image_token_id] = -100
packed["labels"] = labels
if length_only:
return {
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
}
return packed
@staticmethod @staticmethod
def process_rows(examples, processor, chat_template, max_images, length_only=False): def process_rows(examples, processor, chat_template, max_images, length_only=False):
# HINT: use `_torch_collate_batch` to stack and pad tensors # HINT: use `_torch_collate_batch` to stack and pad tensors

View File

@@ -1,93 +0,0 @@
"""Module for wandb utilities"""
import logging
import os
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.utils.comet_")
COMET_ENV_MAPPING_OVERRIDE = {
"comet_mode": "COMET_START_MODE",
"comet_online": "COMET_START_ONLINE",
}
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
"auto_log_co2": "COMET_AUTO_LOG_CO2",
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
"log_code": "COMET_AUTO_LOG_CODE",
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
"log_graph": "COMET_AUTO_LOG_GRAPH",
"name": "COMET_START_EXPERIMENT_NAME",
"offline_directory": "COMET_OFFLINE_DIRECTORY",
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
"tags": "COMET_START_EXPERIMENT_TAGS",
}
def python_value_to_environ_value(python_value):
if isinstance(python_value, bool):
if python_value is True:
return "true"
return "false"
if isinstance(python_value, int):
return str(python_value)
if isinstance(python_value, list): # Comet only have one list of string parameter
return ",".join(map(str, python_value))
return python_value
def setup_comet_env_vars(cfg: DictDefault):
# TODO, we need to convert Axolotl configuration to environment variables
# as Transformers integration are call first and would create an
# Experiment first
for key in cfg.keys():
if key.startswith("comet_") and key != "comet_experiment_config":
value = cfg.get(key, "")
if value is not None and value != "":
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
final_value = python_value_to_environ_value(value)
os.environ[env_variable_name] = final_value
if cfg.comet_experiment_config:
for key, value in cfg.comet_experiment_config.items():
if value is not None and value != "":
config_env_variable_name = (
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
)
if config_env_variable_name is None:
LOG.warning(
f"Unknown Comet Experiment Config name {key}, ignoring it"
)
continue
final_value = python_value_to_environ_value(value)
os.environ[config_env_variable_name] = final_value
# Enable comet if project name is present
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
cfg.use_comet = True

View File

@@ -489,19 +489,6 @@ class WandbConfig(BaseModel):
return data return data
class CometConfig(BaseModel):
"""Comet configuration subset"""
use_comet: Optional[bool] = None
comet_api_key: Optional[str] = None
comet_workspace: Optional[str] = None
comet_project_name: Optional[str] = None
comet_experiment_key: Optional[str] = None
comet_mode: Optional[str] = None
comet_online: Optional[bool] = None
comet_experiment_config: Optional[Dict[str, Any]] = None
class GradioConfig(BaseModel): class GradioConfig(BaseModel):
"""Gradio configuration subset""" """Gradio configuration subset"""
@@ -522,7 +509,6 @@ class AxolotlInputConfig(
HyperparametersConfig, HyperparametersConfig,
WandbConfig, WandbConfig,
MLFlowConfig, MLFlowConfig,
CometConfig,
LISAConfig, LISAConfig,
GradioConfig, GradioConfig,
RemappedParameters, RemappedParameters,
@@ -980,26 +966,6 @@ class AxolotlInputConfig(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
) )
if data.get("do_bench_eval") and not (
data.get("evals_per_epoch") or data.get("eval_steps")
):
raise ValueError(
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
)
return data
@model_validator(mode="before")
@classmethod
def check_test_datasets_bench(cls, data):
if (
data.get("do_bench_eval")
and not data.get("test_datasets")
and not data.get("val_set_size")
):
LOG.warning(
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
)
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
return data return data
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -242,7 +242,6 @@ def load_tokenized_prepared_datasets(
name=config_dataset.name, name=config_dataset.name,
streaming=True, streaming=True,
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision,
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -347,7 +346,6 @@ def load_tokenized_prepared_datasets(
streaming=False, streaming=False,
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs, **load_ds_kwargs,
) )
elif ds_from_cloud and remote_file_system: elif ds_from_cloud and remote_file_system:
@@ -382,7 +380,6 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=config_dataset.data_files, filename=config_dataset.data_files,
revision=config_dataset.revision,
) )
elif isinstance(config_dataset.data_files, list): elif isinstance(config_dataset.data_files, list):
fp = [] fp = []
@@ -392,7 +389,6 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=file, filename=file,
revision=config_dataset.revision,
) )
) )
else: else:
@@ -437,8 +433,8 @@ def load_tokenized_prepared_datasets(
config_dataset=config_dataset, config_dataset=config_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
cfg=cfg, cfg=cfg,
d_base_type=d_base_type,
dataset=ds, dataset=ds,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style, d_prompt_style=d_prompt_style,
processor=processor, processor=processor,
) )

View File

@@ -267,74 +267,6 @@ class TestDatasetPreparation(unittest.TestCase):
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_hub_with_revision(self):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -9,7 +9,6 @@ from typing import Optional
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -1330,105 +1329,3 @@ class TestValidationWandb(BaseValidation):
os.environ.pop("WANDB_PROJECT", None) os.environ.pop("WANDB_PROJECT", None)
os.environ.pop("WANDB_DISABLED", None) os.environ.pop("WANDB_DISABLED", None)
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
class TestValidationComet(BaseValidation):
"""
Validation test for comet
"""
def test_comet_sets_env(self, minimal_cfg):
from axolotl.utils.comet_ import setup_comet_env_vars
comet_config = {
"comet_api_key": "foo",
"comet_workspace": "some_workspace",
"comet_project_name": "some_project",
"comet_experiment_key": "some_experiment_key",
"comet_mode": "get_or_create",
"comet_online": False,
"comet_experiment_config": {
"auto_histogram_activation_logging": False,
"auto_histogram_epoch_rate": 2,
"auto_histogram_gradient_logging": True,
"auto_histogram_tensorboard_logging": False,
"auto_histogram_weight_logging": True,
"auto_log_co2": False,
"auto_metric_logging": True,
"auto_metric_step_rate": 15,
"auto_output_logging": False,
"auto_param_logging": True,
"comet_disabled": False,
"display_summary_level": 2,
"distributed_node_identifier": "some_distributed_node_identifier",
"log_code": True,
"log_env_cpu": False,
"log_env_details": True,
"log_env_disk": False,
"log_env_gpu": True,
"log_env_host": False,
"log_env_network": True,
"log_git_metadata": False,
"log_git_patch": True,
"log_graph": False,
"name": "some_name",
"offline_directory": "some_offline_directory",
"parse_args": True,
"tags": ["tag1", "tag2"],
},
}
cfg = DictDefault(comet_config) | minimal_cfg
new_cfg = validate_config(cfg)
setup_comet_env_vars(new_cfg)
comet_env = {
key: value for key, value in os.environ.items() if key.startswith("COMET_")
}
assert (
len(comet_env)
== len(comet_config) + len(comet_config["comet_experiment_config"]) - 1
)
assert comet_env == {
"COMET_API_KEY": "foo",
"COMET_AUTO_LOG_CLI_ARGUMENTS": "true",
"COMET_AUTO_LOG_CO2": "false",
"COMET_AUTO_LOG_CODE": "true",
"COMET_AUTO_LOG_DISABLE": "false",
"COMET_AUTO_LOG_ENV_CPU": "false",
"COMET_AUTO_LOG_ENV_DETAILS": "true",
"COMET_AUTO_LOG_ENV_DISK": "false",
"COMET_AUTO_LOG_ENV_GPU": "true",
"COMET_AUTO_LOG_ENV_HOST": "false",
"COMET_AUTO_LOG_ENV_NETWORK": "true",
"COMET_AUTO_LOG_GIT_METADATA": "false",
"COMET_AUTO_LOG_GIT_PATCH": "true",
"COMET_AUTO_LOG_GRAPH": "false",
"COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false",
"COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2",
"COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true",
"COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false",
"COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true",
"COMET_AUTO_LOG_METRIC_STEP_RATE": "15",
"COMET_AUTO_LOG_METRICS": "true",
"COMET_AUTO_LOG_OUTPUT_LOGGER": "false",
"COMET_AUTO_LOG_PARAMETERS": "true",
"COMET_DISPLAY_SUMMARY_LEVEL": "2",
"COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier",
"COMET_EXPERIMENT_KEY": "some_experiment_key",
"COMET_OFFLINE_DIRECTORY": "some_offline_directory",
"COMET_PROJECT_NAME": "some_project",
"COMET_START_EXPERIMENT_NAME": "some_name",
"COMET_START_EXPERIMENT_TAGS": "tag1,tag2",
"COMET_START_MODE": "get_or_create",
"COMET_START_ONLINE": "false",
"COMET_WORKSPACE": "some_workspace",
}
for key in comet_env.keys():
os.environ.pop(key, None)