diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 9101fc2be..1b24f2c97 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -28,7 +28,13 @@ jobs: cuda_version: 12.4.1 cudnn_version: "" python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + - cuda: "124" + cuda_version: 12.4.1 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.4.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" steps: - name: Checkout diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5a972f5f0..c27dbedef 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,7 +27,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -84,7 +84,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 1d95a0983..17c76c24e 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -26,7 +26,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -83,7 +83,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 30ed397ce..8c9e1f49e 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.0"] + pytorch_version: ["2.3.1", "2.4.1"] timeout-minutes: 20 steps: @@ -91,7 +91,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 num_gpus: 1 axolotl_extras: nightly_build: "true" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c104e92c2..a798bdd5c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.0"] + pytorch_version: ["2.3.1", "2.4.1"] timeout-minutes: 20 steps: @@ -94,7 +94,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 num_gpus: 1 axolotl_extras: steps: diff --git a/.isort.cfg b/.isort.cfg index 79067a7c9..e48779732 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,3 +1,3 @@ [settings] profile=black -known_third_party=wandb +known_third_party=wandb,comet_ml diff --git a/README.md b/README.md index c84f1cb8c..f6f4e4e80 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Features: - Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking - Works with single GPU or multiple GPUs via FSDP or Deepspeed - Easily run with Docker locally or on the cloud -- Log results and optionally checkpoints to wandb or mlflow +- Log results and optionally checkpoints to wandb, mlflow or Comet - And more! @@ -515,6 +515,22 @@ wandb_name: wandb_log_model: ``` +##### Comet Logging + +Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`. + +- wandb options +```yaml +use_comet: +comet_api_key: +comet_workspace: +comet_project_name: +comet_experiment_key: +comet_mode: +comet_online: +comet_experiment_config: +``` + ##### Special Tokens It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: diff --git a/docs/config.qmd b/docs/config.qmd index 13d8b9e66..cc0c2dad2 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -316,6 +316,18 @@ mlflow_tracking_uri: # URI to mlflow mlflow_experiment_name: # Your experiment name hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry +# Comet configuration if you're using it +# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`. +# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start +use_comet: # Enable or disable Comet integration. +comet_api_key: # API key for Comet. Recommended to set via `comet login`. +comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace. +comet_project_name: # Project name in Comet. Defaults to Uncategorized. +comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key. +comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration. +comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True. +comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details. + # Where to save the full-finetuned model to output_dir: ./completed-model diff --git a/requirements.txt b/requirements.txt index 123a4ee54..41bfdfbeb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ flash-attn==2.6.3 sentencepiece wandb einops -xformers==0.0.27 +xformers==0.0.28.post1 optimum==1.16.2 hf_transfer colorama diff --git a/setup.py b/setup.py index 1b64fadae..e939bc37e 100644 --- a/setup.py +++ b/setup.py @@ -49,10 +49,17 @@ def parse_requirements(): else: raise ValueError("Invalid version format") + if (major, minor) >= (2, 4): + if patch == 0: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.27") if (major, minor) >= (2, 3): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.26.post1") + else: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.27") elif (major, minor) >= (2, 2): _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.25.post1") diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 4171f28b2..c757eca42 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -31,6 +31,7 @@ from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils.chat_templates import get_chat_template +from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( normalize_cfg_datasets, normalize_config, @@ -421,6 +422,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): setup_mlflow_env_vars(cfg) + setup_comet_env_vars(cfg) + return cfg diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b73849d7a..010fcb72e 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -48,7 +48,7 @@ from trl.trainer.utils import pad_to_length from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler -from axolotl.utils import is_mlflow_available +from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, GPUStatsCallback, @@ -1111,6 +1111,12 @@ class TrainerBuilderBase(abc.ABC): callbacks.append( SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_comet and is_comet_available(): + from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback + + callbacks.append( + SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) + ) return callbacks @@ -1179,6 +1185,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer, self.tokenizer, "mlflow" ) callbacks.append(LogPredictionCallback(self.cfg)) + if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0: + LogPredictionCallback = log_prediction_callback_factory( + trainer, self.tokenizer, "comet_ml" + ) + callbacks.append(LogPredictionCallback(self.cfg)) if self.cfg.do_bench_eval: callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) @@ -1430,6 +1441,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): report_to.append("mlflow") if self.cfg.use_tensorboard: report_to.append("tensorboard") + if self.cfg.use_comet: + report_to.append("comet_ml") training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["run_name"] = ( diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index e4352cbe3..9d246cb17 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio): def reset_optimizer( optimizer: torch.optim.Optimizer, *, - reset_params: list[str], # where str is the key to a torch.nn.Parameter - optimizer_state_keys: list[str], + reset_params: List[str], # where str is the key to a torch.nn.Parameter + optimizer_state_keys: List[str], prune_ratio: float = 0.9, ): pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 99dec79f1..91545009a 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -1,8 +1,12 @@ """ Basic utils for Axolotl """ -import importlib +import importlib.util def is_mlflow_available(): return importlib.util.find_spec("mlflow") is not None + + +def is_comet_available(): + return importlib.util.find_spec("comet_ml") is not None diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 73715b06a..0bc781fcb 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -29,7 +29,7 @@ from transformers import ( ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy -from axolotl.utils import is_mlflow_available +from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): references=[[r] for r in references], predictions=predictions, ) - scores[metric_name] = score + scores["eval_" + metric_name] = score return scores def predict_with_generate(): @@ -747,6 +747,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): artifact_file="PredictionsVsGroundTruth.json", tracking_uri=tracking_uri, ) + elif logger == "comet_ml" and is_comet_available(): + import comet_ml + + experiment = comet_ml.get_running_experiment() + if experiment: + experiment.log_table( + f"{name} - Predictions vs Ground Truth.csv", + pd.DataFrame(table_data), + ) if is_main_process(): log_table_from_dataloader("Eval", eval_dataloader) diff --git a/src/axolotl/utils/callbacks/comet_.py b/src/axolotl/utils/callbacks/comet_.py new file mode 100644 index 000000000..b29f997a8 --- /dev/null +++ b/src/axolotl/utils/callbacks/comet_.py @@ -0,0 +1,43 @@ +"""Comet module for trainer callbacks""" + +import logging +from typing import TYPE_CHECKING + +import comet_ml +from transformers import TrainerCallback, TrainerControl, TrainerState + +from axolotl.utils.distributed import is_main_process + +if TYPE_CHECKING: + from axolotl.core.trainer_builder import AxolotlTrainingArguments + +LOG = logging.getLogger("axolotl.callbacks") + + +class SaveAxolotlConfigtoCometCallback(TrainerCallback): + """Callback to save axolotl config to comet""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: "AxolotlTrainingArguments", # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, # pylint: disable=unused-argument + ): + if is_main_process(): + try: + comet_experiment = comet_ml.start(source="axolotl") + comet_experiment.log_other("Created from", "axolotl") + comet_experiment.log_asset( + self.axolotl_config_path, + file_name="axolotl-config", + ) + LOG.info( + "The Axolotl config has been saved to the Comet Experiment under assets." + ) + except (FileNotFoundError, ConnectionError) as err: + LOG.warning(f"Error while saving Axolotl config to Comet: {err}") + return control diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index f3dc8f5fc..620098ae0 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -28,6 +28,20 @@ _CHAT_TEMPLATES = { "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', } +CHAT_TEMPLATES = { + "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", + "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. + "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", + "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", + "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n', + "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", + "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", + "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', +} + def get_chat_template( user_choice: str, @@ -125,4 +139,4 @@ def register_chat_template(template_name: str, chat_template: str): if template_name in _CHAT_TEMPLATES: raise ValueError(f"Template '{template_name}' already exists.") - _CHAT_TEMPLATES[template_name] = chat_template \ No newline at end of file + _CHAT_TEMPLATES[template_name] = chat_template diff --git a/src/axolotl/utils/comet_.py b/src/axolotl/utils/comet_.py new file mode 100644 index 000000000..b4ecc80ad --- /dev/null +++ b/src/axolotl/utils/comet_.py @@ -0,0 +1,93 @@ +"""Module for wandb utilities""" + +import logging +import os + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.utils.comet_") + +COMET_ENV_MAPPING_OVERRIDE = { + "comet_mode": "COMET_START_MODE", + "comet_online": "COMET_START_ONLINE", +} +COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = { + "auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS", + "auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE", + "auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS", + "auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD", + "auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS", + "auto_log_co2": "COMET_AUTO_LOG_CO2", + "auto_metric_logging": "COMET_AUTO_LOG_METRICS", + "auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE", + "auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER", + "auto_param_logging": "COMET_AUTO_LOG_PARAMETERS", + "comet_disabled": "COMET_AUTO_LOG_DISABLE", + "display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL", + "distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER", + "log_code": "COMET_AUTO_LOG_CODE", + "log_env_cpu": "COMET_AUTO_LOG_ENV_CPU", + "log_env_details": "COMET_AUTO_LOG_ENV_DETAILS", + "log_env_disk": "COMET_AUTO_LOG_ENV_DISK", + "log_env_gpu": "COMET_AUTO_LOG_ENV_GPU", + "log_env_host": "COMET_AUTO_LOG_ENV_HOST", + "log_env_network": "COMET_AUTO_LOG_ENV_NETWORK", + "log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA", + "log_git_patch": "COMET_AUTO_LOG_GIT_PATCH", + "log_graph": "COMET_AUTO_LOG_GRAPH", + "name": "COMET_START_EXPERIMENT_NAME", + "offline_directory": "COMET_OFFLINE_DIRECTORY", + "parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS", + "tags": "COMET_START_EXPERIMENT_TAGS", +} + + +def python_value_to_environ_value(python_value): + if isinstance(python_value, bool): + if python_value is True: + return "true" + + return "false" + + if isinstance(python_value, int): + return str(python_value) + + if isinstance(python_value, list): # Comet only have one list of string parameter + return ",".join(map(str, python_value)) + + return python_value + + +def setup_comet_env_vars(cfg: DictDefault): + # TODO, we need to convert Axolotl configuration to environment variables + # as Transformers integration are call first and would create an + # Experiment first + + for key in cfg.keys(): + if key.startswith("comet_") and key != "comet_experiment_config": + value = cfg.get(key, "") + + if value is not None and value != "": + env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper()) + final_value = python_value_to_environ_value(value) + os.environ[env_variable_name] = final_value + + if cfg.comet_experiment_config: + for key, value in cfg.comet_experiment_config.items(): + if value is not None and value != "": + config_env_variable_name = ( + COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key) + ) + + if config_env_variable_name is None: + LOG.warning( + f"Unknown Comet Experiment Config name {key}, ignoring it" + ) + continue + + final_value = python_value_to_environ_value(value) + os.environ[config_env_variable_name] = final_value + + # Enable comet if project name is present + if cfg.comet_project_name and len(cfg.comet_project_name) > 0: + cfg.use_comet = True diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 8c8218d08..81f3021e8 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -516,6 +516,19 @@ class WandbConfig(BaseModel): return data +class CometConfig(BaseModel): + """Comet configuration subset""" + + use_comet: Optional[bool] = None + comet_api_key: Optional[str] = None + comet_workspace: Optional[str] = None + comet_project_name: Optional[str] = None + comet_experiment_key: Optional[str] = None + comet_mode: Optional[str] = None + comet_online: Optional[bool] = None + comet_experiment_config: Optional[Dict[str, Any]] = None + + class GradioConfig(BaseModel): """Gradio configuration subset""" @@ -536,6 +549,7 @@ class AxolotlInputConfig( HyperparametersConfig, WandbConfig, MLFlowConfig, + CometConfig, LISAConfig, GradioConfig, RemappedParameters, diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 7a5582ddc..920f31ef4 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -73,7 +73,7 @@ class TestAssistantChatTemplateLlama3: strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=get_chat_template("llama3"), + chat_templates("llama3"), message_field_role="role", message_field_content="content", roles={ diff --git a/tests/test_validation.py b/tests/test_validation.py index 35d0e265e..6e0d0ad2a 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -9,6 +9,7 @@ from typing import Optional import pytest from pydantic import ValidationError +from axolotl.utils import is_comet_available from axolotl.utils.config import validate_config from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.dict import DictDefault @@ -1329,3 +1330,105 @@ class TestValidationWandb(BaseValidation): os.environ.pop("WANDB_PROJECT", None) os.environ.pop("WANDB_DISABLED", None) + + +@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed") +class TestValidationComet(BaseValidation): + """ + Validation test for comet + """ + + def test_comet_sets_env(self, minimal_cfg): + from axolotl.utils.comet_ import setup_comet_env_vars + + comet_config = { + "comet_api_key": "foo", + "comet_workspace": "some_workspace", + "comet_project_name": "some_project", + "comet_experiment_key": "some_experiment_key", + "comet_mode": "get_or_create", + "comet_online": False, + "comet_experiment_config": { + "auto_histogram_activation_logging": False, + "auto_histogram_epoch_rate": 2, + "auto_histogram_gradient_logging": True, + "auto_histogram_tensorboard_logging": False, + "auto_histogram_weight_logging": True, + "auto_log_co2": False, + "auto_metric_logging": True, + "auto_metric_step_rate": 15, + "auto_output_logging": False, + "auto_param_logging": True, + "comet_disabled": False, + "display_summary_level": 2, + "distributed_node_identifier": "some_distributed_node_identifier", + "log_code": True, + "log_env_cpu": False, + "log_env_details": True, + "log_env_disk": False, + "log_env_gpu": True, + "log_env_host": False, + "log_env_network": True, + "log_git_metadata": False, + "log_git_patch": True, + "log_graph": False, + "name": "some_name", + "offline_directory": "some_offline_directory", + "parse_args": True, + "tags": ["tag1", "tag2"], + }, + } + + cfg = DictDefault(comet_config) | minimal_cfg + + new_cfg = validate_config(cfg) + + setup_comet_env_vars(new_cfg) + + comet_env = { + key: value for key, value in os.environ.items() if key.startswith("COMET_") + } + + assert ( + len(comet_env) + == len(comet_config) + len(comet_config["comet_experiment_config"]) - 1 + ) + + assert comet_env == { + "COMET_API_KEY": "foo", + "COMET_AUTO_LOG_CLI_ARGUMENTS": "true", + "COMET_AUTO_LOG_CO2": "false", + "COMET_AUTO_LOG_CODE": "true", + "COMET_AUTO_LOG_DISABLE": "false", + "COMET_AUTO_LOG_ENV_CPU": "false", + "COMET_AUTO_LOG_ENV_DETAILS": "true", + "COMET_AUTO_LOG_ENV_DISK": "false", + "COMET_AUTO_LOG_ENV_GPU": "true", + "COMET_AUTO_LOG_ENV_HOST": "false", + "COMET_AUTO_LOG_ENV_NETWORK": "true", + "COMET_AUTO_LOG_GIT_METADATA": "false", + "COMET_AUTO_LOG_GIT_PATCH": "true", + "COMET_AUTO_LOG_GRAPH": "false", + "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false", + "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2", + "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true", + "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false", + "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true", + "COMET_AUTO_LOG_METRIC_STEP_RATE": "15", + "COMET_AUTO_LOG_METRICS": "true", + "COMET_AUTO_LOG_OUTPUT_LOGGER": "false", + "COMET_AUTO_LOG_PARAMETERS": "true", + "COMET_DISPLAY_SUMMARY_LEVEL": "2", + "COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier", + "COMET_EXPERIMENT_KEY": "some_experiment_key", + "COMET_OFFLINE_DIRECTORY": "some_offline_directory", + "COMET_PROJECT_NAME": "some_project", + "COMET_START_EXPERIMENT_NAME": "some_name", + "COMET_START_EXPERIMENT_TAGS": "tag1,tag2", + "COMET_START_MODE": "get_or_create", + "COMET_START_ONLINE": "false", + "COMET_WORKSPACE": "some_workspace", + } + + for key in comet_env.keys(): + os.environ.pop(key, None)