Compare commits
1 Commits
fix/diffus
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93600fa80d |
@@ -11,13 +11,13 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.7
|
||||
rev: v0.14.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.19.0
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
@@ -26,7 +26,7 @@ repos:
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.9.2
|
||||
rev: 1.8.6
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
|
||||
@@ -30,7 +30,7 @@ eval_sample_packing: true
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
warmup_ratio: 0.1
|
||||
warmup_steps: 0.1
|
||||
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
@@ -44,7 +44,7 @@ resume_from_checkpoint:
|
||||
sdp_attention: true
|
||||
|
||||
logging_steps: 1
|
||||
save_strategy: epoch
|
||||
save_strategy: best
|
||||
eval_strategy: epoch
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -631,9 +631,7 @@ class AxolotlTrainer(
|
||||
logs["tokens_per_second_per_gpu"] = round(
|
||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||
)
|
||||
|
||||
if hasattr(self.state, "total_tokens"):
|
||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
|
||||
@@ -23,6 +23,29 @@ class LigerPlugin(BasePlugin):
|
||||
return "axolotl.integrations.liger.LigerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""
|
||||
Apply LIGER runtime patches and integrations according to the provided configuration.
|
||||
|
||||
This hook inspects `cfg` and conditionally applies LIGER kernel patches, replacements, and model-specific integrations (rotary embeddings, normalization, GLU variants, and cross-entropy implementations) for the model type indicated by `cfg.model_config_type`. Behavior is driven entirely by the various `cfg.liger_*` flags; the method logs actions and warnings when support is experimental or unavailable.
|
||||
|
||||
Parameters:
|
||||
cfg: Configuration object containing LIGER-related flags and model identification. Expected attributes include:
|
||||
- model_config_type (str): Target model config type to determine which patches to apply.
|
||||
- base_model (str): Base model identifier used when probing model modules (used for some model types).
|
||||
- trust_remote_code (bool|None): Passed when loading remote model code (used for some model types).
|
||||
- torch_compile (bool): If true, disable torch.compile optimizations for certain LIGER kernels.
|
||||
- liger_cross_entropy (bool)
|
||||
- liger_fused_linear_cross_entropy (bool)
|
||||
- liger_use_token_scaling (bool)
|
||||
- liger_rope (bool)
|
||||
- liger_rms_norm (bool)
|
||||
- liger_layer_norm (bool)
|
||||
- liger_glu_activation (str|bool): Name or flag for GLU/SwiGLU activation selection.
|
||||
(Other LIGER flags referenced by the code may also be consulted.)
|
||||
|
||||
Raises:
|
||||
ValueError: If both `cfg.liger_cross_entropy` and `cfg.liger_fused_linear_cross_entropy` are enabled.
|
||||
"""
|
||||
if cfg.torch_compile:
|
||||
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
||||
import liger_kernel.ops.fused_linear_cross_entropy
|
||||
@@ -168,6 +191,22 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_vl":
|
||||
"""
|
||||
Apply Liger kernels for Qwen3 Vision-Language models.
|
||||
|
||||
Note: The parameter 'swiglu' is used instead of 'glu_activation' to match
|
||||
the Liger kernel API for vision-language models.
|
||||
"""
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl
|
||||
|
||||
apply_liger_kernel_to_qwen3_vl(
|
||||
rope=cfg.liger_rope,
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
swiglu=cfg.liger_glu_activation, # Note: qwen3_vl uses swiglu parameter name
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_moe":
|
||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||
apply_liger_kernel_to_qwen3_moe,
|
||||
@@ -206,4 +245,4 @@ class LigerPlugin(BasePlugin):
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||
)
|
||||
)
|
||||
@@ -1,126 +0,0 @@
|
||||
{%- if not skip_think is defined %}
|
||||
{%- set skip_think = true %}
|
||||
{%- endif %}
|
||||
{%- set role_indicators = {
|
||||
'user': '[|user|]\n',
|
||||
'assistant': '[|assistant|]\n',
|
||||
'system': '[|system|]\n',
|
||||
'tool': '[|tool|]\n'
|
||||
} %}
|
||||
{%- set end_of_turn = '[|endofturn|]\n' %}
|
||||
{%- macro available_tools(tools) %}
|
||||
{{- "# Available Tools" }}
|
||||
{{- "\nYou can use none, one, or multiple of the following tools by calling them as functions to help with the user’s query." }}
|
||||
{{- "\nHere are the tools available to you in JSON format within <tool> and </tool> tags:\n" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "<tool>" }}
|
||||
{{- tool | tojson(ensure_ascii=False) | safe }}
|
||||
{{- "</tool>\n" }}
|
||||
{%- endfor %}
|
||||
{{- "\nFor each function call you want to make, return a JSON object with function name and arguments within <tool_call> and </tool_call> tags, like:" }}
|
||||
{{- "\n<tool_call>{\"name\": function_1_name, \"arguments\": {argument_1_name: argument_1_value, argument_2_name: argument_2_value}}</tool_call>" }}
|
||||
{{- "\n<tool_call>{\"name\": function_2_name, \"arguments\": {...}}</tool_call>\n..." }}
|
||||
{{- "\nNote that if no argument name is specified for a tool, you can just print the argument value directly, without the argument name or JSON formatting." }}
|
||||
{%- endmacro %}
|
||||
{%- set ns = namespace(last_query_index = messages|length - 1) %}
|
||||
{%- for message in messages %}
|
||||
{%- if message.role == "user" and message.content is string %}
|
||||
{%- set ns.last_query_index = loop.index0 -%}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- for i in range(messages | length) %}
|
||||
{%- set msg = messages[i] %}
|
||||
{%- set role = msg.role %}
|
||||
{%- if role not in role_indicators %}
|
||||
{{- raise_exception('Unknown role: ' ~ role) }}
|
||||
{%- endif %}
|
||||
{# ---- Case A: If the first message is "system", handle it here alone (without continue) ---- #}
|
||||
{%- if i == 0 and role == 'system' %}
|
||||
{{- role_indicators['system'] }}
|
||||
{{- msg.content }}
|
||||
{%- if tools is defined and tools %}
|
||||
{{- "\n\n" }}{{- available_tools(tools) }}
|
||||
{%- endif %}
|
||||
{{- end_of_turn -}}
|
||||
{%- else %}
|
||||
{# ---- Case B: If the first message is tools instead of system, inject the system tools preamble ---- #}
|
||||
{%- if i == 0 and tools is defined and tools %}
|
||||
{{- role_indicators['system'] }}
|
||||
{{- available_tools(tools) }}
|
||||
{{- end_of_turn -}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if role == 'assistant' %}
|
||||
{{- role_indicators['assistant'] }}
|
||||
{%- if msg.content %}
|
||||
{%- if "</think>" in msg.content %}
|
||||
{%- set content = msg.content.split('</think>')[-1].strip() %}
|
||||
{%- set reasoning_content = msg.content.split('</think>')[0].strip() %}
|
||||
{%- if reasoning_content.startswith("<think>") %}
|
||||
{%- set reasoning_content = reasoning_content[7:].strip() %}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{%- set content = msg.content %}
|
||||
{%- endif %}
|
||||
{%- if msg.reasoning_content %}
|
||||
{%- set reasoning_content = msg.reasoning_content %}
|
||||
{%- endif %}
|
||||
{%- if (not skip_think and loop.last) and reasoning_content is defined %}
|
||||
{{- "<think>\n" }}
|
||||
{{- reasoning_content}}
|
||||
{{- "\n</think>\n\n" }}
|
||||
{%- else %}
|
||||
{{- "<think>\n\n</think>\n\n" }}
|
||||
{%- endif %}
|
||||
{{- content }}
|
||||
{%- endif %}
|
||||
{%- if msg.tool_calls %}
|
||||
{%- if msg.content %}
|
||||
{{- "\n" }}
|
||||
{%- else %}
|
||||
{{- "<think>\n\n</think>\n\n" }}
|
||||
{%- endif %}
|
||||
{%- for tool_call in msg.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{%- if tool_call.arguments is defined %}
|
||||
{%- set arguments = tool_call.arguments %}
|
||||
{%- elif tool_call.parameters is defined %}
|
||||
{%- set arguments = tool_call.parameters %}
|
||||
{%- else %}
|
||||
{{- raise_exception('arguments or parameters are mandatory: ' ~ tool_call) }}
|
||||
{%- endif %}
|
||||
{{- "<tool_call>" }}{"name": "{{- tool_call.name }}", "arguments": {{ arguments | tojson(ensure_ascii=False) | safe }}}{{- "</tool_call>" }}
|
||||
{%- if not loop.last %}
|
||||
{{- "\n" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- end_of_turn -}}
|
||||
{%- elif role == "tool" %}
|
||||
{%- if i == 0 or messages[i - 1].role != "tool" %}
|
||||
{{- role_indicators['tool'] }}
|
||||
{%- endif %}
|
||||
{%- if msg.content is defined %}
|
||||
{{- "<tool_result>" }}{"result": {{ msg.content | tojson(ensure_ascii=False) | safe }}}{{- "</tool_result>" }}
|
||||
{%- endif %}
|
||||
{%- if loop.last or messages[i + 1].role != "tool" %}
|
||||
{{- end_of_turn -}}
|
||||
{%- else %}
|
||||
{{- "\n" }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- role_indicators[role] }}
|
||||
{{- msg.content }}
|
||||
{{- end_of_turn -}}
|
||||
{%- endif %}
|
||||
{% endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- role_indicators['assistant'] }}
|
||||
{%- if enable_thinking is defined and enable_thinking is true %}
|
||||
{{- "<think>\n" }}
|
||||
{%- else %}
|
||||
{{- "<think>\n\n</think>\n\n" }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
@@ -58,7 +58,6 @@ class ChatTemplate(str, Enum):
|
||||
falcon_h1 = "falcon_h1"
|
||||
tokenizer_default = "tokenizer_default"
|
||||
exaone = "exaone"
|
||||
exaone4 = "exaone4"
|
||||
metharme = "metharme"
|
||||
pixtral = "pixtral"
|
||||
llava = "llava"
|
||||
|
||||
Reference in New Issue
Block a user