Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
dcd916b29b bump transformers 4.57.3 2025-12-02 10:33:44 -05:00
Yohan Na
c6ddcdd06a feat: add exaone4 chat template and update enums (#3279)
* feat: add exaone4 chat template and update enums

* fix: handle first message as system or tools in exaone4 chat template

* Update src/axolotl/utils/chat_templates/templates/exaone4.jinja

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* fix: lint

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-12-01 15:52:45 +07:00
github-actions[bot]
7fb6a947d9 chore: update pre-commit hooks (#3287)
Co-authored-by: SalmanMohammadi <25081738+SalmanMohammadi@users.noreply.github.com>
2025-12-01 15:03:14 +07:00
5 changed files with 132 additions and 44 deletions

View File

@@ -11,13 +11,13 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.3
rev: v0.14.7
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
rev: v1.19.0
hooks:
- id: mypy
additional_dependencies:
@@ -26,7 +26,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.8.6
rev: 1.9.2
hooks:
- id: bandit
args: [

View File

@@ -13,7 +13,7 @@ packaging==23.2
huggingface_hub>=0.36.0
peft>=0.18.0
tokenizers>=0.22.1
transformers==4.57.1
transformers==4.57.3
accelerate==1.11.0
datasets==4.4.1
deepspeed>=0.17.0

View File

@@ -23,29 +23,6 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
"""
Apply LIGER runtime patches and integrations according to the provided configuration.
This hook inspects `cfg` and conditionally applies LIGER kernel patches, replacements, and model-specific integrations (rotary embeddings, normalization, GLU variants, and cross-entropy implementations) for the model type indicated by `cfg.model_config_type`. Behavior is driven entirely by the various `cfg.liger_*` flags; the method logs actions and warnings when support is experimental or unavailable.
Parameters:
cfg: Configuration object containing LIGER-related flags and model identification. Expected attributes include:
- model_config_type (str): Target model config type to determine which patches to apply.
- base_model (str): Base model identifier used when probing model modules (used for some model types).
- trust_remote_code (bool|None): Passed when loading remote model code (used for some model types).
- torch_compile (bool): If true, disable torch.compile optimizations for certain LIGER kernels.
- liger_cross_entropy (bool)
- liger_fused_linear_cross_entropy (bool)
- liger_use_token_scaling (bool)
- liger_rope (bool)
- liger_rms_norm (bool)
- liger_layer_norm (bool)
- liger_glu_activation (str|bool): Name or flag for GLU/SwiGLU activation selection.
(Other LIGER flags referenced by the code may also be consulted.)
Raises:
ValueError: If both `cfg.liger_cross_entropy` and `cfg.liger_fused_linear_cross_entropy` are enabled.
"""
if cfg.torch_compile:
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
import liger_kernel.ops.fused_linear_cross_entropy
@@ -191,22 +168,6 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_vl":
"""
Apply Liger kernels for Qwen3 Vision-Language models.
Note: The parameter 'swiglu' is used instead of 'glu_activation' to match
the Liger kernel API for vision-language models.
"""
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl
apply_liger_kernel_to_qwen3_vl(
rope=cfg.liger_rope,
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation, # Note: qwen3_vl uses swiglu parameter name
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
@@ -245,4 +206,4 @@ class LigerPlugin(BasePlugin):
else:
LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)
)

View File

@@ -0,0 +1,126 @@
{%- if not skip_think is defined %}
{%- set skip_think = true %}
{%- endif %}
{%- set role_indicators = {
'user': '[|user|]\n',
'assistant': '[|assistant|]\n',
'system': '[|system|]\n',
'tool': '[|tool|]\n'
} %}
{%- set end_of_turn = '[|endofturn|]\n' %}
{%- macro available_tools(tools) %}
{{- "# Available Tools" }}
{{- "\nYou can use none, one, or multiple of the following tools by calling them as functions to help with the users query." }}
{{- "\nHere are the tools available to you in JSON format within <tool> and </tool> tags:\n" }}
{%- for tool in tools %}
{{- "<tool>" }}
{{- tool | tojson(ensure_ascii=False) | safe }}
{{- "</tool>\n" }}
{%- endfor %}
{{- "\nFor each function call you want to make, return a JSON object with function name and arguments within <tool_call> and </tool_call> tags, like:" }}
{{- "\n<tool_call>{\"name\": function_1_name, \"arguments\": {argument_1_name: argument_1_value, argument_2_name: argument_2_value}}</tool_call>" }}
{{- "\n<tool_call>{\"name\": function_2_name, \"arguments\": {...}}</tool_call>\n..." }}
{{- "\nNote that if no argument name is specified for a tool, you can just print the argument value directly, without the argument name or JSON formatting." }}
{%- endmacro %}
{%- set ns = namespace(last_query_index = messages|length - 1) %}
{%- for message in messages %}
{%- if message.role == "user" and message.content is string %}
{%- set ns.last_query_index = loop.index0 -%}
{%- endif %}
{%- endfor %}
{%- for i in range(messages | length) %}
{%- set msg = messages[i] %}
{%- set role = msg.role %}
{%- if role not in role_indicators %}
{{- raise_exception('Unknown role: ' ~ role) }}
{%- endif %}
{# ---- Case A: If the first message is "system", handle it here alone (without continue) ---- #}
{%- if i == 0 and role == 'system' %}
{{- role_indicators['system'] }}
{{- msg.content }}
{%- if tools is defined and tools %}
{{- "\n\n" }}{{- available_tools(tools) }}
{%- endif %}
{{- end_of_turn -}}
{%- else %}
{# ---- Case B: If the first message is tools instead of system, inject the system tools preamble ---- #}
{%- if i == 0 and tools is defined and tools %}
{{- role_indicators['system'] }}
{{- available_tools(tools) }}
{{- end_of_turn -}}
{%- endif %}
{%- endif %}
{%- if role == 'assistant' %}
{{- role_indicators['assistant'] }}
{%- if msg.content %}
{%- if "</think>" in msg.content %}
{%- set content = msg.content.split('</think>')[-1].strip() %}
{%- set reasoning_content = msg.content.split('</think>')[0].strip() %}
{%- if reasoning_content.startswith("<think>") %}
{%- set reasoning_content = reasoning_content[7:].strip() %}
{%- endif %}
{%- else %}
{%- set content = msg.content %}
{%- endif %}
{%- if msg.reasoning_content %}
{%- set reasoning_content = msg.reasoning_content %}
{%- endif %}
{%- if (not skip_think and loop.last) and reasoning_content is defined %}
{{- "<think>\n" }}
{{- reasoning_content}}
{{- "\n</think>\n\n" }}
{%- else %}
{{- "<think>\n\n</think>\n\n" }}
{%- endif %}
{{- content }}
{%- endif %}
{%- if msg.tool_calls %}
{%- if msg.content %}
{{- "\n" }}
{%- else %}
{{- "<think>\n\n</think>\n\n" }}
{%- endif %}
{%- for tool_call in msg.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- if tool_call.arguments is defined %}
{%- set arguments = tool_call.arguments %}
{%- elif tool_call.parameters is defined %}
{%- set arguments = tool_call.parameters %}
{%- else %}
{{- raise_exception('arguments or parameters are mandatory: ' ~ tool_call) }}
{%- endif %}
{{- "<tool_call>" }}{"name": "{{- tool_call.name }}", "arguments": {{ arguments | tojson(ensure_ascii=False) | safe }}}{{- "</tool_call>" }}
{%- if not loop.last %}
{{- "\n" }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- end_of_turn -}}
{%- elif role == "tool" %}
{%- if i == 0 or messages[i - 1].role != "tool" %}
{{- role_indicators['tool'] }}
{%- endif %}
{%- if msg.content is defined %}
{{- "<tool_result>" }}{"result": {{ msg.content | tojson(ensure_ascii=False) | safe }}}{{- "</tool_result>" }}
{%- endif %}
{%- if loop.last or messages[i + 1].role != "tool" %}
{{- end_of_turn -}}
{%- else %}
{{- "\n" }}
{%- endif %}
{%- else %}
{{- role_indicators[role] }}
{{- msg.content }}
{{- end_of_turn -}}
{%- endif %}
{% endfor %}
{%- if add_generation_prompt %}
{{- role_indicators['assistant'] }}
{%- if enable_thinking is defined and enable_thinking is true %}
{{- "<think>\n" }}
{%- else %}
{{- "<think>\n\n</think>\n\n" }}
{%- endif %}
{%- endif %}

View File

@@ -58,6 +58,7 @@ class ChatTemplate(str, Enum):
falcon_h1 = "falcon_h1"
tokenizer_default = "tokenizer_default"
exaone = "exaone"
exaone4 = "exaone4"
metharme = "metharme"
pixtral = "pixtral"
llava = "llava"