diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index 7610b335a..827a63c07 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -11,4 +11,5 @@ MOE_ARCH_BLOCK = { ], "mixtral": "MixtralSparseMoeBlock", "qwen2_moe": "Qwen2MoeSparseMoeBlock", + "deepseek_v2": "DeepseekV2MoE", } diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index a2ce0e64f..904352010 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -25,12 +25,12 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ ] -def patch_for_multipack(model_type, model_name=None): +def patch_for_multipack(model_type, model_name=None, is_remote_code=False): if model_type == "gemmoe": patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") elif model_type == "deepseek_v2": patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") - elif hasattr(transformers, "modeling_flash_attention_utils"): + elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code: transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 725934cf5..ca4334d75 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -26,6 +26,7 @@ def chat_templates(user_choice: str): "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", } if user_choice in templates: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4fb020bd5..b765263ba 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -189,6 +189,7 @@ class ChatTemplate(str, Enum): cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name + deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name class LoftQConfig(BaseModel): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index f65da71d4..87f50d9a2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -348,7 +348,11 @@ def load_model( and cfg.flash_attention and cfg.sample_packing ): - patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) + patch_for_multipack( + cfg.model_config_type, + model_name=cfg.base_model, + is_remote_code=cfg.trust_remote_code, + ) if cfg.is_llama_derived_model: from axolotl.monkeypatch.llama_attn_hijack_flash import (