Compare commits

...

4 Commits

Author SHA1 Message Date
NanoCode012
08c8f3f22f fix: total tokens and defaults in config 2025-12-02 21:38:10 +07:00
NanoCode012
76f0fe2621 fix: steps not allowed fractional 2025-12-02 21:30:15 +07:00
Yohan Na
c6ddcdd06a feat: add exaone4 chat template and update enums (#3279)
* feat: add exaone4 chat template and update enums

* fix: handle first message as system or tools in exaone4 chat template

* Update src/axolotl/utils/chat_templates/templates/exaone4.jinja

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* fix: lint

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-12-01 15:52:45 +07:00
github-actions[bot]
7fb6a947d9 chore: update pre-commit hooks (#3287)
Co-authored-by: SalmanMohammadi <25081738+SalmanMohammadi@users.noreply.github.com>
2025-12-01 15:03:14 +07:00
5 changed files with 135 additions and 6 deletions

View File

@@ -11,13 +11,13 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.3
rev: v0.14.7
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
rev: v1.19.0
hooks:
- id: mypy
additional_dependencies:
@@ -26,7 +26,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.8.6
rev: 1.9.2
hooks:
- id: bandit
args: [

View File

@@ -30,7 +30,7 @@ eval_sample_packing: true
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
warmup_steps: 0.1
warmup_ratio: 0.1
optimizer: adamw_8bit
lr_scheduler: cosine
@@ -44,7 +44,7 @@ resume_from_checkpoint:
sdp_attention: true
logging_steps: 1
save_strategy: best
save_strategy: epoch
eval_strategy: epoch
special_tokens:

View File

@@ -631,7 +631,9 @@ class AxolotlTrainer(
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
logs["total_tokens"] = int(self.state.total_tokens.item())
if hasattr(self.state, "total_tokens"):
logs["total_tokens"] = int(self.state.total_tokens.item())
del self._stored_metrics[train_eval]

View File

@@ -0,0 +1,126 @@
{%- if not skip_think is defined %}
{%- set skip_think = true %}
{%- endif %}
{%- set role_indicators = {
'user': '[|user|]\n',
'assistant': '[|assistant|]\n',
'system': '[|system|]\n',
'tool': '[|tool|]\n'
} %}
{%- set end_of_turn = '[|endofturn|]\n' %}
{%- macro available_tools(tools) %}
{{- "# Available Tools" }}
{{- "\nYou can use none, one, or multiple of the following tools by calling them as functions to help with the users query." }}
{{- "\nHere are the tools available to you in JSON format within <tool> and </tool> tags:\n" }}
{%- for tool in tools %}
{{- "<tool>" }}
{{- tool | tojson(ensure_ascii=False) | safe }}
{{- "</tool>\n" }}
{%- endfor %}
{{- "\nFor each function call you want to make, return a JSON object with function name and arguments within <tool_call> and </tool_call> tags, like:" }}
{{- "\n<tool_call>{\"name\": function_1_name, \"arguments\": {argument_1_name: argument_1_value, argument_2_name: argument_2_value}}</tool_call>" }}
{{- "\n<tool_call>{\"name\": function_2_name, \"arguments\": {...}}</tool_call>\n..." }}
{{- "\nNote that if no argument name is specified for a tool, you can just print the argument value directly, without the argument name or JSON formatting." }}
{%- endmacro %}
{%- set ns = namespace(last_query_index = messages|length - 1) %}
{%- for message in messages %}
{%- if message.role == "user" and message.content is string %}
{%- set ns.last_query_index = loop.index0 -%}
{%- endif %}
{%- endfor %}
{%- for i in range(messages | length) %}
{%- set msg = messages[i] %}
{%- set role = msg.role %}
{%- if role not in role_indicators %}
{{- raise_exception('Unknown role: ' ~ role) }}
{%- endif %}
{# ---- Case A: If the first message is "system", handle it here alone (without continue) ---- #}
{%- if i == 0 and role == 'system' %}
{{- role_indicators['system'] }}
{{- msg.content }}
{%- if tools is defined and tools %}
{{- "\n\n" }}{{- available_tools(tools) }}
{%- endif %}
{{- end_of_turn -}}
{%- else %}
{# ---- Case B: If the first message is tools instead of system, inject the system tools preamble ---- #}
{%- if i == 0 and tools is defined and tools %}
{{- role_indicators['system'] }}
{{- available_tools(tools) }}
{{- end_of_turn -}}
{%- endif %}
{%- endif %}
{%- if role == 'assistant' %}
{{- role_indicators['assistant'] }}
{%- if msg.content %}
{%- if "</think>" in msg.content %}
{%- set content = msg.content.split('</think>')[-1].strip() %}
{%- set reasoning_content = msg.content.split('</think>')[0].strip() %}
{%- if reasoning_content.startswith("<think>") %}
{%- set reasoning_content = reasoning_content[7:].strip() %}
{%- endif %}
{%- else %}
{%- set content = msg.content %}
{%- endif %}
{%- if msg.reasoning_content %}
{%- set reasoning_content = msg.reasoning_content %}
{%- endif %}
{%- if (not skip_think and loop.last) and reasoning_content is defined %}
{{- "<think>\n" }}
{{- reasoning_content}}
{{- "\n</think>\n\n" }}
{%- else %}
{{- "<think>\n\n</think>\n\n" }}
{%- endif %}
{{- content }}
{%- endif %}
{%- if msg.tool_calls %}
{%- if msg.content %}
{{- "\n" }}
{%- else %}
{{- "<think>\n\n</think>\n\n" }}
{%- endif %}
{%- for tool_call in msg.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- if tool_call.arguments is defined %}
{%- set arguments = tool_call.arguments %}
{%- elif tool_call.parameters is defined %}
{%- set arguments = tool_call.parameters %}
{%- else %}
{{- raise_exception('arguments or parameters are mandatory: ' ~ tool_call) }}
{%- endif %}
{{- "<tool_call>" }}{"name": "{{- tool_call.name }}", "arguments": {{ arguments | tojson(ensure_ascii=False) | safe }}}{{- "</tool_call>" }}
{%- if not loop.last %}
{{- "\n" }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- end_of_turn -}}
{%- elif role == "tool" %}
{%- if i == 0 or messages[i - 1].role != "tool" %}
{{- role_indicators['tool'] }}
{%- endif %}
{%- if msg.content is defined %}
{{- "<tool_result>" }}{"result": {{ msg.content | tojson(ensure_ascii=False) | safe }}}{{- "</tool_result>" }}
{%- endif %}
{%- if loop.last or messages[i + 1].role != "tool" %}
{{- end_of_turn -}}
{%- else %}
{{- "\n" }}
{%- endif %}
{%- else %}
{{- role_indicators[role] }}
{{- msg.content }}
{{- end_of_turn -}}
{%- endif %}
{% endfor %}
{%- if add_generation_prompt %}
{{- role_indicators['assistant'] }}
{%- if enable_thinking is defined and enable_thinking is true %}
{{- "<think>\n" }}
{%- else %}
{{- "<think>\n\n</think>\n\n" }}
{%- endif %}
{%- endif %}

View File

@@ -58,6 +58,7 @@ class ChatTemplate(str, Enum):
falcon_h1 = "falcon_h1"
tokenizer_default = "tokenizer_default"
exaone = "exaone"
exaone4 = "exaone4"
metharme = "metharme"
pixtral = "pixtral"
llava = "llava"