From 5c42f114115cc4e2dd49c1da2437ef1c08aecf69 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Sep 2024 22:19:54 -0400 Subject: [PATCH 01/12] remove dynamic module loader monkeypatch as this was fixed upstream (#1914) --- examples/deepseek-v2/qlora-fsdp-2_5.yaml | 83 +++++++++++++++++++ requirements.txt | 4 +- .../transformers_dynamic_module_utils.py | 51 ------------ src/axolotl/utils/models.py | 5 -- 4 files changed, 85 insertions(+), 58 deletions(-) create mode 100644 examples/deepseek-v2/qlora-fsdp-2_5.yaml delete mode 100644 src/axolotl/monkeypatch/transformers_dynamic_module_utils.py diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml new file mode 100644 index 000000000..6e82062d6 --- /dev/null +++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml @@ -0,0 +1,83 @@ +base_model: axolotl-quants/DeepSeek-V2.5-bnb-nf4-bf16 +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: true +strict: false + + +plugins: + - axolotl.integrations.liger.LigerPlugin +liger_rms_norm: true +liger_swiglu: true +liger_fused_linear_cross_entropy: true + +chat_template: deepseek_v2 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +adapter: qlora +lora_r: 256 +lora_alpha: 256 +lora_target_linear: true +peft_use_rslora: true + +gradient_accumulation_steps: 1 +micro_batch_size: 8 +num_epochs: 1 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 2 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +special_tokens: +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD diff --git a/requirements.txt b/requirements.txt index 83116af60..32a9e0e01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.12.0 -transformers==4.44.2 +transformers @ git+https://github.com/huggingface/transformers.git@0963229e287501bed52ae1dabc17922524de6992 tokenizers>=0.19.1 bitsandbytes==0.43.3 accelerate==0.34.2 -datasets==2.20.0 +datasets==2.21.0 deepspeed==0.14.4 pydantic==2.6.3 addict diff --git a/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py b/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py deleted file mode 100644 index dfc3e29c5..000000000 --- a/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Patch transformers.dynamic_module_utils.get_class_in_module to avoid reloading models from disk""" - -import importlib -import os -import sys -import typing -from pathlib import Path - -from transformers.file_utils import HF_MODULES_CACHE - - -def _patched_get_class_in_module( - class_name: str, module_path: typing.Union[str, os.PathLike] -) -> typing.Type: - """ - Import a module on the cache directory for modules and extract a class from it. - - Args: - class_name (`str`): The name of the class to import. - module_path (`str` or `os.PathLike`): The path to the module to import. - - Returns: - `typing.Type`: The class looked for. - """ - name = os.path.normpath(module_path) - if name.endswith(".py"): - name = name[:-3] - name = name.replace(os.path.sep, ".") - module_spec = importlib.util.spec_from_file_location( - name, location=Path(HF_MODULES_CACHE) / module_path - ) - module = sys.modules.get(name) - if module is None: - module = importlib.util.module_from_spec(module_spec) - # insert it into sys.modules before any loading begins - sys.modules[name] = module - # load in initial case only - module_spec.loader.exec_module(module) - return getattr(module, class_name) - - -def patch_transformers_dynamic_module_utils(): - """ - Recently, transformers started reloading modeling code from disk for models marked trust_remote_code=True. - This causes monkey-patches for multipack and liger to be removed. - We replace the original function with a version that does not reload the module from disk. - See https://github.com/huggingface/transformers/pull/30370#pullrequestreview-2264361581 - """ - import transformers - - transformers.dynamic_module_utils.get_class_in_module = _patched_get_class_in_module diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e0526fb04..e18330199 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -43,9 +43,6 @@ from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) -from axolotl.monkeypatch.transformers_dynamic_module_utils import ( - patch_transformers_dynamic_module_utils, -) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import chat_templates @@ -57,8 +54,6 @@ from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_mod LOG = logging.getLogger("axolotl") -patch_transformers_dynamic_module_utils() - # copied from accelerator.FullyShardedDataParallelPlugin def get_module_class_from_name(module, name): From 7b9f669a3ab18aecb00b17e7f2885aeb458440c8 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Sat, 14 Sep 2024 05:22:54 -0700 Subject: [PATCH 02/12] Trigger the original tokenization behavior when no advanced turn settings are provided (#1915) --- examples/phi/lora-3.5.yaml | 76 ++ .../prompt_strategies/chat_template.py | 54 +- src/axolotl/utils/chat_templates.py | 1 + .../config/models/input/v0_4_1/__init__.py | 1 + tests/prompt_strategies/conftest.py | 71 ++ .../prompt_strategies/test_chat_templates.py | 714 ++---------------- .../test_chat_templates_advanced.py | 615 +++++++++++++++ 7 files changed, 866 insertions(+), 666 deletions(-) create mode 100644 examples/phi/lora-3.5.yaml create mode 100644 tests/prompt_strategies/conftest.py create mode 100644 tests/prompt_strategies/test_chat_templates_advanced.py diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml new file mode 100644 index 000000000..59d667b8d --- /dev/null +++ b/examples/phi/lora-3.5.yaml @@ -0,0 +1,76 @@ +base_model: microsoft/Phi-3.5-mini-instruct +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: true +load_in_4bit: false +strict: false + +chat_template: phi_3 +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + chat_template: phi_3 + field_messages: messages + message_field_role: role + message_field_content: content + roles: + user: + - user + assistant: + - assistant + +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/lora-out + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 2 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bfloat16: true +bf16: true +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 4 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 19e36531a..717367eef 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -24,8 +24,8 @@ class ChatTemplatePrompter(Prompter): max_length=2048, message_field_role: str = "from", message_field_content: str = "value", - message_field_training: str = "train", - message_field_training_detail: str = "train_detail", + message_field_training: Optional[str] = None, + message_field_training_detail: Optional[str] = None, roles: Optional[Dict[str, List[str]]] = None, drop_system_message: bool = False, ): @@ -186,7 +186,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): train_on_inputs, sequence_len, roles_to_train=None, - train_on_eos="last", + train_on_eos=None, ): super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) self.roles_to_train = roles_to_train if roles_to_train is not None else [] @@ -201,6 +201,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): self._messages = messages def tokenize_prompt(self, prompt): + # Old simple legacy behavior that works reliably. + if ( + not self.roles_to_train + and not self.train_on_eos + and not self.prompter.message_field_training + and not self.prompter.message_field_training_detail + ): + turns = self.get_conversation_thread(prompt) + prompt_ids = self.prompter.build_prompt( + turns[:-1], add_generation_prompt=True + ) + input_ids = self.prompter.build_prompt(turns) + + if not self.train_on_inputs: + user_prompt_len = len(prompt_ids) + labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] + else: + labels = input_ids + + tokenized_prompt = { + "input_ids": input_ids, + "labels": labels, + "attention_mask": [1] * len(input_ids), + } + + return tokenized_prompt + LOG.info(self.roles_to_train) + LOG.info(self.train_on_eos) + LOG.info(self.prompter.message_field_training) + LOG.info(self.prompter.message_field_training_detail) + turns = prompt[self.messages] input_ids = self.prompter.build_prompt(turns) labels = [IGNORE_TOKEN_ID] * len(input_ids) @@ -219,9 +250,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): should_train = ( train_turn if train_turn is not None - else bool(train_detail is not None) - if train_detail is not None - else self.train_on_inputs or role in self.roles_to_train + else ( + bool(train_detail is not None) + if train_detail is not None + else self.train_on_inputs or role in self.roles_to_train + ) ) LOG.debug(f"Should train: {should_train}") @@ -344,9 +377,10 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), "message_field_role": ds_cfg.get("message_field_role", "from"), "message_field_content": ds_cfg.get("message_field_content", "value"), - "message_field_training": ds_cfg.get("message_field_training", "training"), + "message_field_training": ds_cfg.get("message_field_training", None), "message_field_training_detail": ds_cfg.get( - "message_field_training_detail", "train_detail" + "message_field_training_detail", + None, ), "roles": ds_cfg.get("roles"), "drop_system_message": ds_cfg.get("drop_system_message", False), @@ -357,8 +391,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): strategy_params = { "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), - "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + "roles_to_train": ds_cfg.get("roles_to_train", []), + "train_on_eos": ds_cfg.get("train_on_eos", None), } strategy = ChatTemplateStrategy( diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 51f88b1bd..7a96f5c1e 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -26,6 +26,7 @@ def chat_templates(user_choice: str): "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', } diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 9044047cc..458bacdb1 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -189,6 +189,7 @@ class ChatTemplate(str, Enum): cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name + phi_35 = "phi_35" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name jamba = "jamba" # pylint: disable=invalid-name diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py new file mode 100644 index 000000000..43423f725 --- /dev/null +++ b/tests/prompt_strategies/conftest.py @@ -0,0 +1,71 @@ +""" +shared fixtures for prompt strategies tests +""" + +import pytest +from datasets import Dataset +from transformers import AutoTokenizer + + +@pytest.fixture(name="assistant_dataset") +def fixture_assistant_dataset(): + return Dataset.from_list( + [ + { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "goodbye"}, + {"role": "assistant", "content": "goodbye"}, + ] + } + ] + ) + + +@pytest.fixture(name="sharegpt_dataset") +def fixture_sharegpt_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversations": [ + {"from": "human", "value": "hello"}, + {"from": "gpt", "value": "hello"}, + {"from": "human", "value": "goodbye"}, + {"from": "gpt", "value": "goodbye"}, + ] + } + ] + ) + + +@pytest.fixture(name="basic_dataset") +def fixture_basic_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversations": [ + {"from": "system", "value": "You are an AI assistant."}, + {"from": "human", "value": "Hello"}, + {"from": "assistant", "value": "Hi there!"}, + {"from": "human", "value": "How are you?"}, + {"from": "assistant", "value": "I'm doing well, thank you!"}, + ] + } + ] + ) + + +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") + + return tokenizer + + +@pytest.fixture(name="phi35_tokenizer") +def fixture_phi35_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct") + return tokenizer diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index e2fc0f6a5..28210b7ae 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -5,10 +5,6 @@ tests for chat_template prompt strategy import logging import unittest -import pytest -from datasets import Dataset -from transformers import AutoTokenizer - from axolotl.prompt_strategies.chat_template import ( ChatTemplatePrompter, ChatTemplateStrategy, @@ -22,657 +18,6 @@ logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger("axolotl") -@pytest.fixture(name="assistant_dataset") -def fixture_assistant_dataset(): - return Dataset.from_list( - [ - { - "messages": [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "hello"}, - {"role": "user", "content": "goodbye"}, - {"role": "assistant", "content": "goodbye"}, - ] - } - ] - ) - - -@pytest.fixture(name="sharegpt_dataset") -def fixture_sharegpt_dataset(): - # pylint: disable=duplicate-code - return Dataset.from_list( - [ - { - "conversations": [ - {"from": "human", "value": "hello"}, - {"from": "gpt", "value": "hello"}, - {"from": "human", "value": "goodbye"}, - {"from": "gpt", "value": "goodbye"}, - ] - } - ] - ) - - -@pytest.fixture(name="basic_dataset") -def fixture_basic_dataset(): - # pylint: disable=duplicate-code - return Dataset.from_list( - [ - { - "conversations": [ - {"from": "system", "value": "You are an AI assistant."}, - {"from": "human", "value": "Hello"}, - {"from": "assistant", "value": "Hi there!"}, - {"from": "human", "value": "How are you?"}, - {"from": "assistant", "value": "I'm doing well, thank you!"}, - ] - } - ] - ) - - -@pytest.fixture(name="llama3_tokenizer") -def fixture_llama3_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") - - return tokenizer - - -class TestChatTemplateConfigurations: - """ - Test class for various configurations of ChatTemplateStrategy. - """ - - @staticmethod - def find_sublist(full_list, sub_list): - token_count = len(sub_list) - for index in range(len(full_list) - token_count + 1): - if full_list[index : index + token_count] == sub_list: - return index - return -1 - - def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_inputs=True") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=True, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that assistant responses are labeled - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert start_idx != -1, f"Could not find '{response}' in input_ids" - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - # Check the behavior of human inputs - human_inputs = ["Hello", "How are you?"] - for input_text in human_inputs: - input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, input_ids) - labeled = all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(input_ids)] - ) - LOG.debug( - f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}" - ) - - LOG.debug("Full labels: %s", labels) - LOG.debug("Full input_ids: %s", input_ids) - - def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_inputs=False") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that only assistant responses are labeled - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert start_idx != -1, f"Could not find '{response}' in input_ids" - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - # Verify that human inputs are not labeled - human_inputs = ["Hello", "How are you?"] - for input_text in human_inputs: - input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, input_ids) - LOG.debug( - f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}" - ) - assert start_idx != -1, f"Could not find '{input_text}' in input_ids" - assert all( - label == IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(input_ids)] - ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}" - - def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing roles_to_train with assistant only") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that only assistant responses are labeled - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing roles_to_train with all roles") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=True, - sequence_len=512, - roles_to_train=["human", "assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that all responses are labeled (except for special tokens) - all_responses = [ - "Hello", - "Hi there!", - "How are you?", - "I'm doing well, thank you!", - ] - for response in all_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with empty roles_to_train") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=[], - train_on_eos="none", # Add this line - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - - # Verify that no labels are set when roles_to_train is empty - LOG.debug("Full labels: %s", labels) - assert all( - label == IGNORE_TOKEN_ID for label in labels - ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty" - - def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='all'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="all", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - eos_indices = [ - i for i, token_id in enumerate(input_ids) if token_id == eos_token_id - ] - - assert len(eos_indices) > 0, "Expected at least one EOS token in the input" - for eos_idx in eos_indices: - assert ( - labels[eos_idx] != IGNORE_TOKEN_ID - ), f"Expected EOS token at index {eos_idx} to be labeled" - - def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='turn'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="turn", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - assert start_idx != -1, f"Could not find '{response}' in input_ids" - - eos_idx = start_idx + len(response_ids) - while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: - eos_idx += 1 - - assert eos_idx < len( - input_ids - ), f"Could not find EOS token after '{response}'" - assert ( - labels[eos_idx] != IGNORE_TOKEN_ID - ), f"Expected EOS token after assistant response '{response}' to be labeled" - - # Check that EOS tokens after human inputs are not labeled - human_inputs = ["Hello", "How are you?"] - for input_text in human_inputs: - input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, input_ids) - assert start_idx != -1, f"Could not find '{input_text}' in input_ids" - - eos_idx = start_idx + len(input_ids) - while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: - eos_idx += 1 - - assert ( - labels[eos_idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token after human input '{input_text}' to not be labeled" - - def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='last'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="last", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - eos_indices = [ - i for i, token_id in enumerate(input_ids) if token_id == eos_token_id - ] - - assert len(eos_indices) > 0, "Expected at least one EOS token in the input" - last_eos_idx = eos_indices[-1] - - # Check that only the last EOS token is labeled - for idx in eos_indices[:-1]: - assert ( - labels[idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token at index {idx} to not be labeled" - assert ( - labels[last_eos_idx] != IGNORE_TOKEN_ID - ), f"Expected last EOS token at index {last_eos_idx} to be labeled" - - def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='none'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="none", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - eos_indices = [ - i for i, token_id in enumerate(input_ids) if token_id == eos_token_id - ] - - assert len(eos_indices) > 0, "Expected at least one EOS token in the input" - for eos_idx in eos_indices: - assert ( - labels[eos_idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token at index {eos_idx} to not be labeled" - - def test_drop_system_message(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with drop_system_message=True") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), drop_system_message=True - ), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - input_ids = res["input_ids"] - - # Check if system message is not present in input_ids - system_message = "You are an AI assistant." - system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False) - assert ( - self.find_sublist(input_ids, system_ids) == -1 - ), "Expected system message to be dropped" - - def test_custom_roles(self, llama3_tokenizer): - LOG.info("Testing with custom roles mapping") - custom_roles = { - "user": ["human", "user"], - "assistant": ["ai", "assistant"], - "system": ["context"], - } - strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), roles=custom_roles - ), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["ai"], - ) - - # Create a new dataset with modified role names - modified_conversations = [ - {"from": "context", "value": "You are an AI assistant."}, - {"from": "human", "value": "Hello"}, - {"from": "ai", "value": "Hi there!"}, - {"from": "human", "value": "How are you?"}, - {"from": "ai", "value": "I'm doing well, thank you!"}, - ] - - modified_dataset = Dataset.from_dict( - {"conversations": [modified_conversations]} - ) - - res = strategy.tokenize_prompt(modified_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Check if AI responses are labeled correctly - ai_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in ai_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - assert start_idx != -1, f"Could not find response '{response}' in input_ids" - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for AI response '{response}' to be set" - - # Check if human messages are not labeled - human_messages = ["Hello", "How are you?"] - for message in human_messages: - message_ids = llama3_tokenizer.encode(message, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, message_ids) - assert start_idx != -1, f"Could not find message '{message}' in input_ids" - assert all( - label == IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(message_ids)] - ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID" - - def test_message_field_training(self, llama3_tokenizer): - LOG.info("Testing with message_field_training") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - llama3_tokenizer, - chat_templates("llama3"), - message_field_training="train", - message_field_training_detail="train_detail", - ), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=[], - ) - - # Create a new dataset with the train and train_detail fields - modified_conversation = [ - {"from": "system", "value": "You are an AI assistant.", "train": False}, - {"from": "human", "value": "Hello", "train": False}, - {"from": "assistant", "value": "Hello", "train": True}, - {"from": "human", "value": "How are you?", "train": True}, - { - "from": "assistant", - "value": "I'm doing very well, thank you!", - "train_detail": [ - {"begin_offset": 0, "end_offset": 8, "train": False}, - {"begin_offset": 9, "end_offset": 18, "train": True}, - {"begin_offset": 19, "end_offset": 30, "train": False}, - ], - }, - { - "from": "human", - "value": "I'm doing very well, thank you!", - "train": False, - }, - {"from": "assistant", "value": "Hi there!", "train": True}, - ] - - modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]}) - - res = strategy.tokenize_prompt(modified_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Function to find all occurrences of a sublist - def find_all_sublists(full_list, sub_list): - indices = [] - for index in range(len(full_list) - len(sub_list) + 1): - if full_list[index : index + len(sub_list)] == sub_list: - indices.append(index) - return indices - - # Keep track of which occurrences we've processed - processed_occurrences = {} - # Check if messages are labeled correctly based on train or train_detail - for i, turn in enumerate(modified_conversation): - turn_tokens = llama3_tokenizer.encode( - turn["value"], add_special_tokens=False - ) - occurrences = find_all_sublists(input_ids, turn_tokens) - turn_key = turn["value"] - if turn_key not in processed_occurrences: - processed_occurrences[turn_key] = 0 - current_occurrence = processed_occurrences[turn_key] - - if current_occurrence >= len(occurrences): - assert ( - False - ), f"Not enough occurrences found for message: {turn['value']}" - - start_idx = occurrences[current_occurrence] - processed_occurrences[turn_key] += 1 - end_idx = start_idx + len(turn_tokens) - - LOG.debug( - f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}" - ) - - if "train_detail" in turn: - # Get token offsets - tokenized_output = llama3_tokenizer( - turn["value"], return_offsets_mapping=True, add_special_tokens=False - ) - token_offsets = tokenized_output["offset_mapping"] - - # Adjust token offsets as done in the implementation - for i in range(len(token_offsets) - 1): - token_offsets[i] = ( - token_offsets[i][0], - token_offsets[i + 1][0] - 1, - ) - token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1) - - # Adjust train_details - adjusted_train_details = strategy.prompter.adjust_train_details( - turn["train_detail"], token_offsets - ) - - LOG.debug(f"Original train_details: {turn['train_detail']}") - LOG.debug(f"Adjusted train_details: {adjusted_train_details}") - - # Handle train_detail - token_offsets = strategy.prompter.get_offsets_for_train_detail( - text=turn["value"], - train_details=adjusted_train_details, - mask_untrainable=False, - ) - token_offsets_masked = strategy.prompter.get_offsets_for_train_detail( - text=turn["value"], - train_details=adjusted_train_details, - mask_untrainable=True, - ) - LOG.debug(f"Token offsets: {token_offsets_masked}") - - expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens) - for i, offset in enumerate(token_offsets_masked): - if offset != IGNORE_TOKEN_ID: - expected_labels[i] = turn_tokens[i] - actual_labels = labels[ - start_idx : start_idx + len(token_offsets_masked) - ] - assert ( - actual_labels == expected_labels - ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}" - - for detail in adjusted_train_details: - # Find the token indices that correspond to the character offsets - detail_start = start_idx + next( - i - for i, offset in enumerate(token_offsets) - if offset >= detail["begin_offset"] - ) - detail_end = start_idx + next( - ( - i - for i, offset in enumerate(token_offsets) - if offset > detail["end_offset"] - ), - len(token_offsets), - ) - - detail_text = turn["value"][ - detail["begin_offset"] : detail["end_offset"] + 1 - ] - detail_labels = labels[detail_start:detail_end] - detail_input_ids = input_ids[detail_start:detail_end] - - LOG.debug( - f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}" - ) - LOG.debug(f"Detail input_ids: {detail_input_ids}") - LOG.debug(f"Detail labels: {detail_labels}") - LOG.debug( - f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}" - ) - LOG.debug( - f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}" - ) - - if detail["train"]: - assert all( - label != IGNORE_TOKEN_ID for label in detail_labels - ), ( - f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. " - f"Labels({detail_start}:{detail_end}): {detail_labels}, " - f"InputIDs: {detail_input_ids}, " - f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" - ) - else: - assert all( - label == IGNORE_TOKEN_ID for label in detail_labels - ), ( - f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. " - f"Labels({detail_start}:{detail_end}): {detail_labels}, " - f"InputIDs: {detail_input_ids}, " - f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" - ) - else: - should_train = turn.get("train", False) - turn_labels = labels[start_idx:end_idx] - - LOG.debug(f"Should train: {should_train}") - LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}") - LOG.debug(f"Turn labels: {turn_labels}") - LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}") - LOG.debug( - f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}" - ) - - if should_train: - assert all(label != IGNORE_TOKEN_ID for label in turn_labels), ( - f"Expected all labels for '{turn['value']}' to be set\n" - f"Labels({start_idx}:{end_idx}): {turn_labels}, " - f"InputIDs: {input_ids[start_idx:end_idx]}, " - f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" - ) - else: - assert all(label == IGNORE_TOKEN_ID for label in turn_labels), ( - f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n" - f"Labels({start_idx}:{end_idx}): {turn_labels}, " - f"InputIDs: {input_ids[start_idx:end_idx]}, " - f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" - ) - - LOG.debug( - f"Processed turn: {turn['from']}, content: '{turn['value']}', " - f"start_idx: {start_idx}, end_idx: {end_idx}, " - f"labels: {labels[start_idx:end_idx]}" - ) - - LOG.debug(f"Final labels: {labels}") - LOG.debug(f"Final input_ids: {input_ids}") - - class TestAssistantChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. @@ -740,7 +85,6 @@ class TestAssistantChatTemplateLlama3: tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, - roles_to_train=["assistant"], ) strategy.messages = "messages" res = strategy.tokenize_prompt(assistant_dataset[0]) @@ -764,6 +108,64 @@ class TestAssistantChatTemplateLlama3: input_ids == expected_input_ids ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + def test_phi35(self, phi35_tokenizer, assistant_dataset): + LOG.info("Testing phi-3.5 with assistant dataset") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + phi35_tokenizer, + chat_templates("phi_35"), + message_field_role="role", + message_field_content="content", + roles={ + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ), + tokenizer=phi35_tokenizer, + train_on_inputs=False, + sequence_len=512, + ) + strategy.messages = "messages" + res = strategy.tokenize_prompt(assistant_dataset[0]) + input_ids = res["input_ids"] + labels = res["labels"] + # fmt: off + expected_input_ids = [ + 32010, # user + 22172, 32007, # user eot + 32001, # assistant + 22172, 32007, # assistant eot + 32010, # user + 1781, 26966, 32007, # user eot + 32001, # assistant + 1781, 26966, 32007, # assistant eot + 32000, # eos + ] + expected_labels = [ + -100, # user + -100, -100, # user eot + -100, # assistant + -100, -100, # assistant eot, + -100, # user + -100, -100, -100, # user eot + -100, # assistant + 1781, 26966, 32007, # assistant eot + 32000, # eos + ] + # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + + LOG.debug(f"Expected labels : {expected_labels}") + LOG.debug(f"Actual labels : {labels}") + assert ( + labels == expected_labels + ), f"Input IDs mismatch: {labels} != {expected_labels}" + def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): LOG.info("Testing llama-3 with assistant dataset including training data") strategy = ChatTemplateStrategy( diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py new file mode 100644 index 000000000..f18fb3942 --- /dev/null +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -0,0 +1,615 @@ +""" +tests for chat_template prompt strategy +""" + +import logging +import unittest + +from datasets import Dataset + +from axolotl.prompt_strategies.chat_template import ( + ChatTemplatePrompter, + ChatTemplateStrategy, +) +from axolotl.prompters import IGNORE_TOKEN_ID +from axolotl.utils.chat_templates import chat_templates + +logging.basicConfig(level=logging.DEBUG) +LOG = logging.getLogger("axolotl") + + +class TestChatTemplateConfigurations: + """ + Test class for various configurations of ChatTemplateStrategy. + """ + + @staticmethod + def find_sublist(full_list, sub_list): + token_count = len(sub_list) + for index in range(len(full_list) - token_count + 1): + if full_list[index : index + token_count] == sub_list: + return index + return -1 + + def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_inputs=True") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=True, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + # Check the behavior of human inputs + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + labeled = all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(input_ids)] + ) + LOG.debug( + f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}" + ) + + LOG.debug("Full labels: %s", labels) + LOG.debug("Full input_ids: %s", input_ids) + + def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_inputs=False") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that only assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + # Verify that human inputs are not labeled + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + LOG.debug( + f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{input_text}' in input_ids" + assert all( + label == IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(input_ids)] + ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}" + + def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing roles_to_train with assistant only") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that only assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing roles_to_train with all roles") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=True, + sequence_len=512, + roles_to_train=["human", "assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that all responses are labeled (except for special tokens) + all_responses = [ + "Hello", + "Hi there!", + "How are you?", + "I'm doing well, thank you!", + ] + for response in all_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with empty roles_to_train") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=[], + train_on_eos="none", # Add this line + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + + # Verify that no labels are set when roles_to_train is empty + LOG.debug("Full labels: %s", labels) + assert all( + label == IGNORE_TOKEN_ID for label in labels + ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty" + + def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='all'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="all", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + for eos_idx in eos_indices: + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOS token at index {eos_idx} to be labeled" + + def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='turn'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="turn", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + + eos_idx = start_idx + len(response_ids) + while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: + eos_idx += 1 + + assert eos_idx < len( + input_ids + ), f"Could not find EOS token after '{response}'" + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOS token after assistant response '{response}' to be labeled" + + # Check that EOS tokens after human inputs are not labeled + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + assert start_idx != -1, f"Could not find '{input_text}' in input_ids" + + eos_idx = start_idx + len(input_ids) + while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: + eos_idx += 1 + + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token after human input '{input_text}' to not be labeled" + + def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='last'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="last", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + last_eos_idx = eos_indices[-1] + + # Check that only the last EOS token is labeled + for idx in eos_indices[:-1]: + assert ( + labels[idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token at index {idx} to not be labeled" + assert ( + labels[last_eos_idx] != IGNORE_TOKEN_ID + ), f"Expected last EOS token at index {last_eos_idx} to be labeled" + + def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='none'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="none", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + for eos_idx in eos_indices: + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token at index {eos_idx} to not be labeled" + + def test_drop_system_message(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with drop_system_message=True") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, chat_templates("llama3"), drop_system_message=True + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + input_ids = res["input_ids"] + + # Check if system message is not present in input_ids + system_message = "You are an AI assistant." + system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False) + assert ( + self.find_sublist(input_ids, system_ids) == -1 + ), "Expected system message to be dropped" + + def test_custom_roles(self, llama3_tokenizer): + LOG.info("Testing with custom roles mapping") + custom_roles = { + "user": ["human", "user"], + "assistant": ["ai", "assistant"], + "system": ["context"], + } + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, chat_templates("llama3"), roles=custom_roles + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["ai"], + ) + + # Create a new dataset with modified role names + modified_conversations = [ + {"from": "context", "value": "You are an AI assistant."}, + {"from": "human", "value": "Hello"}, + {"from": "ai", "value": "Hi there!"}, + {"from": "human", "value": "How are you?"}, + {"from": "ai", "value": "I'm doing well, thank you!"}, + ] + + modified_dataset = Dataset.from_dict( + {"conversations": [modified_conversations]} + ) + + res = strategy.tokenize_prompt(modified_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Check if AI responses are labeled correctly + ai_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in ai_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + assert start_idx != -1, f"Could not find response '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for AI response '{response}' to be set" + + # Check if human messages are not labeled + human_messages = ["Hello", "How are you?"] + for message in human_messages: + message_ids = llama3_tokenizer.encode(message, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, message_ids) + assert start_idx != -1, f"Could not find message '{message}' in input_ids" + assert all( + label == IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(message_ids)] + ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID" + + def test_message_field_training(self, llama3_tokenizer): + LOG.info("Testing with message_field_training") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, + chat_templates("llama3"), + message_field_training="train", + message_field_training_detail="train_detail", + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=[], + ) + + # Create a new dataset with the train and train_detail fields + modified_conversation = [ + {"from": "system", "value": "You are an AI assistant.", "train": False}, + {"from": "human", "value": "Hello", "train": False}, + {"from": "assistant", "value": "Hello", "train": True}, + {"from": "human", "value": "How are you?", "train": True}, + { + "from": "assistant", + "value": "I'm doing very well, thank you!", + "train_detail": [ + {"begin_offset": 0, "end_offset": 8, "train": False}, + {"begin_offset": 9, "end_offset": 18, "train": True}, + {"begin_offset": 19, "end_offset": 30, "train": False}, + ], + }, + { + "from": "human", + "value": "I'm doing very well, thank you!", + "train": False, + }, + {"from": "assistant", "value": "Hi there!", "train": True}, + ] + + modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]}) + + res = strategy.tokenize_prompt(modified_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Function to find all occurrences of a sublist + def find_all_sublists(full_list, sub_list): + indices = [] + for index in range(len(full_list) - len(sub_list) + 1): + if full_list[index : index + len(sub_list)] == sub_list: + indices.append(index) + return indices + + # Keep track of which occurrences we've processed + processed_occurrences = {} + # Check if messages are labeled correctly based on train or train_detail + for i, turn in enumerate(modified_conversation): + turn_tokens = llama3_tokenizer.encode( + turn["value"], add_special_tokens=False + ) + occurrences = find_all_sublists(input_ids, turn_tokens) + turn_key = turn["value"] + if turn_key not in processed_occurrences: + processed_occurrences[turn_key] = 0 + current_occurrence = processed_occurrences[turn_key] + + if current_occurrence >= len(occurrences): + assert ( + False + ), f"Not enough occurrences found for message: {turn['value']}" + + start_idx = occurrences[current_occurrence] + processed_occurrences[turn_key] += 1 + end_idx = start_idx + len(turn_tokens) + + LOG.debug( + f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}" + ) + + if "train_detail" in turn: + # Get token offsets + tokenized_output = llama3_tokenizer( + turn["value"], return_offsets_mapping=True, add_special_tokens=False + ) + token_offsets = tokenized_output["offset_mapping"] + + # Adjust token offsets as done in the implementation + for i in range(len(token_offsets) - 1): + token_offsets[i] = ( + token_offsets[i][0], + token_offsets[i + 1][0] - 1, + ) + token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1) + + # Adjust train_details + adjusted_train_details = strategy.prompter.adjust_train_details( + turn["train_detail"], token_offsets + ) + + LOG.debug(f"Original train_details: {turn['train_detail']}") + LOG.debug(f"Adjusted train_details: {adjusted_train_details}") + + # Handle train_detail + token_offsets = strategy.prompter.get_offsets_for_train_detail( + text=turn["value"], + train_details=adjusted_train_details, + mask_untrainable=False, + ) + token_offsets_masked = strategy.prompter.get_offsets_for_train_detail( + text=turn["value"], + train_details=adjusted_train_details, + mask_untrainable=True, + ) + LOG.debug(f"Token offsets: {token_offsets_masked}") + + expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens) + for i, offset in enumerate(token_offsets_masked): + if offset != IGNORE_TOKEN_ID: + expected_labels[i] = turn_tokens[i] + actual_labels = labels[ + start_idx : start_idx + len(token_offsets_masked) + ] + assert ( + actual_labels == expected_labels + ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}" + + for detail in adjusted_train_details: + # Find the token indices that correspond to the character offsets + detail_start = start_idx + next( + i + for i, offset in enumerate(token_offsets) + if offset >= detail["begin_offset"] + ) + detail_end = start_idx + next( + ( + i + for i, offset in enumerate(token_offsets) + if offset > detail["end_offset"] + ), + len(token_offsets), + ) + + detail_text = turn["value"][ + detail["begin_offset"] : detail["end_offset"] + 1 + ] + detail_labels = labels[detail_start:detail_end] + detail_input_ids = input_ids[detail_start:detail_end] + + LOG.debug( + f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}" + ) + LOG.debug(f"Detail input_ids: {detail_input_ids}") + LOG.debug(f"Detail labels: {detail_labels}") + LOG.debug( + f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}" + ) + LOG.debug( + f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}" + ) + + if detail["train"]: + assert all( + label != IGNORE_TOKEN_ID for label in detail_labels + ), ( + f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. " + f"Labels({detail_start}:{detail_end}): {detail_labels}, " + f"InputIDs: {detail_input_ids}, " + f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" + ) + else: + assert all( + label == IGNORE_TOKEN_ID for label in detail_labels + ), ( + f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. " + f"Labels({detail_start}:{detail_end}): {detail_labels}, " + f"InputIDs: {detail_input_ids}, " + f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" + ) + else: + should_train = turn.get("train", False) + turn_labels = labels[start_idx:end_idx] + + LOG.debug(f"Should train: {should_train}") + LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}") + LOG.debug(f"Turn labels: {turn_labels}") + LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}") + LOG.debug( + f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}" + ) + + if should_train: + assert all(label != IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected all labels for '{turn['value']}' to be set\n" + f"Labels({start_idx}:{end_idx}): {turn_labels}, " + f"InputIDs: {input_ids[start_idx:end_idx]}, " + f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" + ) + else: + assert all(label == IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n" + f"Labels({start_idx}:{end_idx}): {turn_labels}, " + f"InputIDs: {input_ids[start_idx:end_idx]}, " + f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" + ) + + LOG.debug( + f"Processed turn: {turn['from']}, content: '{turn['value']}', " + f"start_idx: {start_idx}, end_idx: {end_idx}, " + f"labels: {labels[start_idx:end_idx]}" + ) + + LOG.debug(f"Final labels: {labels}") + LOG.debug(f"Final input_ids: {input_ids}") + + +if __name__ == "__main__": + unittest.main() From d7eea2ff343e9f0653ce82fc1826b41efa9bc2f6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 24 Sep 2024 14:05:58 -0400 Subject: [PATCH 03/12] validation fixes 20240923 (#1925) * validation fixes 20240923 * fix run name for wandb and defaults for chat template fields * fix gradio inference with llama chat template --- src/axolotl/cli/__init__.py | 27 +++++++++++++++++-- src/axolotl/core/trainer_builder.py | 8 ++++++ .../prompt_strategies/chat_template.py | 4 +-- .../config/models/input/v0_4_1/__init__.py | 10 ++++++- 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index aaa62423c..13c5b4ab5 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -30,6 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta +from axolotl.utils.chat_templates import chat_templates from axolotl.utils.config import ( normalize_cfg_datasets, normalize_config, @@ -234,7 +235,8 @@ def do_inference_gradio( model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) prompter = cli_args.prompter - default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} + # default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} + default_tokens: Dict[str, str] = {} for token, symbol in default_tokens.items(): # If the token isn't already specified in the config, add it @@ -242,10 +244,13 @@ def do_inference_gradio( tokenizer.add_special_tokens({token: symbol}) prompter_module = None + chat_template_str = None if prompter: prompter_module = getattr( importlib.import_module("axolotl.prompters"), prompter ) + elif cfg.chat_template: + chat_template_str = chat_templates(cfg.chat_template) model = model.to(cfg.device, dtype=cfg.torch_dtype) @@ -259,7 +264,24 @@ def do_inference_gradio( ) else: prompt = instruction.strip() - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + if chat_template_str: + batch = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": prompt, + } + ], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) model.eval() with torch.no_grad(): @@ -282,6 +304,7 @@ def do_inference_gradio( streamer = TextIteratorStreamer(tokenizer) generation_kwargs = { "inputs": batch["input_ids"].to(cfg.device), + "attention_mask": batch["attention_mask"].to(cfg.device), "generation_config": generation_config, "streamer": streamer, } diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index f4cd25783..7c3e437f8 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1417,6 +1417,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): report_to = [] if self.cfg.use_wandb: report_to.append("wandb") + if self.cfg.wandb_name: + training_arguments_kwargs["run_name"] = self.cfg.wandb_name if self.cfg.use_mlflow: report_to.append("mlflow") if self.cfg.use_tensorboard: @@ -1574,6 +1576,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) training_args = self.hook_post_create_training_args(training_args) + # unset run_name so wandb sets up experiment names + if self.cfg.use_wandb and training_args.run_name == training_args.output_dir: + training_args.run_name = ( # pylint: disable=attribute-defined-outside-init + None + ) + data_collator_kwargs = { "padding": True, # True/"longest" is the default } diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 717367eef..88e748895 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -375,8 +375,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): prompter_params = { "tokenizer": tokenizer, "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), - "message_field_role": ds_cfg.get("message_field_role", "from"), - "message_field_content": ds_cfg.get("message_field_content", "value"), + "message_field_role": ds_cfg.get("message_field_role", "role"), + "message_field_content": ds_cfg.get("message_field_content", "content"), "message_field_training": ds_cfg.get("message_field_training", None), "message_field_training_detail": ds_cfg.get( "message_field_training_detail", diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 458bacdb1..221785508 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1017,12 +1017,20 @@ class AxolotlInputConfig( return neftune_noise_alpha @model_validator(mode="after") - def check(self): + def check_rl_beta(self): if self.dpo_beta and not self.rl_beta: self.rl_beta = self.dpo_beta del self.dpo_beta return self + @model_validator(mode="after") + def check_simpo_warmup(self): + if self.rl == "simpo" and self.warmup_ratio: + raise ValueError( + "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" + ) + return self + @model_validator(mode="before") @classmethod def check_frozen(cls, data): From b98d7d7098f5d64a07c5a96855c4e08dca7afd91 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 26 Sep 2024 11:33:41 -0400 Subject: [PATCH 04/12] update upstream deps versions and replace lora+ (#1928) * update upstream deps versions and replace lora+ * typo transformers version --- requirements.txt | 8 +- src/axolotl/core/trainer_builder.py | 14 +-- src/axolotl/loraplus.py | 133 ---------------------------- 3 files changed, 11 insertions(+), 144 deletions(-) delete mode 100644 src/axolotl/loraplus.py diff --git a/requirements.txt b/requirements.txt index 32a9e0e01..3f17e5d32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 -peft==0.12.0 -transformers @ git+https://github.com/huggingface/transformers.git@0963229e287501bed52ae1dabc17922524de6992 +peft==0.13.0 +transformers==4.45.0 tokenizers>=0.19.1 -bitsandbytes==0.43.3 +bitsandbytes==0.44.0 accelerate==0.34.2 datasets==2.21.0 deepspeed==0.14.4 @@ -34,7 +34,7 @@ tensorboard python-dotenv==1.0.1 autoawq>=0.2.5 triton>=2.3.0 -liger-kernel==0.2.1 +liger-kernel==0.3.0 mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7c3e437f8..23ac0952e 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -21,6 +21,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union import torch import transformers from datasets import Dataset +from peft.optimizers import create_loraplus_optimizer from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler @@ -45,7 +46,6 @@ from trl import ( ) from trl.trainer.utils import pad_to_length -from axolotl.loraplus import create_loraplus_optimizer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils import is_mlflow_available @@ -461,9 +461,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer): self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init opt_model, optimizer_cls, - optimizer_kwargs, - loraplus_lr_ratio, - loraplus_lr_embedding, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, ) elif self.args.alternate_optimizer == "optimi_adamw": from optimi import AdamW @@ -969,9 +969,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init opt_model, optimizer_cls, - optimizer_kwargs, - loraplus_lr_ratio, - loraplus_lr_embedding, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, ) if is_sagemaker_mp_enabled(): diff --git a/src/axolotl/loraplus.py b/src/axolotl/loraplus.py deleted file mode 100644 index b4abec55a..000000000 --- a/src/axolotl/loraplus.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Module for LoRA+""" - -# MIT License -# -# Copyright (c) 2024 nikhil-ghosh-berkeley -# https://github.com/nikhil-ghosh-berkeley/loraplus - -import logging -from functools import reduce - -from peft.tuners import lora -from torch import nn -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.trainer_pt_utils import get_parameter_names - -LOG = logging.getLogger("axolotl.loraplus") - - -def get_module(name, opt_model): - """ - Retrieve a module from a model using its parameter name. - Args: - name (str): Full name of the parameter, typically including module path. - opt_model (torch.nn.Module): The model from which to retrieve the module. - - Returns: - Module corresponding to the given name. - """ - parent_idx = 2 if "lora" in name else 1 - module_names = name.split(sep=".")[:-parent_idx] - module = reduce(getattr, module_names, opt_model) - return module - - -def create_loraplus_optimizer( - opt_model, - optimizer_cls, - optimizer_kwargs, - loraplus_lr_ratio, - loraplus_lr_embedding=None, -): - """ - Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups. - - Args: - opt_model (torch.nn.Module): The model for which the optimizer is being created. - optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam). - optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization. - loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters. - loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided. - - Returns: - An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates. - """ - - assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided." - - if loraplus_lr_embedding is None: - loraplus_lr_embedding = 1e-6 - - decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) - decay_parameters = [name for name in decay_parameters if "bias" not in name] - param_groups = { - "groupA": {}, - "groupB": {}, - "groupB_no_decay": {}, - "embedding": {}, - } - - for name, param in opt_model.named_parameters(): - if not param.requires_grad: - continue - - module = get_module(name, opt_model) - if isinstance(module, lora.Embedding): - param_groups["embedding"][name] = param - elif "lora_B" in name or param.ndim == 1: - if name in decay_parameters: - param_groups["groupB"][name] = param - else: - param_groups["groupB_no_decay"][name] = param - else: - param_groups["groupA"][name] = param - - assigned_param_groups = "" - for group, group_params in param_groups.items(): - assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n" - LOG.info(assigned_param_groups) - - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name - weight_decay = optimizer_kwargs.get("weight_decay", 0.0) - - optimizer_grouped_parameters = [ - { - "params": list(param_groups["groupA"].values()), - "weight_decay": weight_decay, - "lr": lr, - }, - { - "params": list(param_groups["embedding"].values()), - "weight_decay": weight_decay, - "lr": loraplus_lr_embedding, - }, - { - "params": list(param_groups["groupB"].values()), - "weight_decay": weight_decay, - "lr": lr * loraplus_lr_ratio, - }, - { - "params": list(param_groups["groupB_no_decay"].values()), - "weight_decay": 0.0, - "lr": lr * loraplus_lr_ratio, - }, - ] - - optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - if optimizer_cls.__name__ == "Adam8bit": - import bitsandbytes - - manager = bitsandbytes.optim.GlobalOptimManager.get_instance() - - skipped = 0 - for module in opt_model.modules(): - if isinstance(module, nn.Embedding): - skipped += sum( - {p.data_ptr(): p.numel() for p in module.parameters()}.values() - ) - LOG.info(f"skipped {module}: {skipped/2**20}M params") - manager.register_module_override(module, "weight", {"optim_bits": 32}) - LOG.debug(f"bitsandbytes: will optimize {module} in fp32") - LOG.info(f"skipped: {skipped/2**20}M params") - - return optimizer From 61aa291119e90dbebf5612be42cd5cad3729bc6e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 27 Sep 2024 15:58:35 -0400 Subject: [PATCH 05/12] fix for empty lora+ lr embedding (#1932) --- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 23ac0952e..249398f85 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -456,7 +456,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): if self.args.loraplus_lr_ratio is not None: loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_embedding = getattr( - self.args, "loraplus_lr_embedding", None + self.args, "loraplus_lr_embedding", 1e-6 ) self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init opt_model, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 221785508..4e07c9260 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -298,6 +298,13 @@ class LoraConfig(BaseModel): raise ValueError("Require cfg.load_in_4bit to be True for qlora") return self + @field_validator("loraplus_lr_embedding") + @classmethod + def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding): + if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str): + loraplus_lr_embedding = float(loraplus_lr_embedding) + return loraplus_lr_embedding + class ReLoRAConfig(BaseModel): """ReLoRA configuration subset""" From 844331005c1ef45430ff26b9f42f757dce6ee66a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Sep 2024 13:56:12 -0400 Subject: [PATCH 06/12] bump transformers to 4.45.1 (#1936) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3f17e5d32..123a4ee54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.13.0 -transformers==4.45.0 +transformers==4.45.1 tokenizers>=0.19.1 bitsandbytes==0.44.0 accelerate==0.34.2 From e1915f5625b2330555c3f61816dd003fb939ae13 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 2 Oct 2024 21:02:48 -0400 Subject: [PATCH 07/12] Multimodal Vision Llama - rudimentary support (#1940) --------- Co-authored-by: Sunny Co-authored-by: sunny --- docs/input_output.qmd | 2 +- docs/multimodal.qmd | 28 +++ examples/llama-3-vision/lora-11b.yaml | 63 +++++ src/axolotl/cli/__init__.py | 7 +- src/axolotl/core/trainer_builder.py | 20 +- src/axolotl/monkeypatch/attention/mllama.py | 229 ++++++++++++++++++ src/axolotl/monkeypatch/multipack.py | 1 + .../monkeypatch/stablelm_attn_hijack_flash.py | 1 + src/axolotl/prompt_strategies/__init__.py | 4 +- .../prompt_strategies/chat_template.py | 60 ++++- src/axolotl/train.py | 10 +- src/axolotl/utils/chat_templates.py | 46 ++-- src/axolotl/utils/collators/__init__.py | 10 + .../{collators.py => collators/batching.py} | 35 +-- src/axolotl/utils/collators/core.py | 4 + src/axolotl/utils/collators/mamba.py | 38 +++ src/axolotl/utils/collators/mm_chat.py | 77 ++++++ src/axolotl/utils/config/__init__.py | 25 +- .../config/models/input/v0_4_1/__init__.py | 29 ++- src/axolotl/utils/data/sft.py | 48 +++- src/axolotl/utils/models.py | 98 ++++++-- src/axolotl/utils/trainer.py | 16 +- .../prompt_strategies/test_chat_templates.py | 21 +- .../test_chat_templates_advanced.py | 46 +++- 24 files changed, 799 insertions(+), 119 deletions(-) create mode 100644 docs/multimodal.qmd create mode 100644 examples/llama-3-vision/lora-11b.yaml create mode 100644 src/axolotl/monkeypatch/attention/mllama.py create mode 100644 src/axolotl/utils/collators/__init__.py rename src/axolotl/utils/{collators.py => collators/batching.py} (90%) create mode 100644 src/axolotl/utils/collators/core.py create mode 100644 src/axolotl/utils/collators/mamba.py create mode 100644 src/axolotl/utils/collators/mm_chat.py diff --git a/docs/input_output.qmd b/docs/input_output.qmd index 7715dd250..6559578d1 100644 --- a/docs/input_output.qmd +++ b/docs/input_output.qmd @@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/') hi there!. goodbye farewell ``` -We can check that the right tokens are ingored by comparing the labels +We can check that the right tokens are ignored by comparing the labels to each token: ```python diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd new file mode 100644 index 000000000..2381566ad --- /dev/null +++ b/docs/multimodal.qmd @@ -0,0 +1,28 @@ +# MultiModal / Vision Language Models (BETA) + +### Supported Models + +- Mllama, i.e. llama with vision models + +### Usage + +Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA, +you'll need to use the following in YAML in combination with the rest of the required hyperparams. + +```yaml +base_model: alpindale/Llama-3.2-11B-Vision-Instruct +processor_type: AutoProcessor +skip_prepare_dataset: true + +chat_template: llama3_2_vision +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +remove_unused_columns: false +sample_packing: false + +# only finetune the Language model, leave the vision model and vision tower frozen +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +``` diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml new file mode 100644 index 000000000..b2e494641 --- /dev/null +++ b/examples/llama-3-vision/lora-11b.yaml @@ -0,0 +1,63 @@ +base_model: alpindale/Llama-3.2-11B-Vision-Instruct +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: llama3_2_vision +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 13c5b4ab5..a1d84b6a1 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -40,7 +40,7 @@ from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process from axolotl.utils.mlflow_ import setup_mlflow_env_vars -from axolotl.utils.models import load_tokenizer +from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars @@ -430,9 +430,12 @@ def load_datasets( cli_args: TrainerCliArgs, ) -> TrainDatasetMeta: tokenizer = load_tokenizer(cfg) + processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( - cfg, tokenizer + cfg, + tokenizer, + processor=processor, ) if cli_args.debug or cfg.debug: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 249398f85..4893e63dc 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -61,12 +61,14 @@ from axolotl.utils.callbacks import ( log_prediction_callback_factory, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory +from axolotl.utils.chat_templates import chat_templates from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) +from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.models import ensure_dtype from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( @@ -250,6 +252,10 @@ class AxolotlTrainingMixins: "help": "workaround to pass an alternate lr scheduler to the HF trainer" }, ) + chat_template: Optional[str] = field( + default=None, + metadata={"help": "Chat template converting chat messages to text"}, + ) @dataclass @@ -1043,10 +1049,11 @@ class TrainerBuilderBase(abc.ABC): _model_ref = None _peft_config = None - def __init__(self, cfg, model, tokenizer): + def __init__(self, cfg, model, tokenizer, processor=None): self.cfg = cfg self.model = model self.tokenizer = tokenizer + self.processor = processor # in case the model supports tagging, add the axolotl tag. # This makes sure the tag is correctly pushed even if a user calls @@ -1515,6 +1522,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) + if self.cfg.chat_template: + training_arguments_kwargs["chat_template"] = chat_templates( + self.cfg.chat_template + ) if self.cfg.rl == "orpo": training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha @@ -1661,7 +1672,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): else: collator = BatchSamplerDataCollatorForSeq2Seq else: - collator = DataCollatorForSeq2Seq + if self.cfg.processor_type and self.processor: + collator = MultiModalChatDataCollator + kwargs["processor"] = self.processor + kwargs["chat_template"] = training_args.chat_template + else: + collator = DataCollatorForSeq2Seq return collator( self.tokenizer, diff --git a/src/axolotl/monkeypatch/attention/mllama.py b/src/axolotl/monkeypatch/attention/mllama.py new file mode 100644 index 000000000..0b18b716d --- /dev/null +++ b/src/axolotl/monkeypatch/attention/mllama.py @@ -0,0 +1,229 @@ +""" +Monkeypatch for Vision Llama for FA2 support +""" +# pylint: disable=duplicate-code + +from typing import Optional, Tuple + +import torch +from flash_attn.flash_attn_interface import flash_attn_func +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.models.mllama.configuration_mllama import MllamaTextConfig +from transformers.models.mllama.modeling_mllama import ( + MllamaTextCrossAttention, + MllamaTextSelfAttention, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import is_flash_attn_greater_or_equal_2_10 + + +class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention): + """ + Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and + implements the forward pass using Flash Attention for improved performance. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Check if flash attention version is greater or equal to 2.1 + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[ # pylint: disable=unused-argument + torch.Tensor + ] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, + value_states, + self.layer_idx, + {"cache_position": cache_position}, + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + # Transpose to get the expected layout for flash attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # Apply Flash Attention + dropout_rate = self.dropout if self.training else 0.0 + output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=dropout_rate, + softmax_scale=None, + causal=False, + return_attn_probs=output_attentions, + ) + + attn_output = output.contiguous().view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention): + """ + Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and + implements the forward pass using Flash Attention for improved performance. + """ + + def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs): + super().__init__(config, layer_idx, *args, **kwargs) + + # Check if flash attention version is greater or equal to 2.1 + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + past_key_value=None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, # pylint: disable=unused-argument + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x num_heads x head_dim + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Transpose to get the expected layout for flash attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # Handle potential silent casting to float32 + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = ( + self.config._pre_quantization_dtype # pylint: disable=protected-access + ) + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=True, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def patch_mllama(): + from transformers.models.mllama.modeling_mllama import ( + MLLAMA_TEXT_ATTENTION_CLASSES, + MLLAMA_TEXT_CROSS_ATTENTION_CLASSES, + MLLAMA_VISION_ATTENTION_CLASSES, + MllamaPreTrainedModel, + ) + + MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access + True + ) + MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2 + MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ + "flash_attention_2" + ] = MllamaTextCrossFlashAttention2 + # fallback to SDPA + MLLAMA_VISION_ATTENTION_CLASSES[ + "flash_attention_2" + ] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 44fc4cb47..85101cd3c 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -10,6 +10,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 from axolotl.monkeypatch.utils import get_unpad_data SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "mllama_text_model", "llama", "mistral", "mixtral", diff --git a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py index 0269f9015..67e9337e3 100644 --- a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py @@ -16,6 +16,7 @@ # This code is based off the following work: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py +# pylint: disable=duplicate-code """ PyTorch StableLM Epoch model. """ import importlib import math diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index f5699a087..66cd5deeb 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -9,7 +9,7 @@ from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig LOG = logging.getLogger("axolotl.prompt_strategies") -def load(strategy, tokenizer, cfg, ds_cfg): +def load(strategy, tokenizer, cfg, ds_cfg, processor=None): try: load_fn = "load" if strategy.split(".")[-1].startswith("load_"): @@ -24,6 +24,8 @@ def load(strategy, tokenizer, cfg, ds_cfg): sig = inspect.signature(func) if "ds_cfg" in sig.parameters: load_kwargs["ds_cfg"] = ds_cfg + if "processor" in sig.parameters: + load_kwargs["processor"] = processor return func(tokenizer, cfg, **load_kwargs) except ModuleNotFoundError: return None diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 88e748895..48d52dae1 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -5,6 +5,8 @@ HF Chat Templates prompt strategy import logging from typing import Any, Dict, List, Optional +from transformers import ProcessorMixin + from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import chat_templates @@ -20,6 +22,7 @@ class ChatTemplatePrompter(Prompter): def __init__( self, tokenizer, + processor=None, chat_template=None, max_length=2048, message_field_role: str = "from", @@ -44,11 +47,12 @@ class ChatTemplatePrompter(Prompter): self.message_field_training = message_field_training self.message_field_training_detail = message_field_training_detail self.tokenizer = tokenizer + self.processor: ProcessorMixin = processor self.chat_template = chat_template self.max_length = max_length self.drop_system_message = drop_system_message - def build_prompt(self, conversation, add_generation_prompt=False): + def build_prompt(self, conversation, add_generation_prompt=False, images=None): turns = [ { "role": self.roles[t[self.message_field_role]], @@ -61,6 +65,28 @@ class ChatTemplatePrompter(Prompter): if self.drop_system_message and turns[0]["role"] == "system": turns = turns[1:] + if self.processor: + text = self.processor.apply_chat_template( + turns, + chat_template=self.chat_template, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + batch = self.processor( + text=text, + images=images, + return_tensors="pt", + truncation=True, + max_length=self.max_length, + ) + # workaround since processor works in batches instead of single examples + for k, val in batch.items(): + if k in ["pixel_values"]: + batch[k] = val.tolist() + else: + batch[k] = val.squeeze().tolist() + return batch + return self.tokenizer.apply_chat_template( turns, truncation=True, @@ -191,6 +217,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) self.roles_to_train = roles_to_train if roles_to_train is not None else [] self.train_on_eos = train_on_eos + self.images = "images" @property def messages(self): @@ -209,10 +236,21 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): and not self.prompter.message_field_training_detail ): turns = self.get_conversation_thread(prompt) + images = self.get_images(prompt) prompt_ids = self.prompter.build_prompt( - turns[:-1], add_generation_prompt=True + turns[:-1], + add_generation_prompt=True, + images=images, ) - input_ids = self.prompter.build_prompt(turns) + tokenized_res = self.prompter.build_prompt(turns, images=images) + tokenized_prompt = {} + if isinstance(tokenized_res, list): + input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] + tokenized_prompt["input_ids"] = input_ids + tokenized_prompt["attention_mask"] = [1] * len(input_ids) + else: + input_ids = tokenized_res["input_ids"] + tokenized_prompt = tokenized_res if not self.train_on_inputs: user_prompt_len = len(prompt_ids) @@ -220,17 +258,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): else: labels = input_ids - tokenized_prompt = { - "input_ids": input_ids, - "labels": labels, - "attention_mask": [1] * len(input_ids), - } + tokenized_prompt["labels"] = labels return tokenized_prompt - LOG.info(self.roles_to_train) - LOG.info(self.train_on_eos) - LOG.info(self.prompter.message_field_training) - LOG.info(self.prompter.message_field_training_detail) turns = prompt[self.messages] input_ids = self.prompter.build_prompt(turns) @@ -368,8 +398,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def get_conversation_thread(self, prompt): return prompt[self.messages] + def get_images(self, prompt): + return prompt.get(self.images, None) -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): ds_cfg = ds_cfg or {} prompter_params = { @@ -386,6 +419,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): "drop_system_message": ds_cfg.get("drop_system_message", False), # we need to add one for detecting sequences with exceeding the `sequence_len` limit. "max_length": cfg.sequence_len + 1, + "processor": processor, } strategy_params = { diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b21b0b269..855dbc2d3 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -24,7 +24,7 @@ from axolotl.core.tokenizer_utils import fix_untrained_tokens from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except -from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.trainer import setup_trainer try: @@ -69,6 +69,9 @@ def train( main_process_only=True, ) tokenizer = load_tokenizer(cfg) + processor = None + if cfg.is_multimodal: + processor = load_processor(cfg, tokenizer) train_dataset = dataset_meta.train_dataset eval_dataset = dataset_meta.eval_dataset @@ -96,7 +99,9 @@ def train( LOG.debug(msg) # we wait unitl the last possible moment to setup Accelerator Accelerator() - model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) + model, peft_config = load_model( + cfg, tokenizer, processor=processor, inference=cli_args.inference + ) model.generation_config.do_sample = True model_ref = None @@ -122,6 +127,7 @@ def train( eval_dataset, (model, model_ref, peft_config), tokenizer, + processor, total_num_steps, ) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 7a96f5c1e..7468ae8b1 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -3,6 +3,20 @@ This module provides functionality for selecting chat templates based on user ch These templates are used for formatting messages in a conversation. """ +CHAT_TEMPLATES = { + "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", + "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. + "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", + "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", + "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n', + "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", + "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", + "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', +} + def chat_templates(user_choice: str): """ @@ -18,20 +32,22 @@ def chat_templates(user_choice: str): ValueError: If the user_choice is not found in the templates. """ - templates = { - "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", - "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. - "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", - "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", - "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", - "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", - "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", - "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', - } - - if user_choice in templates: - return templates[user_choice] + if user_choice in CHAT_TEMPLATES: + return CHAT_TEMPLATES[user_choice] raise ValueError(f"Template '{user_choice}' not found.") + + +def register_chat_template(template_name: str, chat_template: str): + """ + Registers chat templates. + + Args: + template_name (str): The name of the template. + chat_template (str): The template string. + """ + + if template_name in CHAT_TEMPLATES: + raise ValueError(f"Template '{template_name}' already exists.") + + CHAT_TEMPLATES[template_name] = chat_template diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py new file mode 100644 index 000000000..93502b67d --- /dev/null +++ b/src/axolotl/utils/collators/__init__.py @@ -0,0 +1,10 @@ +""" +shared axolotl collators for multipack, mamba, multimodal +""" +from .batching import ( # noqa: F401 + BatchSamplerDataCollatorForSeq2Seq, + DataCollatorForSeq2Seq, + PretrainingBatchSamplerDataCollatorForSeq2Seq, + V2BatchSamplerDataCollatorForSeq2Seq, +) +from .mamba import MambaDataCollator # noqa: F401 diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators/batching.py similarity index 90% rename from src/axolotl/utils/collators.py rename to src/axolotl/utils/collators/batching.py index 26c7fa9f3..7cf771421 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,17 +1,14 @@ """ DataCollator for axolotl to pad labels and position_ids for packed sequences """ + from dataclasses import dataclass -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Optional, Union import numpy as np -import torch -import transformers from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy -IGNORE_INDEX = -100 - @dataclass class DataCollatorForSeq2Seq: @@ -183,34 +180,6 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): return super().__call__(out_features, return_tensors=return_tensors) -@dataclass -class MambaDataCollator: - """ - Collator for State Space Models (Mamba) - """ - - tokenizer: transformers.PreTrainedTokenizer - - def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - input_ids, labels = tuple( - [torch.LongTensor(instance[key]) for instance in instances] - for key in ("input_ids", "labels") - ) - input_ids = torch.nn.utils.rnn.pad_sequence( - input_ids, - batch_first=True, - padding_value=self.tokenizer.pad_token_id, - ) - labels = torch.nn.utils.rnn.pad_sequence( - labels, batch_first=True, padding_value=IGNORE_INDEX - ) - - return { - "input_ids": input_ids, - "labels": labels, - } - - @dataclass class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """ diff --git a/src/axolotl/utils/collators/core.py b/src/axolotl/utils/collators/core.py new file mode 100644 index 000000000..0eae0c3bd --- /dev/null +++ b/src/axolotl/utils/collators/core.py @@ -0,0 +1,4 @@ +""" +basic shared collator constants +""" +IGNORE_INDEX = -100 diff --git a/src/axolotl/utils/collators/mamba.py b/src/axolotl/utils/collators/mamba.py new file mode 100644 index 000000000..0c4a22fcc --- /dev/null +++ b/src/axolotl/utils/collators/mamba.py @@ -0,0 +1,38 @@ +""" +collators for Mamba +""" +from dataclasses import dataclass +from typing import Dict, Sequence + +import torch +import transformers + +from axolotl.utils.collators.core import IGNORE_INDEX + + +@dataclass +class MambaDataCollator: + """ + Collator for State Space Models (Mamba) + """ + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple( + [torch.LongTensor(instance[key]) for instance in instances] + for key in ("input_ids", "labels") + ) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX + ) + + return { + "input_ids": input_ids, + "labels": labels, + } diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py new file mode 100644 index 000000000..f49e97f37 --- /dev/null +++ b/src/axolotl/utils/collators/mm_chat.py @@ -0,0 +1,77 @@ +""" +Collators for multi-modal chat messages and packing +""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +from transformers import PreTrainedTokenizerBase, ProcessorMixin +from transformers.data.data_collator import DataCollatorMixin +from transformers.utils import PaddingStrategy + + +@dataclass +class MultiModalChatDataCollator(DataCollatorMixin): + """ + Collator for multi-modal chat messages + """ + + tokenizer: PreTrainedTokenizerBase + processor: ProcessorMixin + return_tensors: str = "pt" + chat_template: Optional[str] = None + packing: bool = False + max_images: int = -1 + padding: Union[bool, str, PaddingStrategy] = True + pad_to_multiple_of: Optional[int] = None + + def __post_init__(self): + if self.packing: + raise ValueError("Packing is currently not supported.") + + def torch_call( + self, examples: List[Union[List[int], Any, Dict[str, Any]]] + ) -> Dict[str, Any]: + # Handle dict or lists with proper padding and conversion to tensor. + + return self.__class__.process_rows( + examples, self.processor, self.chat_template, self.max_images + ) + + @staticmethod + def process_rows(examples, processor, chat_template, max_images, length_only=False): + # HINT: use `_torch_collate_batch` to stack and pad tensors + # see also DataCollatorWithFlattening and DefaultDataCollator + + # *** This is COPIED from the trl example sft_vlm.py code *** + # use this as a starting point + + # Get the texts and images, and apply the chat template + texts = [ + processor.apply_chat_template( + example["messages"], chat_template=chat_template, tokenize=False + ) + for example in examples + ] + images = [example["images"] for example in examples] + + if max_images > 0: + images = [img_batch[:max_images] for img_batch in images] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 # + # Ignore the image token index in the loss computation (model specific) + image_token_id = processor.tokenizer.convert_tokens_to_ids( + processor.image_token + ) + labels[labels == image_token_id] = -100 + batch["labels"] = labels + + if length_only: + return { + "length": [len(sample["input_ids"]) for sample in batch["input_ids"]] + } + return batch diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 82436e8d7..f732db06f 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -121,15 +121,36 @@ def normalize_config(cfg): cfg.base_model_config = cfg.base_model model_config = load_model_config(cfg) - cfg.model_config_type = model_config.model_type cfg.tokenizer_config = ( cfg.tokenizer_config or cfg.base_model_config or cfg.base_model ) + cfg.is_multimodal = ( + hasattr(model_config, "model_type") + and model_config.model_type in ["llava", "mllama"] + or any( + multimodal_name in cfg.base_model.lower() + for multimodal_name in [ + "pixtral", + ] + ) + or cfg.is_multimodal + ) + if cfg.is_multimodal: + cfg.processor_config = ( + cfg.processor_config or cfg.base_model_config or cfg.base_model + ) + model_config = model_config.text_config + + cfg.model_config_type = model_config.model_type + # figure out if the model is llama cfg.is_llama_derived_model = ( - (hasattr(model_config, "model_type") and model_config.model_type == "llama") + ( + hasattr(model_config, "model_type") + and model_config.model_type == ["llama", "mllama_text_model"] + ) or cfg.is_llama_derived_model or "llama" in cfg.base_model.lower() or (cfg.type_of_model and "llama" in cfg.type_of_model.lower()) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4e07c9260..fced5e639 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -188,6 +188,7 @@ class ChatTemplate(str, Enum): gemma = "gemma" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name + llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name phi_35 = "phi_35" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name @@ -228,11 +229,12 @@ class LoraConfig(BaseModel): lora_r: Optional[int] = None lora_alpha: Optional[int] = None lora_fan_in_fan_out: Optional[bool] = None - lora_target_modules: Optional[List[str]] = None + lora_target_modules: Optional[Union[str, List[str]]] = None lora_target_linear: Optional[bool] = None lora_modules_to_save: Optional[List[str]] = None lora_dropout: Optional[float] = 0.0 peft_layers_to_transform: Optional[List[int]] = None + peft_layers_pattern: Optional[List[str]] = None peft: Optional[PeftConfig] = None peft_use_dora: Optional[bool] = None peft_use_rslora: Optional[bool] = None @@ -328,6 +330,9 @@ class ModelInputConfig(BaseModel): tokenizer_type: Optional[str] = Field( default=None, metadata={"help": "transformers tokenizer class"} ) + processor_type: Optional[str] = Field( + default=None, metadata={"help": "transformers processor class"} + ) trust_remote_code: Optional[bool] = None model_kwargs: Optional[Dict[str, Any]] = None @@ -530,6 +535,7 @@ class AxolotlInputConfig( dataset_prepared_path: Optional[str] = None dataset_shard_num: Optional[int] = None dataset_shard_idx: Optional[int] = None + skip_prepare_dataset: Optional[bool] = False pretraining_dataset: Optional[ # type: ignore conlist(Union[PretrainingDataset, SFTDataset], min_length=1) @@ -997,6 +1003,18 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_mm_prepare(cls, data): + if data.get("skip_prepare_dataset"): + if data.get("remove_unused_columns") is None: + LOG.info( + "setting `remove_unused_columns: false` for skip_prepare_dataset" + ) + data["remove_unused_columns"] = False + + return data + @model_validator(mode="before") @classmethod def check_warmup(cls, data): @@ -1052,6 +1070,15 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_peft_layers_pattern(cls, data): + if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"): + raise ValueError( + "peft_layers_pattern requires peft_layers_to_transform to be set" + ) + return data + @model_validator(mode="after") def check_fft_possible_bad_config(self): if ( diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 1b6df1cde..7d6922cbf 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -51,20 +51,31 @@ from axolotl.utils.trainer import ( LOG = logging.getLogger("axolotl") -def prepare_dataset(cfg, tokenizer): +def prepare_dataset(cfg, tokenizer, processor=None): prompters = [] if not cfg.pretraining_dataset: with zero_first(is_local_main_process()): if cfg.test_datasets: train_dataset, _, prompters = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + split="train", + processor=processor, ) _, eval_dataset, _ = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test" + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + split="test", + processor=processor, ) else: train_dataset, eval_dataset, prompters = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + processor=processor, ) else: path = cfg.pretraining_dataset @@ -123,6 +134,7 @@ def load_tokenized_prepared_datasets( cfg, default_dataset_prepared_path, split="train", + processor=None, ) -> Tuple[DatasetDict, List[Prompter]]: cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets tokenizer_name = cfg.tokenizer_config @@ -180,6 +192,7 @@ def load_tokenized_prepared_datasets( cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")) and not cfg.is_preprocess + and not cfg.skip_prepare_dataset ): LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") dataset = load_from_disk(str(prepared_ds_path)) @@ -423,12 +436,16 @@ def load_tokenized_prepared_datasets( dataset=ds, d_base_type=d_base_type, d_prompt_style=d_prompt_style, + processor=processor, ) datasets.append(dataset_wrapper) prompters.append(dataset_prompter) - LOG.info("merging datasets") - dataset = concatenate_datasets(datasets) + if len(datasets) == 1: + dataset = datasets[0] + else: + LOG.info("merging datasets") + dataset = concatenate_datasets(datasets) if len(datasets) > 1: if cfg.shuffle_merged_datasets: @@ -437,9 +454,10 @@ def load_tokenized_prepared_datasets( else: LOG.debug("NOT shuffling merged datasets") - dataset, _ = process_datasets_for_packing(cfg, dataset, None) + if not cfg.skip_prepare_dataset: + dataset, _ = process_datasets_for_packing(cfg, dataset, None) - if cfg.local_rank == 0: + if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(str(prepared_ds_path)) if cfg.push_dataset_to_hub: @@ -478,9 +496,14 @@ def load_prepare_datasets( cfg, default_dataset_prepared_path, split="train", + processor=None, ) -> Tuple[Dataset, Dataset, List[Prompter]]: dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path, split=split + tokenizer, + cfg, + default_dataset_prepared_path, + split=split, + processor=processor, ) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: @@ -546,6 +569,7 @@ def get_dataset_wrapper( d_base_type, dataset, d_prompt_style=None, + processor=None, ): dataset_wrapper = None dataset_prompter = None @@ -578,7 +602,11 @@ def get_dataset_wrapper( dataset, **ds_kwargs, ) - elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): + elif cfg.skip_prepare_dataset: + dataset_wrapper = dataset + elif ds_strategy := load( + config_dataset.type, tokenizer, cfg, config_dataset, processor=processor + ): dataset_prompter = UnsupportedPrompter() dataset_wrapper = TokenizedPromptDataset( ds_strategy, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e18330199..c18af9760 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -28,12 +28,17 @@ from transformers import ( # noqa: F401 AddedToken, AutoConfig, AutoModelForCausalLM, + AutoModelForVision2Seq, + AutoProcessor, AutoTokenizer, AwqConfig, BitsAndBytesConfig, GPTQConfig, + LlavaForConditionalGeneration, + MllamaForConditionalGeneration, PreTrainedModel, PreTrainedTokenizerBase, + ProcessorMixin, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled @@ -80,6 +85,9 @@ def get_module_class_from_name(module, name): def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): + if cfg.is_multimodal: + model_config = model_config.text_config + quant_config_exists = ( hasattr(model_config, "quantization_config") and model_config.quantization_config @@ -299,11 +307,31 @@ def load_tokenizer(cfg): return tokenizer +def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): + processor_kwargs: Dict[str, Any] = {} # do we actually need this? + + processor_cls = AutoProcessor + if cfg.processor_type: + processor_cls = getattr(transformers, cfg.processor_type) + + processor = processor_cls.from_pretrained( + cfg.processor_config, + trust_remote_code=cfg.trust_remote_code or False, + tokenizer=tokenizer, + **processor_kwargs, + ) + + return processor + + def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, + *, + processor: ProcessorMixin = None, # pylint: disable=unused-argument inference: bool = False, reference_model: bool = False, + **kwargs, # pylint: disable=unused-argument ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: """ Load a model for a given configuration and tokenizer. @@ -319,12 +347,23 @@ def load_model( plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(cfg) + if cfg.is_multimodal: + text_model_config = model_config.text_config + else: + text_model_config = model_config + # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit if cfg.gradient_checkpointing == "unsloth": transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper + if hasattr(model_config, "model_type") and model_config.model_type == "mllama": + if cfg.flash_attention: + from axolotl.monkeypatch.attention.mllama import patch_mllama + + patch_mllama() + if hasattr(model_config, "model_type") and model_config.model_type == "btlm": if cfg.flash_attention: from axolotl.monkeypatch.btlm_attn_hijack_flash import ( @@ -461,6 +500,19 @@ def load_model( max_memory = cfg.max_memory device_map = cfg.device_map + AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name + if cfg.is_multimodal: + if model_config.model_type == "llava": + AutoModelLoader = ( # pylint: disable=invalid-name + LlavaForConditionalGeneration + ) + elif model_config.model_type == "mllama": + AutoModelLoader = ( # pylint: disable=invalid-name + MllamaForConditionalGeneration + ) + else: + AutoModelLoader = AutoModelForVision2Seq # pylint: disable=invalid-name + if cfg.gpu_memory_limit: gpu_memory_limit = ( str(cfg.gpu_memory_limit) + "GiB" @@ -478,7 +530,7 @@ def load_model( from accelerate import infer_auto_device_map with init_empty_weights(): - model_canvas = AutoModelForCausalLM.from_config( + model_canvas = AutoModelLoader.from_config( model_config, trust_remote_code=cfg.trust_remote_code or False ) model_canvas.tie_weights() @@ -633,6 +685,8 @@ def load_model( quantization_config = ( quantization_config or model_kwargs["quantization_config"] ) + if cfg.is_multimodal: + model_config.text_config = text_model_config model = load_sharded_model_quant( base_model, model_config, @@ -651,7 +705,9 @@ def load_model( if "device_map" in model_kwargs: del model_kwargs["device_map"] - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, **model_kwargs, @@ -690,13 +746,17 @@ def load_model( and not cfg.trust_remote_code ): if cfg.gptq: - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) else: + if cfg.is_multimodal: + model_config.text_config = text_model_config model = getattr(transformers, model_type).from_pretrained( base_model, config=model_config, @@ -707,21 +767,23 @@ def load_model( # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # when training starts if ( - hasattr(model_config, "max_seq_len") - and model_config.max_seq_len + hasattr(text_model_config, "max_seq_len") + and text_model_config.max_seq_len and cfg.sequence_len > model_config.max_seq_len ): - model_config.max_seq_len = cfg.sequence_len + text_model_config.max_seq_len = cfg.sequence_len LOG.warning(f"increasing context length to {cfg.sequence_len}") elif ( - hasattr(model_config, "max_sequence_length") - and model_config.max_sequence_length - and cfg.sequence_len > model_config.max_sequence_length + hasattr(text_model_config, "max_sequence_length") + and text_model_config.max_sequence_length + and cfg.sequence_len > text_model_config.max_sequence_length ): - model_config.max_sequence_length = cfg.sequence_len + text_model_config.max_sequence_length = cfg.sequence_len LOG.warning(f"increasing context length to {cfg.sequence_len}") if cfg.gptq: - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, trust_remote_code=cfg.trust_remote_code or False, @@ -734,7 +796,9 @@ def load_model( if "device_map" in model_kwargs: del model_kwargs["device_map"] - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, trust_remote_code=cfg.trust_remote_code or False, @@ -1016,12 +1080,17 @@ def load_lora(model, cfg, inference=False, config_only=False): from peft import LoraConfig, get_peft_model - lora_target_modules = list(cfg.lora_target_modules or []) + lora_target_modules = cfg.lora_target_modules or [] if cfg.lora_target_linear: linear_names = find_all_linear_names(model) LOG.info(f"found linear modules: {repr(sorted(linear_names))}") - lora_target_modules = list(set(lora_target_modules + linear_names)) + lora_target_modules_as_list = ( + lora_target_modules + if isinstance(lora_target_modules, list) + else [lora_target_modules] + ) + lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) lora_config_kwargs = {} loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits @@ -1040,6 +1109,7 @@ def load_lora(model, cfg, inference=False, config_only=False): lora_alpha=cfg.lora_alpha, target_modules=lora_target_modules, layers_to_transform=cfg.peft_layers_to_transform, + layers_pattern=cfg.peft_layers_pattern, lora_dropout=cfg.lora_dropout, fan_in_fan_out=cfg.lora_fan_in_fan_out, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 89ae4e697..17276dd8e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -306,7 +306,7 @@ def process_pretraining_datasets_for_packing( def calculate_total_num_steps(cfg, train_dataset, update=True): - if not cfg.total_num_tokens: + if not cfg.total_num_tokens and not cfg.skip_prepare_dataset: total_num_tokens = np.sum( train_dataset.data.column("input_ids") .to_pandas() @@ -319,7 +319,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): skip_estimates = cfg.model_config_type == "mamba" - if not skip_estimates and not cfg.total_supervised_tokens: + if ( + not skip_estimates + and not cfg.total_supervised_tokens + and not cfg.skip_prepare_dataset + ): total_supervised_tokens = ( train_dataset.data.column("labels") .to_pandas() @@ -478,13 +482,15 @@ def prepare_opinionated_env(cfg): os.environ["TOKENIZERS_PARALLELISM"] = "false" -def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): +def setup_trainer( + cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps +): if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: - trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] else: - trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 28210b7ae..20533504c 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -73,7 +73,7 @@ class TestAssistantChatTemplateLlama3: strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + chat_template=chat_templates("llama3"), message_field_role="role", message_field_content="content", roles={ @@ -113,7 +113,7 @@ class TestAssistantChatTemplateLlama3: strategy = ChatTemplateStrategy( ChatTemplatePrompter( phi35_tokenizer, - chat_templates("phi_35"), + chat_template=chat_templates("phi_35"), message_field_role="role", message_field_content="content", roles={ @@ -171,7 +171,7 @@ class TestAssistantChatTemplateLlama3: strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + chat_template=chat_templates("llama3"), message_field_role="role", message_field_content="content", message_field_training="training", @@ -227,8 +227,11 @@ class TestSharegptChatTemplateLlama3: def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts") + # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", @@ -277,8 +280,11 @@ class TestSharegptChatTemplateLlama3: def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 human prompts") + # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", @@ -327,8 +333,11 @@ class TestSharegptChatTemplateLlama3: def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts") + # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index f18fb3942..50429e3a2 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -34,7 +34,9 @@ class TestChatTemplateConfigurations: def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=True") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=True, sequence_len=512, @@ -77,7 +79,9 @@ class TestChatTemplateConfigurations: def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=False") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -118,7 +122,9 @@ class TestChatTemplateConfigurations: def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with assistant only") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -144,7 +150,9 @@ class TestChatTemplateConfigurations: def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with all roles") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=True, sequence_len=512, @@ -175,7 +183,9 @@ class TestChatTemplateConfigurations: def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with empty roles_to_train") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -194,7 +204,9 @@ class TestChatTemplateConfigurations: def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='all'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -219,7 +231,9 @@ class TestChatTemplateConfigurations: def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='turn'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -267,7 +281,9 @@ class TestChatTemplateConfigurations: def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='last'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -298,7 +314,9 @@ class TestChatTemplateConfigurations: def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='none'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -324,7 +342,9 @@ class TestChatTemplateConfigurations: LOG.info("Testing with drop_system_message=True") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), drop_system_message=True + llama3_tokenizer, + chat_template=chat_templates("llama3"), + drop_system_message=True, ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -350,7 +370,9 @@ class TestChatTemplateConfigurations: } strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), roles=custom_roles + llama3_tokenizer, + chat_template=chat_templates("llama3"), + roles=custom_roles, ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -402,7 +424,7 @@ class TestChatTemplateConfigurations: strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + chat_template=chat_templates("llama3"), message_field_training="train", message_field_training_detail="train_detail", ), From 4ca0a47cfb884f2d4421785982e64220b26f48df Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 9 Oct 2024 08:43:11 -0400 Subject: [PATCH 08/12] add 2.4.1 to base models (#1953) --- .github/workflows/base.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 9101fc2be..5e8c8fc33 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -30,6 +30,12 @@ jobs: python_version: "3.11" pytorch: 2.4.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + - cuda: "124" + cuda_version: 12.4.1 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.4.1 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" steps: - name: Checkout uses: actions/checkout@v3 From e8d3da00814ec7773d33edd5643bb885d85686cb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 9 Oct 2024 11:53:56 -0400 Subject: [PATCH 09/12] upgrade pytorch from 2.4.0 => 2.4.1 (#1950) * upgrade pytorch from 2.4.0 => 2.4.1 * update xformers for updated pytorch version * handle xformers version case for torch==2.3.1 --- .github/workflows/base.yml | 2 +- .github/workflows/main.yml | 4 ++-- .github/workflows/nightlies.yml | 4 ++-- .github/workflows/tests-nightly.yml | 4 ++-- .github/workflows/tests.yml | 4 ++-- requirements.txt | 2 +- setup.py | 7 +++++++ 7 files changed, 17 insertions(+), 10 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 5e8c8fc33..1b24f2c97 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -28,7 +28,7 @@ jobs: cuda_version: 12.4.1 cudnn_version: "" python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "124" cuda_version: 12.4.1 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5a972f5f0..c27dbedef 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,7 +27,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -84,7 +84,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 1d95a0983..17c76c24e 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -26,7 +26,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -83,7 +83,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 30ed397ce..8c9e1f49e 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.0"] + pytorch_version: ["2.3.1", "2.4.1"] timeout-minutes: 20 steps: @@ -91,7 +91,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 num_gpus: 1 axolotl_extras: nightly_build: "true" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c104e92c2..a798bdd5c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.0"] + pytorch_version: ["2.3.1", "2.4.1"] timeout-minutes: 20 steps: @@ -94,7 +94,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 num_gpus: 1 axolotl_extras: steps: diff --git a/requirements.txt b/requirements.txt index 123a4ee54..41bfdfbeb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ flash-attn==2.6.3 sentencepiece wandb einops -xformers==0.0.27 +xformers==0.0.28.post1 optimum==1.16.2 hf_transfer colorama diff --git a/setup.py b/setup.py index 1b64fadae..e939bc37e 100644 --- a/setup.py +++ b/setup.py @@ -49,10 +49,17 @@ def parse_requirements(): else: raise ValueError("Invalid version format") + if (major, minor) >= (2, 4): + if patch == 0: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.27") if (major, minor) >= (2, 3): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.26.post1") + else: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.27") elif (major, minor) >= (2, 2): _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.25.post1") From a560593b1dbac3f3afcbe6bdf975c9c9e5a5afcc Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 10 Oct 2024 03:02:32 +0700 Subject: [PATCH 10/12] fix(log): update perplexity log to clarify from eval split (#1952) [skip ci] --- src/axolotl/utils/callbacks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 73715b06a..acc2238a4 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): references=[[r] for r in references], predictions=predictions, ) - scores[metric_name] = score + scores["eval_" + metric_name] = score return scores def predict_with_generate(): From dee77232feb5c7e41216e5586da3ec4407638846 Mon Sep 17 00:00:00 2001 From: aarush gupta Date: Wed, 9 Oct 2024 13:03:16 -0700 Subject: [PATCH 11/12] fix type annotations (#1941) [skip ci] --- src/axolotl/monkeypatch/relora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index e4352cbe3..9d246cb17 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio): def reset_optimizer( optimizer: torch.optim.Optimizer, *, - reset_params: list[str], # where str is the key to a torch.nn.Parameter - optimizer_state_keys: list[str], + reset_params: List[str], # where str is the key to a torch.nn.Parameter + optimizer_state_keys: List[str], prune_ratio: float = 0.9, ): pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) From 6d3caadf90a9d4faafe8e167441355d128c66537 Mon Sep 17 00:00:00 2001 From: Boris Feld Date: Wed, 9 Oct 2024 22:03:37 +0200 Subject: [PATCH 12/12] Comet integration (#1939) * Add first version of a Comet integration * Remove debug prints * Add test for Comet Configuration transformation to env variables * Fix last lint warning * Update Readme for Comet logging documentation * Update Comet integration to be optional, update code and tests * Add documentation for Comet configuration * Add missing check --- .isort.cfg | 2 +- README.md | 18 ++- docs/config.qmd | 12 ++ src/axolotl/cli/__init__.py | 3 + src/axolotl/core/trainer_builder.py | 15 ++- src/axolotl/utils/__init__.py | 6 +- src/axolotl/utils/callbacks/__init__.py | 11 +- src/axolotl/utils/callbacks/comet_.py | 43 ++++++++ src/axolotl/utils/comet_.py | 93 ++++++++++++++++ .../config/models/input/v0_4_1/__init__.py | 14 +++ tests/test_validation.py | 103 ++++++++++++++++++ 11 files changed, 315 insertions(+), 5 deletions(-) create mode 100644 src/axolotl/utils/callbacks/comet_.py create mode 100644 src/axolotl/utils/comet_.py diff --git a/.isort.cfg b/.isort.cfg index 79067a7c9..e48779732 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,3 +1,3 @@ [settings] profile=black -known_third_party=wandb +known_third_party=wandb,comet_ml diff --git a/README.md b/README.md index c84f1cb8c..f6f4e4e80 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Features: - Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking - Works with single GPU or multiple GPUs via FSDP or Deepspeed - Easily run with Docker locally or on the cloud -- Log results and optionally checkpoints to wandb or mlflow +- Log results and optionally checkpoints to wandb, mlflow or Comet - And more! @@ -515,6 +515,22 @@ wandb_name: wandb_log_model: ``` +##### Comet Logging + +Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`. + +- wandb options +```yaml +use_comet: +comet_api_key: +comet_workspace: +comet_project_name: +comet_experiment_key: +comet_mode: +comet_online: +comet_experiment_config: +``` + ##### Special Tokens It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: diff --git a/docs/config.qmd b/docs/config.qmd index e85999978..99a69a097 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -267,6 +267,18 @@ mlflow_tracking_uri: # URI to mlflow mlflow_experiment_name: # Your experiment name hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry +# Comet configuration if you're using it +# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`. +# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start +use_comet: # Enable or disable Comet integration. +comet_api_key: # API key for Comet. Recommended to set via `comet login`. +comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace. +comet_project_name: # Project name in Comet. Defaults to Uncategorized. +comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key. +comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration. +comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True. +comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details. + # Where to save the full-finetuned model to output_dir: ./completed-model diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index a1d84b6a1..db975501a 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -31,6 +31,7 @@ from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( normalize_cfg_datasets, normalize_config, @@ -421,6 +422,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): setup_mlflow_env_vars(cfg) + setup_comet_env_vars(cfg) + return cfg diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4893e63dc..b1ee519dc 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -48,7 +48,7 @@ from trl.trainer.utils import pad_to_length from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler -from axolotl.utils import is_mlflow_available +from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, GPUStatsCallback, @@ -1111,6 +1111,12 @@ class TrainerBuilderBase(abc.ABC): callbacks.append( SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_comet and is_comet_available(): + from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback + + callbacks.append( + SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) + ) return callbacks @@ -1179,6 +1185,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer, self.tokenizer, "mlflow" ) callbacks.append(LogPredictionCallback(self.cfg)) + if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0: + LogPredictionCallback = log_prediction_callback_factory( + trainer, self.tokenizer, "comet_ml" + ) + callbacks.append(LogPredictionCallback(self.cfg)) if self.cfg.do_bench_eval: callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) @@ -1430,6 +1441,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): report_to.append("mlflow") if self.cfg.use_tensorboard: report_to.append("tensorboard") + if self.cfg.use_comet: + report_to.append("comet_ml") training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["run_name"] = ( diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 99dec79f1..91545009a 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -1,8 +1,12 @@ """ Basic utils for Axolotl """ -import importlib +import importlib.util def is_mlflow_available(): return importlib.util.find_spec("mlflow") is not None + + +def is_comet_available(): + return importlib.util.find_spec("comet_ml") is not None diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index acc2238a4..0bc781fcb 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -29,7 +29,7 @@ from transformers import ( ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy -from axolotl.utils import is_mlflow_available +from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -747,6 +747,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): artifact_file="PredictionsVsGroundTruth.json", tracking_uri=tracking_uri, ) + elif logger == "comet_ml" and is_comet_available(): + import comet_ml + + experiment = comet_ml.get_running_experiment() + if experiment: + experiment.log_table( + f"{name} - Predictions vs Ground Truth.csv", + pd.DataFrame(table_data), + ) if is_main_process(): log_table_from_dataloader("Eval", eval_dataloader) diff --git a/src/axolotl/utils/callbacks/comet_.py b/src/axolotl/utils/callbacks/comet_.py new file mode 100644 index 000000000..b29f997a8 --- /dev/null +++ b/src/axolotl/utils/callbacks/comet_.py @@ -0,0 +1,43 @@ +"""Comet module for trainer callbacks""" + +import logging +from typing import TYPE_CHECKING + +import comet_ml +from transformers import TrainerCallback, TrainerControl, TrainerState + +from axolotl.utils.distributed import is_main_process + +if TYPE_CHECKING: + from axolotl.core.trainer_builder import AxolotlTrainingArguments + +LOG = logging.getLogger("axolotl.callbacks") + + +class SaveAxolotlConfigtoCometCallback(TrainerCallback): + """Callback to save axolotl config to comet""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: "AxolotlTrainingArguments", # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, # pylint: disable=unused-argument + ): + if is_main_process(): + try: + comet_experiment = comet_ml.start(source="axolotl") + comet_experiment.log_other("Created from", "axolotl") + comet_experiment.log_asset( + self.axolotl_config_path, + file_name="axolotl-config", + ) + LOG.info( + "The Axolotl config has been saved to the Comet Experiment under assets." + ) + except (FileNotFoundError, ConnectionError) as err: + LOG.warning(f"Error while saving Axolotl config to Comet: {err}") + return control diff --git a/src/axolotl/utils/comet_.py b/src/axolotl/utils/comet_.py new file mode 100644 index 000000000..b4ecc80ad --- /dev/null +++ b/src/axolotl/utils/comet_.py @@ -0,0 +1,93 @@ +"""Module for wandb utilities""" + +import logging +import os + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.utils.comet_") + +COMET_ENV_MAPPING_OVERRIDE = { + "comet_mode": "COMET_START_MODE", + "comet_online": "COMET_START_ONLINE", +} +COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = { + "auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS", + "auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE", + "auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS", + "auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD", + "auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS", + "auto_log_co2": "COMET_AUTO_LOG_CO2", + "auto_metric_logging": "COMET_AUTO_LOG_METRICS", + "auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE", + "auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER", + "auto_param_logging": "COMET_AUTO_LOG_PARAMETERS", + "comet_disabled": "COMET_AUTO_LOG_DISABLE", + "display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL", + "distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER", + "log_code": "COMET_AUTO_LOG_CODE", + "log_env_cpu": "COMET_AUTO_LOG_ENV_CPU", + "log_env_details": "COMET_AUTO_LOG_ENV_DETAILS", + "log_env_disk": "COMET_AUTO_LOG_ENV_DISK", + "log_env_gpu": "COMET_AUTO_LOG_ENV_GPU", + "log_env_host": "COMET_AUTO_LOG_ENV_HOST", + "log_env_network": "COMET_AUTO_LOG_ENV_NETWORK", + "log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA", + "log_git_patch": "COMET_AUTO_LOG_GIT_PATCH", + "log_graph": "COMET_AUTO_LOG_GRAPH", + "name": "COMET_START_EXPERIMENT_NAME", + "offline_directory": "COMET_OFFLINE_DIRECTORY", + "parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS", + "tags": "COMET_START_EXPERIMENT_TAGS", +} + + +def python_value_to_environ_value(python_value): + if isinstance(python_value, bool): + if python_value is True: + return "true" + + return "false" + + if isinstance(python_value, int): + return str(python_value) + + if isinstance(python_value, list): # Comet only have one list of string parameter + return ",".join(map(str, python_value)) + + return python_value + + +def setup_comet_env_vars(cfg: DictDefault): + # TODO, we need to convert Axolotl configuration to environment variables + # as Transformers integration are call first and would create an + # Experiment first + + for key in cfg.keys(): + if key.startswith("comet_") and key != "comet_experiment_config": + value = cfg.get(key, "") + + if value is not None and value != "": + env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper()) + final_value = python_value_to_environ_value(value) + os.environ[env_variable_name] = final_value + + if cfg.comet_experiment_config: + for key, value in cfg.comet_experiment_config.items(): + if value is not None and value != "": + config_env_variable_name = ( + COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key) + ) + + if config_env_variable_name is None: + LOG.warning( + f"Unknown Comet Experiment Config name {key}, ignoring it" + ) + continue + + final_value = python_value_to_environ_value(value) + os.environ[config_env_variable_name] = final_value + + # Enable comet if project name is present + if cfg.comet_project_name and len(cfg.comet_project_name) > 0: + cfg.use_comet = True diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index fced5e639..76748191b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -489,6 +489,19 @@ class WandbConfig(BaseModel): return data +class CometConfig(BaseModel): + """Comet configuration subset""" + + use_comet: Optional[bool] = None + comet_api_key: Optional[str] = None + comet_workspace: Optional[str] = None + comet_project_name: Optional[str] = None + comet_experiment_key: Optional[str] = None + comet_mode: Optional[str] = None + comet_online: Optional[bool] = None + comet_experiment_config: Optional[Dict[str, Any]] = None + + class GradioConfig(BaseModel): """Gradio configuration subset""" @@ -509,6 +522,7 @@ class AxolotlInputConfig( HyperparametersConfig, WandbConfig, MLFlowConfig, + CometConfig, LISAConfig, GradioConfig, RemappedParameters, diff --git a/tests/test_validation.py b/tests/test_validation.py index 35d0e265e..6e0d0ad2a 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -9,6 +9,7 @@ from typing import Optional import pytest from pydantic import ValidationError +from axolotl.utils import is_comet_available from axolotl.utils.config import validate_config from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.dict import DictDefault @@ -1329,3 +1330,105 @@ class TestValidationWandb(BaseValidation): os.environ.pop("WANDB_PROJECT", None) os.environ.pop("WANDB_DISABLED", None) + + +@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed") +class TestValidationComet(BaseValidation): + """ + Validation test for comet + """ + + def test_comet_sets_env(self, minimal_cfg): + from axolotl.utils.comet_ import setup_comet_env_vars + + comet_config = { + "comet_api_key": "foo", + "comet_workspace": "some_workspace", + "comet_project_name": "some_project", + "comet_experiment_key": "some_experiment_key", + "comet_mode": "get_or_create", + "comet_online": False, + "comet_experiment_config": { + "auto_histogram_activation_logging": False, + "auto_histogram_epoch_rate": 2, + "auto_histogram_gradient_logging": True, + "auto_histogram_tensorboard_logging": False, + "auto_histogram_weight_logging": True, + "auto_log_co2": False, + "auto_metric_logging": True, + "auto_metric_step_rate": 15, + "auto_output_logging": False, + "auto_param_logging": True, + "comet_disabled": False, + "display_summary_level": 2, + "distributed_node_identifier": "some_distributed_node_identifier", + "log_code": True, + "log_env_cpu": False, + "log_env_details": True, + "log_env_disk": False, + "log_env_gpu": True, + "log_env_host": False, + "log_env_network": True, + "log_git_metadata": False, + "log_git_patch": True, + "log_graph": False, + "name": "some_name", + "offline_directory": "some_offline_directory", + "parse_args": True, + "tags": ["tag1", "tag2"], + }, + } + + cfg = DictDefault(comet_config) | minimal_cfg + + new_cfg = validate_config(cfg) + + setup_comet_env_vars(new_cfg) + + comet_env = { + key: value for key, value in os.environ.items() if key.startswith("COMET_") + } + + assert ( + len(comet_env) + == len(comet_config) + len(comet_config["comet_experiment_config"]) - 1 + ) + + assert comet_env == { + "COMET_API_KEY": "foo", + "COMET_AUTO_LOG_CLI_ARGUMENTS": "true", + "COMET_AUTO_LOG_CO2": "false", + "COMET_AUTO_LOG_CODE": "true", + "COMET_AUTO_LOG_DISABLE": "false", + "COMET_AUTO_LOG_ENV_CPU": "false", + "COMET_AUTO_LOG_ENV_DETAILS": "true", + "COMET_AUTO_LOG_ENV_DISK": "false", + "COMET_AUTO_LOG_ENV_GPU": "true", + "COMET_AUTO_LOG_ENV_HOST": "false", + "COMET_AUTO_LOG_ENV_NETWORK": "true", + "COMET_AUTO_LOG_GIT_METADATA": "false", + "COMET_AUTO_LOG_GIT_PATCH": "true", + "COMET_AUTO_LOG_GRAPH": "false", + "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false", + "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2", + "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true", + "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false", + "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true", + "COMET_AUTO_LOG_METRIC_STEP_RATE": "15", + "COMET_AUTO_LOG_METRICS": "true", + "COMET_AUTO_LOG_OUTPUT_LOGGER": "false", + "COMET_AUTO_LOG_PARAMETERS": "true", + "COMET_DISPLAY_SUMMARY_LEVEL": "2", + "COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier", + "COMET_EXPERIMENT_KEY": "some_experiment_key", + "COMET_OFFLINE_DIRECTORY": "some_offline_directory", + "COMET_PROJECT_NAME": "some_project", + "COMET_START_EXPERIMENT_NAME": "some_name", + "COMET_START_EXPERIMENT_TAGS": "tag1,tag2", + "COMET_START_MODE": "get_or_create", + "COMET_START_ONLINE": "false", + "COMET_WORKSPACE": "some_workspace", + } + + for key in comet_env.keys(): + os.environ.pop(key, None)