Feat: add kimi linear support (#3257)
* feat: add custom kimi linear patch [skip ci] * feat: add configuration file and fix import [skip ci] * fix: hijack tokenizer temporarily [skip ci] * chore: remove accidental commit * fix: attempt patch kimi remote * fix: kwargs passsed * fix: device for tensor * fix: aux loss calculation * feat: cleaned up patches order * fix: remove duplicate tokenizer patch * chore: add debug logs * chore: add debug logs * chore: debug * Revert "chore: add debug logs" This reverts commitda372a5f67. * Revert "chore: add debug logs" This reverts commit97d1de1d7c. * fix: KeyError: 'tokenization_kimi' * fix: support remote_model_id in cce patch * feat: add config preload patch * fix: use standard aux loss calc and updated modeling * fix: import * feat: add kimi-linear docs and example * chore: add note about moe kernels * feat: update cce to include kimi-linear * chore: lint * chore: update main readme * fix: patch mechanism to address comments * chore: lint * fix: tests * chore: cleanup comment
This commit is contained in:
@@ -29,7 +29,7 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
|
- 2025/12: Axolotl now includes support for [Kimi-Linear](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/kimi-linear), [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
|
||||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
||||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
47
examples/kimi-linear/README.md
Normal file
47
examples/kimi-linear/README.md
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Finetune MoonshotAI's Kimi Linear with Axolotl
|
||||||
|
|
||||||
|
[Kimi Linear](https://huggingface.co/collections/moonshotai/kimi-linear-a3b) is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||||
|
|
||||||
|
**Note:** Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Install CCE via [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/kimi-linear/kimi-48b-lora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 98.7GiB VRAM.
|
||||||
|
|
||||||
|
Let us know how it goes. Happy finetuning!
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- Kimi Linear requires `trust_remote_code: true`.
|
||||||
|
- You can run a full finetuning by removing the `adapter: lora` and `load_in_8bit: true`.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html)
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template)
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
See 👉 [docs](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
This is not yet compatible with MoE kernels from transformers v5.
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [Kimi Linear Paper](https://huggingface.co/papers/2510.26692)
|
||||||
|
- [Kimi Linear GitHub](https://github.com/MoonshotAI/Kimi-Linear)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
81
examples/kimi-linear/kimi-48b-lora.yaml
Normal file
81
examples/kimi-linear/kimi-48b-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: moonshotai/Kimi-Linear-48B-A3B-Instruct
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.2
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -54,6 +54,7 @@ plugins:
|
|||||||
- granitemoehybrid
|
- granitemoehybrid
|
||||||
- hunyuan_v1_dense
|
- hunyuan_v1_dense
|
||||||
- hunyuan_v1_moe
|
- hunyuan_v1_moe
|
||||||
|
- kimi_linear
|
||||||
- lfm2
|
- lfm2
|
||||||
- lfm2_moe
|
- lfm2_moe
|
||||||
- lfm2_vl
|
- lfm2_vl
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -96,7 +96,11 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The patch checks model_type internally
|
# The patch checks model_type internally
|
||||||
cce_patch(cfg.model_config_type)
|
|
||||||
|
cce_patch(
|
||||||
|
cfg.model_config_type,
|
||||||
|
remote_model_id=cfg.base_model if cfg.trust_remote_code else None,
|
||||||
|
)
|
||||||
|
|
||||||
def patch_llama_like(
|
def patch_llama_like(
|
||||||
self,
|
self,
|
||||||
@@ -107,7 +111,9 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
"""
|
"""
|
||||||
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
||||||
|
|
||||||
def patch_generic(maybe_model, patch_options, model_type: str):
|
def patch_generic(
|
||||||
|
maybe_model, patch_options, model_type: str, remote_model_id: str | None
|
||||||
|
):
|
||||||
import cut_cross_entropy.transformers.llama
|
import cut_cross_entropy.transformers.llama
|
||||||
from cut_cross_entropy.transformers.llama import cce_forward
|
from cut_cross_entropy.transformers.llama import cce_forward
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,48 @@ PLUGIN_MANAGER = PluginManager.get_instance()
|
|||||||
class PatchManager:
|
class PatchManager:
|
||||||
"""Manages the application of patches during the model loading process."""
|
"""Manages the application of patches during the model loading process."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_pre_config_load_patches(cfg: DictDefault):
|
||||||
|
"""
|
||||||
|
Apply patches that must be set up before config loading.
|
||||||
|
This is for patches that intercept remote code loading from HuggingFace,
|
||||||
|
which needs to be in place before AutoConfig.from_pretrained() is called.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Configuration dictionary with model and training settings.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
hasattr(cfg, "base_model_config")
|
||||||
|
and cfg.base_model_config
|
||||||
|
and "kimi-linear" in cfg.base_model_config.lower()
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||||
|
patch_kimi_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_kimi_config()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_pre_tokenizer_load_patches(cfg: DictDefault):
|
||||||
|
"""
|
||||||
|
Apply patches that must be set up before tokenizer loading.
|
||||||
|
This is for patches that intercept remote code loading from HuggingFace,
|
||||||
|
which needs to be in place before AutoTokenizer.from_pretrained() is called.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Configuration dictionary with model and training settings.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
hasattr(cfg, "tokenizer_config")
|
||||||
|
and cfg.tokenizer_config
|
||||||
|
and "kimi-linear" in cfg.tokenizer_config.lower()
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||||
|
patch_kimi_tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_kimi_tokenizer()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -190,6 +232,13 @@ class PatchManager:
|
|||||||
|
|
||||||
apply_mistral_tokenizer_image_patch()
|
apply_mistral_tokenizer_image_patch()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "kimi_linear":
|
||||||
|
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||||
|
patch_kimi_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_kimi_model()
|
||||||
|
|
||||||
def _apply_fp8_patches(self):
|
def _apply_fp8_patches(self):
|
||||||
"""Apply patches for FP8 support."""
|
"""Apply patches for FP8 support."""
|
||||||
if self.cfg.fp8:
|
if self.cfg.fp8:
|
||||||
|
|||||||
@@ -124,6 +124,11 @@ def modify_tokenizer_files(
|
|||||||
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||||
"""Load and configure the tokenizer based on the provided config."""
|
"""Load and configure the tokenizer based on the provided config."""
|
||||||
|
|
||||||
|
# Apply patches that need to be in place before tokenizer loading
|
||||||
|
from axolotl.loaders.patch_manager import PatchManager
|
||||||
|
|
||||||
|
PatchManager.apply_pre_tokenizer_load_patches(cfg)
|
||||||
|
|
||||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||||
"""Load mistral-common tokenizer"""
|
"""Load mistral-common tokenizer"""
|
||||||
from axolotl.utils.mistral import HFMistralTokenizer
|
from axolotl.utils.mistral import HFMistralTokenizer
|
||||||
|
|||||||
148
src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py
Normal file
148
src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""
|
||||||
|
Kimi-Linear configuration.
|
||||||
|
|
||||||
|
Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/configuration_kimi.py
|
||||||
|
Revision: 6e163f3
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class KimiLinearConfig(PretrainedConfig):
|
||||||
|
model_type = "kimi_linear"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_type="kimi_linear",
|
||||||
|
vocab_size=163840,
|
||||||
|
hidden_size=4096,
|
||||||
|
head_dim=None,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
hidden_act="silu",
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
moe_intermediate_size: Optional[int] = None,
|
||||||
|
moe_renormalize: bool = True,
|
||||||
|
moe_router_activation_func: str = "sigmoid",
|
||||||
|
num_experts: Optional[int] = None,
|
||||||
|
num_experts_per_token: Optional[int] = None,
|
||||||
|
num_shared_experts: int = 0,
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
first_k_dense_replace: int = 0,
|
||||||
|
moe_layer_freq: int = 1,
|
||||||
|
use_grouped_topk: bool = True,
|
||||||
|
num_expert_group: int = 1,
|
||||||
|
topk_group: int = 1,
|
||||||
|
q_lora_rank: Optional[int] = None,
|
||||||
|
kv_lora_rank: Optional[int] = None,
|
||||||
|
qk_nope_head_dim: Optional[int] = None,
|
||||||
|
qk_rope_head_dim: Optional[int] = None,
|
||||||
|
v_head_dim: Optional[int] = None,
|
||||||
|
mla_use_nope: Optional[bool] = False,
|
||||||
|
num_nextn_predict_layers: int = 0,
|
||||||
|
linear_attn_config: Optional[dict] = None,
|
||||||
|
router_aux_loss_coef: float = 0.01,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.model_type = model_type
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_dim = (
|
||||||
|
head_dim if head_dim is not None else hidden_size // num_attention_heads
|
||||||
|
)
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.mla_use_nope = mla_use_nope
|
||||||
|
# moe config
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.num_experts_per_token = num_experts_per_token
|
||||||
|
self.moe_renormalize = moe_renormalize
|
||||||
|
self.num_shared_experts = num_shared_experts
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.moe_router_activation_func = moe_router_activation_func
|
||||||
|
assert self.moe_router_activation_func in ("softmax", "sigmoid")
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.use_grouped_topk = use_grouped_topk
|
||||||
|
self.num_expert_group = num_expert_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||||
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
|
||||||
|
if linear_attn_config is not None:
|
||||||
|
assert linear_attn_config["kda_layers"] is not None
|
||||||
|
assert linear_attn_config["full_attn_layers"] is not None
|
||||||
|
self.linear_attn_config = linear_attn_config
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_mla(self):
|
||||||
|
return (
|
||||||
|
self.q_lora_rank is not None
|
||||||
|
or self.kv_lora_rank is not None
|
||||||
|
or self.qk_nope_head_dim is not None
|
||||||
|
or self.qk_rope_head_dim is not None
|
||||||
|
or self.v_head_dim is not None
|
||||||
|
or self.mla_use_nope is True
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_moe(self):
|
||||||
|
return self.num_experts is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_linear_attn(self) -> bool:
|
||||||
|
return not (
|
||||||
|
self.linear_attn_config is None
|
||||||
|
or (
|
||||||
|
isinstance(self.linear_attn_config, dict)
|
||||||
|
and self.linear_attn_config["kda_layers"] is not None
|
||||||
|
and len(self.linear_attn_config["kda_layers"]) == 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_kda_layer(self, layer_idx: int):
|
||||||
|
return (
|
||||||
|
self.linear_attn_config is not None
|
||||||
|
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
|
||||||
|
)
|
||||||
1361
src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py
Normal file
1361
src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,85 @@
|
|||||||
|
import importlib.resources
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
|
||||||
|
|
||||||
|
|
||||||
|
def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
|
||||||
|
"""
|
||||||
|
Gets the absolute path to a patch file using importlib.resources.files.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return importlib.resources.files(package_dot_path) / filename
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_local_module(module_name: str, filename: str):
|
||||||
|
"""Helper to load a local module if not already loaded."""
|
||||||
|
if module_name in sys.modules:
|
||||||
|
return sys.modules[module_name]
|
||||||
|
|
||||||
|
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename)
|
||||||
|
if patch_path and patch_path.exists():
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, patch_path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_get_class_in_module():
|
||||||
|
"""
|
||||||
|
Core patch function that hijacks Transformers' dynamic module loading.
|
||||||
|
"""
|
||||||
|
from transformers.dynamic_module_utils import get_class_in_module
|
||||||
|
|
||||||
|
if hasattr(get_class_in_module, "_axolotl_patched"):
|
||||||
|
return
|
||||||
|
|
||||||
|
original_get_class_in_module = get_class_in_module
|
||||||
|
|
||||||
|
# Mapping of module path patterns to (module_name, filename)
|
||||||
|
KIMI_MODULE_MAP = {
|
||||||
|
"configuration_kimi": ("configuration_kimi", "configuration_kimi.py"),
|
||||||
|
"modeling_kimi": ("modeling_kimi", "modeling_kimi.py"),
|
||||||
|
"tokenization_kimi": ("tokenization_kimi", "tokenization_kimi.py"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def patched_get_class_in_module(class_name, module_path, **kwargs):
|
||||||
|
"""Patched version that returns our local modules instead of remote ones."""
|
||||||
|
for pattern, (module_name, filename) in KIMI_MODULE_MAP.items():
|
||||||
|
if pattern in module_path:
|
||||||
|
module = _load_local_module(module_name, filename)
|
||||||
|
if module:
|
||||||
|
return getattr(module, class_name)
|
||||||
|
break # Pattern matched but file not found, fall through
|
||||||
|
|
||||||
|
return original_get_class_in_module(class_name, module_path, **kwargs)
|
||||||
|
|
||||||
|
import transformers.dynamic_module_utils
|
||||||
|
|
||||||
|
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
|
||||||
|
patched_get_class_in_module._axolotl_patched = True
|
||||||
|
|
||||||
|
|
||||||
|
def patch_kimi():
|
||||||
|
"""
|
||||||
|
Apply all Kimi patches.
|
||||||
|
Must be called BEFORE loading config/tokenizer/model.
|
||||||
|
"""
|
||||||
|
_patch_get_class_in_module()
|
||||||
|
LOG.info("Kimi patches applied successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
# Keep these for backward compatibility if needed
|
||||||
|
patch_kimi_config = patch_kimi
|
||||||
|
patch_kimi_tokenizer = patch_kimi
|
||||||
|
patch_kimi_model = patch_kimi
|
||||||
357
src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py
Normal file
357
src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
"""
|
||||||
|
Adapted Kimi-Linear tokenizer to use proper template defaults and misc fixes.
|
||||||
|
|
||||||
|
Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/tokenization_kimi.py
|
||||||
|
Revision: 919416f
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from logging import getLogger
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from tiktoken.load import load_tiktoken_bpe
|
||||||
|
from tokenizers import AddedToken
|
||||||
|
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
|
||||||
|
|
||||||
|
|
||||||
|
class TikTokenTokenizer(PreTrainedTokenizer):
|
||||||
|
"""
|
||||||
|
Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
|
||||||
|
|
||||||
|
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
||||||
|
this superclass for more information regarding those methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (`str`):
|
||||||
|
The path to the Tiktoken model file.
|
||||||
|
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
|
||||||
|
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||||
|
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
|
||||||
|
The end of sequence token.
|
||||||
|
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead. The second to last item in special_tokens.
|
||||||
|
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
|
||||||
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
|
additional_special_tokens (list of `str`, *optional*):
|
||||||
|
A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
|
||||||
|
skipped when decoding if `skip_special_tokens` is set to `True`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
model_input_names = ["input_ids", "attention_mask"]
|
||||||
|
|
||||||
|
special_tokens: Dict[str, int]
|
||||||
|
|
||||||
|
num_reserved_special_tokens = 256
|
||||||
|
|
||||||
|
pat_str = "|".join(
|
||||||
|
[
|
||||||
|
r"""[\p{Han}]+""",
|
||||||
|
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
|
||||||
|
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
|
||||||
|
r"""\p{N}{1,3}""",
|
||||||
|
r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
|
||||||
|
r"""\s*[\r\n]+""",
|
||||||
|
r"""\s+(?!\S)""",
|
||||||
|
r"""\s+""",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
bos_token: Union[str, AddedToken] = "[BOS]", # nosec: B107
|
||||||
|
eos_token: Union[str, AddedToken] = "[EOS]", # nosec: B107
|
||||||
|
unk_token: Union[str, AddedToken, None] = None,
|
||||||
|
pad_token: Union[str, AddedToken, None] = None,
|
||||||
|
additional_special_tokens: List[str] = None,
|
||||||
|
added_tokens_decoder: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert os.path.isfile(vocab_file), vocab_file
|
||||||
|
|
||||||
|
if additional_special_tokens is None:
|
||||||
|
additional_special_tokens = [
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|im_user|>",
|
||||||
|
"<|im_assistant|>",
|
||||||
|
"<|start_header_id|>",
|
||||||
|
"<|end_header_id|>",
|
||||||
|
"[EOT]",
|
||||||
|
"<|im_system|>",
|
||||||
|
"<|im_middle|>",
|
||||||
|
]
|
||||||
|
|
||||||
|
special_tokens_mapping = {
|
||||||
|
i: added_tokens_decoder[i].content for i in added_tokens_decoder
|
||||||
|
}
|
||||||
|
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
mergeable_ranks = load_tiktoken_bpe(vocab_file)
|
||||||
|
num_base_tokens = len(mergeable_ranks)
|
||||||
|
self.special_tokens = {
|
||||||
|
special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
|
||||||
|
for i in range(
|
||||||
|
num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
self.model = tiktoken.Encoding(
|
||||||
|
name=Path(vocab_file).name,
|
||||||
|
pat_str=self.pat_str,
|
||||||
|
mergeable_ranks=mergeable_ranks,
|
||||||
|
special_tokens=self.special_tokens,
|
||||||
|
)
|
||||||
|
logger.info(f"Reloaded tiktoken model from {vocab_file}")
|
||||||
|
|
||||||
|
self.n_words: int = self.model.n_vocab
|
||||||
|
# BOS / EOS token IDs
|
||||||
|
self.bos_id: int = self.special_tokens[str(bos_token)]
|
||||||
|
self.eos_id: int = self.special_tokens[str(eos_token)]
|
||||||
|
logger.info(
|
||||||
|
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pad_id: int = self.special_tokens[str(pad_token)]
|
||||||
|
self.unk_id: int = self.special_tokens[str(unk_token)]
|
||||||
|
|
||||||
|
self.byte_encoder = bytes_to_unicode()
|
||||||
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||||
|
|
||||||
|
self.decoder = {}
|
||||||
|
for i in range(self.n_words):
|
||||||
|
# Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
|
||||||
|
decoding = "".join(
|
||||||
|
[
|
||||||
|
self.byte_encoder[ord(char)]
|
||||||
|
for char in self.model.decode_single_token_bytes(i).decode(
|
||||||
|
"latin-1"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.decoder[i] = decoding
|
||||||
|
|
||||||
|
self.encoder = {}
|
||||||
|
for i in range(self.n_words):
|
||||||
|
if i in self.decoder:
|
||||||
|
self.encoder[self.decoder[i]] = i
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
additional_special_tokens=additional_special_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.all_special_ids_set = set(self.all_special_ids)
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, text: str, allow_special_tokens: bool = True, **kwargs
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Encodes a string into a list of token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input string to be encoded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[int]: A list of token IDs.
|
||||||
|
"""
|
||||||
|
# If there are other args, we should call super().encode because there are a lot of code
|
||||||
|
# to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
|
||||||
|
# NOTE: our encode method is not compatible with the super().encode method,
|
||||||
|
# e.g. split_special_tokens' default is True in our encode method.
|
||||||
|
if len(kwargs) > 0:
|
||||||
|
# logger.warning(f"Calling super().encode with {kwargs}")
|
||||||
|
return super().encode(text, **kwargs)
|
||||||
|
|
||||||
|
assert type(text) is str
|
||||||
|
|
||||||
|
# The tiktoken tokenizer can handle <=400k chars without
|
||||||
|
# pyo3_runtime.PanicException.
|
||||||
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||||
|
|
||||||
|
# https://github.com/openai/tiktoken/issues/195
|
||||||
|
# Here we iterate over subsequences and split if we exceed the limit
|
||||||
|
# of max consecutive non-whitespace or whitespace characters.
|
||||||
|
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||||
|
|
||||||
|
texts = self.pre_tokenizer_process(text)
|
||||||
|
|
||||||
|
all_substrs = []
|
||||||
|
for text in texts:
|
||||||
|
substrs = (
|
||||||
|
substr
|
||||||
|
for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||||
|
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||||
|
text[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_substrs.extend(substrs)
|
||||||
|
|
||||||
|
t: List[int] = []
|
||||||
|
for substr in all_substrs:
|
||||||
|
if allow_special_tokens:
|
||||||
|
t.extend(
|
||||||
|
# we should consider special token as a common token
|
||||||
|
self.model.encode(
|
||||||
|
substr,
|
||||||
|
allowed_special="all",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
t.extend(
|
||||||
|
# we should consider special token as a common token
|
||||||
|
self.model.encode(
|
||||||
|
substr,
|
||||||
|
disallowed_special=(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return t
|
||||||
|
|
||||||
|
def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Decodes a list of token IDs into a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids (List[int]): The list of token IDs to be decoded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The decoded string.
|
||||||
|
"""
|
||||||
|
# If there are other args, we should call super().decode because there are a lot of code
|
||||||
|
# to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
|
||||||
|
if len(kwargs) > 0:
|
||||||
|
return super().decode(token_ids, **kwargs)
|
||||||
|
|
||||||
|
if type(token_ids) is int:
|
||||||
|
token_ids = [token_ids]
|
||||||
|
|
||||||
|
return self.model.decode(cast(List[int], token_ids))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_whitespaces_or_nonwhitespaces(
|
||||||
|
s: str, max_consecutive_slice_len: int
|
||||||
|
) -> Iterator[str]:
|
||||||
|
"""
|
||||||
|
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||||
|
consecutive whitespaces or consecutive non-whitespaces.
|
||||||
|
"""
|
||||||
|
current_slice_len = 0
|
||||||
|
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||||
|
slice_start = 0
|
||||||
|
|
||||||
|
for i in range(len(s)):
|
||||||
|
is_now_space = s[i].isspace()
|
||||||
|
|
||||||
|
if current_slice_is_space ^ is_now_space:
|
||||||
|
current_slice_len = 1
|
||||||
|
current_slice_is_space = is_now_space
|
||||||
|
else:
|
||||||
|
current_slice_len += 1
|
||||||
|
if current_slice_len > max_consecutive_slice_len:
|
||||||
|
yield s[slice_start:i]
|
||||||
|
slice_start = i
|
||||||
|
current_slice_len = 1
|
||||||
|
yield s[slice_start:]
|
||||||
|
|
||||||
|
def pre_tokenizer_process(self, text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
pre-tokenizes the input text into a list of tokens.
|
||||||
|
This method is used to split the input text into smaller chunks for internal processing.
|
||||||
|
"""
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
""" ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.n_words
|
||||||
|
|
||||||
|
def get_vocab(self) -> Dict[str, int]:
|
||||||
|
return self.encoder
|
||||||
|
|
||||||
|
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
||||||
|
return [self.decoder[t] for t in self.encode(text)]
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token: str) -> int:
|
||||||
|
return self.encoder.get(token, self.unk_id)
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index: int) -> str:
|
||||||
|
return self.decoder.get(index)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_up_tokenization(out_string: str) -> str:
|
||||||
|
return out_string
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
text = "".join(tokens)
|
||||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||||
|
"utf-8", "replace"
|
||||||
|
)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def save_vocabulary(
|
||||||
|
self, save_directory: str, filename_prefix: Optional[str] = None
|
||||||
|
) -> Tuple[str]:
|
||||||
|
if not os.path.isdir(save_directory):
|
||||||
|
raise ValueError(
|
||||||
|
f"vocabulary path ({save_directory}) should be a directory"
|
||||||
|
)
|
||||||
|
out_vocab_file = os.path.join(
|
||||||
|
save_directory,
|
||||||
|
(filename_prefix + "-" if filename_prefix else "")
|
||||||
|
+ VOCAB_FILES_NAMES["vocab_file"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
||||||
|
out_vocab_file
|
||||||
|
) and os.path.isfile(self.vocab_file):
|
||||||
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
|
||||||
|
return (out_vocab_file,)
|
||||||
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
conversation,
|
||||||
|
tools: Optional[list[dict]] = None,
|
||||||
|
tokenize: bool = True,
|
||||||
|
add_generation_prompt: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
tools = deep_sort_dict(tools)
|
||||||
|
return super().apply_chat_template(
|
||||||
|
conversation,
|
||||||
|
tools=tools,
|
||||||
|
tokenize=tokenize,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def deep_sort_dict(obj: Any) -> Any:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [deep_sort_dict(item) for item in obj]
|
||||||
|
return obj
|
||||||
@@ -151,6 +151,11 @@ def normalize_config(cfg):
|
|||||||
if not cfg.base_model_config:
|
if not cfg.base_model_config:
|
||||||
cfg.base_model_config = cfg.base_model
|
cfg.base_model_config = cfg.base_model
|
||||||
|
|
||||||
|
# Apply pre-config load patches (e.g., for Kimi Linear remote code patching)
|
||||||
|
from axolotl.loaders.patch_manager import PatchManager
|
||||||
|
|
||||||
|
PatchManager.apply_pre_config_load_patches(cfg)
|
||||||
|
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
|
|
||||||
cfg.tokenizer_config = (
|
cfg.tokenizer_config = (
|
||||||
|
|||||||
Reference in New Issue
Block a user