diff --git a/README.md b/README.md index 285867215..13a5a9243 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ ## 🎉 Latest Updates -- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3). +- 2025/12: Axolotl now includes support for [Kimi-Linear](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/kimi-linear), [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3). - 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss). - 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion). - 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107). diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 77a4154e2..133c3db79 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245\"" ] }, { diff --git a/examples/kimi-linear/README.md b/examples/kimi-linear/README.md new file mode 100644 index 000000000..250f8bea1 --- /dev/null +++ b/examples/kimi-linear/README.md @@ -0,0 +1,47 @@ +# Finetune MoonshotAI's Kimi Linear with Axolotl + +[Kimi Linear](https://huggingface.co/collections/moonshotai/kimi-linear-a3b) is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +**Note:** Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install CCE via [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) + +3. Run the finetuning example: + + ```bash + axolotl train examples/kimi-linear/kimi-48b-lora.yaml + ``` + +This config uses about 98.7GiB VRAM. + +Let us know how it goes. Happy finetuning! + +### TIPS + +- Kimi Linear requires `trust_remote_code: true`. +- You can run a full finetuning by removing the `adapter: lora` and `load_in_8bit: true`. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html) +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template) + +## Optimization Guides + +See 👉 [docs](https://docs.axolotl.ai/docs/optimizations.html). + +## Limitations + +This is not yet compatible with MoE kernels from transformers v5. + +## Related Resources + +- [Kimi Linear Paper](https://huggingface.co/papers/2510.26692) +- [Kimi Linear GitHub](https://github.com/MoonshotAI/Kimi-Linear) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/kimi-linear/kimi-48b-lora.yaml b/examples/kimi-linear/kimi-48b-lora.yaml new file mode 100644 index 000000000..8e855dd72 --- /dev/null +++ b/examples/kimi-linear/kimi-48b-lora.yaml @@ -0,0 +1,81 @@ +base_model: moonshotai/Kimi-Linear-48B-A3B-Instruct + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +trust_remote_code: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + split: train + +dataset_prepared_path: last_run_prepared +val_set_size: 0.2 +output_dir: ./outputs/lora-out + +adapter: lora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.05 +lora_fan_in_fan_out: +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_ratio: 0.1 +evals_per_epoch: 2 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index ec5c6d475..e902bb0ac 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245"' ) diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 2c5b0f6e5..b28382542 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245" ``` ## Usage @@ -54,6 +54,7 @@ plugins: - granitemoehybrid - hunyuan_v1_dense - hunyuan_v1_moe +- kimi_linear - lfm2 - lfm2_moe - lfm2_vl diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 98a1659b1..0d1588f99 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245"`' ) @@ -96,7 +96,11 @@ class CutCrossEntropyPlugin(BasePlugin): ) # The patch checks model_type internally - cce_patch(cfg.model_config_type) + + cce_patch( + cfg.model_config_type, + remote_model_id=cfg.base_model if cfg.trust_remote_code else None, + ) def patch_llama_like( self, @@ -107,7 +111,9 @@ class CutCrossEntropyPlugin(BasePlugin): """ from cut_cross_entropy.transformers.patch import PATCH_FNS - def patch_generic(maybe_model, patch_options, model_type: str): + def patch_generic( + maybe_model, patch_options, model_type: str, remote_model_id: str | None + ): import cut_cross_entropy.transformers.llama from cut_cross_entropy.transformers.llama import cce_forward diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 81e4dd786..c31982262 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -26,6 +26,48 @@ PLUGIN_MANAGER = PluginManager.get_instance() class PatchManager: """Manages the application of patches during the model loading process.""" + @staticmethod + def apply_pre_config_load_patches(cfg: DictDefault): + """ + Apply patches that must be set up before config loading. + This is for patches that intercept remote code loading from HuggingFace, + which needs to be in place before AutoConfig.from_pretrained() is called. + + Args: + cfg: Configuration dictionary with model and training settings. + """ + if ( + hasattr(cfg, "base_model_config") + and cfg.base_model_config + and "kimi-linear" in cfg.base_model_config.lower() + ): + from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import ( + patch_kimi_config, + ) + + patch_kimi_config() + + @staticmethod + def apply_pre_tokenizer_load_patches(cfg: DictDefault): + """ + Apply patches that must be set up before tokenizer loading. + This is for patches that intercept remote code loading from HuggingFace, + which needs to be in place before AutoTokenizer.from_pretrained() is called. + + Args: + cfg: Configuration dictionary with model and training settings. + """ + if ( + hasattr(cfg, "tokenizer_config") + and cfg.tokenizer_config + and "kimi-linear" in cfg.tokenizer_config.lower() + ): + from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import ( + patch_kimi_tokenizer, + ) + + patch_kimi_tokenizer() + def __init__( self, cfg: DictDefault, @@ -190,6 +232,13 @@ class PatchManager: apply_mistral_tokenizer_image_patch() + if self.cfg.model_config_type == "kimi_linear": + from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import ( + patch_kimi_model, + ) + + patch_kimi_model() + def _apply_fp8_patches(self): """Apply patches for FP8 support.""" if self.cfg.fp8: diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 48856116c..170ebf333 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -124,6 +124,11 @@ def modify_tokenizer_files( def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: """Load and configure the tokenizer based on the provided config.""" + # Apply patches that need to be in place before tokenizer loading + from axolotl.loaders.patch_manager import PatchManager + + PatchManager.apply_pre_tokenizer_load_patches(cfg) + def _load_mistral_common_tokenizer(cfg: DictDefault): """Load mistral-common tokenizer""" from axolotl.utils.mistral import HFMistralTokenizer diff --git a/src/axolotl/monkeypatch/models/kimi_linear/__init__.py b/src/axolotl/monkeypatch/models/kimi_linear/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py b/src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py new file mode 100644 index 000000000..1dd0e6702 --- /dev/null +++ b/src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py @@ -0,0 +1,148 @@ +""" +Kimi-Linear configuration. + +Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/configuration_kimi.py +Revision: 6e163f3 +""" + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class KimiLinearConfig(PretrainedConfig): + model_type = "kimi_linear" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + model_type="kimi_linear", + vocab_size=163840, + hidden_size=4096, + head_dim=None, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + rope_theta=10000.0, + rope_scaling=None, + tie_word_embeddings=False, + moe_intermediate_size: Optional[int] = None, + moe_renormalize: bool = True, + moe_router_activation_func: str = "sigmoid", + num_experts: Optional[int] = None, + num_experts_per_token: Optional[int] = None, + num_shared_experts: int = 0, + routed_scaling_factor: float = 1.0, + first_k_dense_replace: int = 0, + moe_layer_freq: int = 1, + use_grouped_topk: bool = True, + num_expert_group: int = 1, + topk_group: int = 1, + q_lora_rank: Optional[int] = None, + kv_lora_rank: Optional[int] = None, + qk_nope_head_dim: Optional[int] = None, + qk_rope_head_dim: Optional[int] = None, + v_head_dim: Optional[int] = None, + mla_use_nope: Optional[bool] = False, + num_nextn_predict_layers: int = 0, + linear_attn_config: Optional[dict] = None, + router_aux_loss_coef: float = 0.01, + **kwargs, + ): + self.model_type = model_type + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.head_dim = ( + head_dim if head_dim is not None else hidden_size // num_attention_heads + ) + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.mla_use_nope = mla_use_nope + # moe config + self.num_experts = num_experts + self.num_experts_per_token = num_experts_per_token + self.moe_renormalize = moe_renormalize + self.num_shared_experts = num_shared_experts + self.routed_scaling_factor = routed_scaling_factor + self.moe_router_activation_func = moe_router_activation_func + assert self.moe_router_activation_func in ("softmax", "sigmoid") + self.moe_intermediate_size = moe_intermediate_size + self.first_k_dense_replace = first_k_dense_replace + self.moe_layer_freq = moe_layer_freq + self.use_grouped_topk = use_grouped_topk + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.num_nextn_predict_layers = num_nextn_predict_layers + self.router_aux_loss_coef = router_aux_loss_coef + + if linear_attn_config is not None: + assert linear_attn_config["kda_layers"] is not None + assert linear_attn_config["full_attn_layers"] is not None + self.linear_attn_config = linear_attn_config + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def is_mla(self): + return ( + self.q_lora_rank is not None + or self.kv_lora_rank is not None + or self.qk_nope_head_dim is not None + or self.qk_rope_head_dim is not None + or self.v_head_dim is not None + or self.mla_use_nope is True + ) + + @property + def is_moe(self): + return self.num_experts is not None + + @property + def is_linear_attn(self) -> bool: + return not ( + self.linear_attn_config is None + or ( + isinstance(self.linear_attn_config, dict) + and self.linear_attn_config["kda_layers"] is not None + and len(self.linear_attn_config["kda_layers"]) == 0 + ) + ) + + def is_kda_layer(self, layer_idx: int): + return ( + self.linear_attn_config is not None + and (layer_idx + 1) in self.linear_attn_config["kda_layers"] + ) diff --git a/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py b/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py new file mode 100644 index 000000000..42a11ec36 --- /dev/null +++ b/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py @@ -0,0 +1,1361 @@ +""" +Adapted Kimi-Linear modeling to enable MoE differentiable. + +Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/modeling_kimi.py +Revision: 6e163f3 +""" + +import math +from collections.abc import Callable +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import transformers +from einops import rearrange +from packaging import version +from torch import nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin +from transformers.masking_utils import create_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + MoeCausalLMOutputWithPast, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + TransformersKwargs, + can_return_tuple, + logging, +) +from transformers.utils.generic import OutputRecorder + +try: + from fla.layers.utils import get_unpad_data, index_first_axis, pad_input + from fla.modules import FusedRMSNormGated, ShortConvolution + from fla.ops.kda import chunk_kda, fused_recurrent_kda + from fla.ops.kda.gate import fused_kda_gate +except ImportError as err: + raise ImportError( + "Plese run `pip uninstall fla-core flash-linear-attention -y && pip install git+https://github.com/fla-org/flash-linear-attention@v0.4.0`" + ) from err + +from axolotl.monkeypatch.models.kimi_linear.configuration_kimi import KimiLinearConfig + +assert version.parse(transformers.__version__) >= version.parse("4.56.0"), ( + "Please upgrade transformers to >= 4.56.0" +) + +logger = logging.get_logger(__name__) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + """Standard Switch Transformer load balancing loss.""" + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + # Concatenate all layer logits + concatenated_gate_logits = torch.cat( + [layer_gate for layer_gate in gate_logits], dim=0 + ) + + routing_weights = F.softmax(concatenated_gate_logits, dim=-1) + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + expert_mask = F.one_hot(selected_experts, num_experts) + + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + router_prob_per_expert = torch.mean(routing_weights, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class KimiDynamicCache: + """ + Dynamic cache for Kimi model. + Inspired by Qwen3-Next + """ + + is_compileable = False + + def __init__(self, config: KimiLinearConfig): + super().__init__() + self.config = config + + if config.linear_attn_config is not None: + self.layer_types = [] + for i in range(config.num_hidden_layers): + if config.is_kda_layer(i): + self.layer_types.append("linear_attention") + else: + self.layer_types.append("full_attention") + else: + self.layer_types = ["full_attention"] * config.num_hidden_layers + + self.transformer_layers = [ + i + for i in range(config.num_hidden_layers) + if self.layer_types[i] == "full_attention" + ] + + linear_layers = [ + i + for i in range(config.num_hidden_layers) + if self.layer_types[i] == "linear_attention" + ] + self.last_linear_layer = linear_layers[-1] if linear_layers else -1 + + self.conv_states = [None for _ in range(config.num_hidden_layers)] + self.recurrent_states = [None for _ in range(config.num_hidden_layers)] + self.key_cache = [None for _ in range(config.num_hidden_layers)] + self.value_cache = [None for _ in range(config.num_hidden_layers)] + + def __len__(self): + return len(self.layer_types) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx], key_states], dim=2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] is not None: + device = self.key_cache[layer_idx].device + beam_idx = beam_idx.to(device) + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( + 0, beam_idx + ) + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( + 0, beam_idx + ) + + if self.conv_states[layer_idx] is not None: + device = self.conv_states[layer_idx][0].device + beam_idx = beam_idx.to(device) + q_conv, k_conv, v_conv = self.conv_states[layer_idx] + self.conv_states[layer_idx] = ( + q_conv.index_select(0, beam_idx), + k_conv.index_select(0, beam_idx), + v_conv.index_select(0, beam_idx), + ) + self.recurrent_states[layer_idx] = self.recurrent_states[ + layer_idx + ].index_select(0, beam_idx) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = ( + self.transformer_layers[0] + if layer_idx not in self.transformer_layers + else layer_idx + ) + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_mask_sizes( + self, cache_position: torch.Tensor, layer_idx: int + ) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. + """ + kv_offset = 0 + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length(layer_idx) + kv_length = query_length + past_seen_tokens + return kv_length, kv_offset + + @property + def has_previous_state(self): + """We have a previous state if the last linear (conv) layer was already updated.""" + if self.last_linear_layer == -1: + return False + return self.conv_states[self.last_linear_layer] is not None + + +class KimiRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + KimiRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(KimiRMSNorm) + + +class KimiBlockSparseMLP(nn.Module): + def __init__( + self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None + ): + super().__init__() + self.config = config + self.ffn_dim = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + self.hidden_dim = config.hidden_size if hidden_size is None else hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # gate + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) # down + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) # up + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( + hidden_states + ) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class KimiMLP(nn.Module): + def __init__( + self, config: KimiLinearConfig, hidden_size=None, intermediate_size=None + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class KimiMLAAttention(nn.Module): + """ + Multi-Latent Attention adapted from deepseek-v3 + """ + + def __init__(self, config: KimiLinearConfig, layer_idx: int): + nn.Module.__init__(self) + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + self.rope_theta = config.rope_theta + self.attention_dropout = getattr(config, "attention_dropout", 0.0) + + try: + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.use_nope = config.mla_use_nope + self.scaling = self.q_head_dim ** (-0.5) + except Exception as e: + raise ValueError( + f"Kimi MLA config is not found or not properly formatted: {e}" + ) from e + + assert self.q_lora_rank is None + self.q_proj = nn.Linear( + self.hidden_size, + self.num_heads * self.q_head_dim, + bias=False, + ) + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + ) + self.kv_a_layernorm = KimiRMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + ) + self.is_causal = True + assert self.use_nope + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.q_head_dim) + key_shape = ( + batch_size, + seq_length, + -1, + self.qk_nope_head_dim + self.v_head_dim, + ) + + q_states = self.q_proj(hidden_states) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split( + q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + k_pass = ( + self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + ) + k_pass, value_states = torch.split( + k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + if past_key_values is not None: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx + ) + + if ( + self.config._attn_implementation == "flash_attention_2" + and self.q_head_dim != self.v_head_dim + ): + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if ( + self.config._attn_implementation == "flash_attention_2" + and self.q_head_dim != self.v_head_dim + ): + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class KimiDeltaAttention(nn.Module): + def __init__(self, config: KimiLinearConfig, layer_idx: int): + super().__init__() + self.config = config + self.mode = "chunk" + + self.hidden_size = config.hidden_size + self.conv_size = config.linear_attn_config["short_conv_kernel_size"] + self.head_dim = config.linear_attn_config["head_dim"] + self.num_heads = config.linear_attn_config["num_heads"] + self.head_k_dim = self.head_dim + self.num_k_heads = self.num_heads + + self.layer_idx = layer_idx + + assert self.mode in ["chunk", "fused_recurrent"], ( + f"Not suppoerted mode `{self.mode}`." + ) + + projection_k_size = self.head_k_dim * self.num_k_heads + projection_size = self.head_dim * self.num_heads + + self.q_proj = nn.Linear(self.hidden_size, projection_k_size, bias=False) + self.k_proj = nn.Linear(self.hidden_size, projection_k_size, bias=False) + self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False) + + self.q_conv1d = ShortConvolution( + hidden_size=projection_k_size, + kernel_size=self.conv_size, + activation="silu", + ) + self.k_conv1d = ShortConvolution( + hidden_size=projection_k_size, kernel_size=self.conv_size, activation="silu" + ) + self.v_conv1d = ShortConvolution( + hidden_size=projection_size, kernel_size=self.conv_size, activation="silu" + ) + + self.A_log = torch.nn.Parameter( + torch.log( + torch.empty(self.num_heads, dtype=torch.float32).uniform_(1, 16) + ).view(1, 1, -1, 1) + ) + + self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False) + self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False) + + self.dt_bias = nn.Parameter(torch.empty(projection_size, dtype=torch.float32)) + + self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False) + + self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False) + self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False) + + self.o_norm = FusedRMSNormGated( + self.head_dim, eps=config.rms_norm_eps, activation="sigmoid" + ) + self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + cache_params: Optional[KimiDynamicCache] = None, + **kwargs: Unpack[dict], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + if attention_mask.dim() != 2: + attention_mask = kwargs.get("padding_mask", None) + + if attention_mask is not None and attention_mask.dim() != 2: + raise ValueError( + "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] " + "(0 = padding). 3D masks are not supported here." + ) + use_cache = cache_params is not None + batch_size, q_len, _ = hidden_states.shape + mode = "fused_recurrent" if q_len <= 64 else self.mode + if self.training: + assert mode == "chunk", "Only chunk mode is supported in training." + + cu_seqlens = kwargs.get("cu_seqlens", None) + indices = None + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis( + rearrange(hidden_states, "b s ... -> (b s) ..."), indices + ).unsqueeze(0) + + conv_state_q, conv_state_k, conv_state_v = None, None, None + recurrent_state = None + if cache_params is not None: + if cache_params.conv_states[self.layer_idx] is not None: + conv_state_q, conv_state_k, conv_state_v = cache_params.conv_states[ + self.layer_idx + ] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + g = self.f_b_proj(self.f_a_proj(hidden_states)) + g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias) + beta = self.b_proj(hidden_states).float().sigmoid() + + q, k = map( + lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_k_dim), (q, k) + ) + v = rearrange(v, "... (h d) -> ... h d", d=self.head_dim) + + if mode == "chunk": + o, recurrent_state = chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + else: + o, recurrent_state = fused_recurrent_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = recurrent_state + cache_params.conv_states[self.layer_idx] = ( + conv_state_q, + conv_state_k, + conv_state_v, + ) + + g = self.g_b_proj(self.g_a_proj(hidden_states)) + g = rearrange(g, "... (h d) -> ... h d", d=self.head_dim) + o = self.o_norm(o, g) + + o = rearrange(o, "b t h d -> b t (h d)") + o = self.o_proj(o) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o + + +class KimiMoEGate(nn.Module): + """ + MoE Gate that returns router logits. + Routing decisions are made in KimiSparseMoeBlock. + """ + + def __init__(self, config: KimiLinearConfig): + super().__init__() + self.config = config + self.num_experts = config.num_experts + self.gating_dim = config.hidden_size + + self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim))) + self.e_score_correction_bias = nn.Parameter(torch.zeros((self.num_experts,))) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states: [batch_size, seq_len, hidden_dim] + + Returns: + router_logits: [batch_size * seq_len, num_experts] + """ + _, _, h = hidden_states.shape + hidden_states = hidden_states.view(-1, h) + router_logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + return router_logits + + # def forward(self, hidden_states): + # bsz, seq_len, h = hidden_states.shape + # # compute gating score + # hidden_states = hidden_states.view(-1, h) + # logits = F.linear( + # hidden_states.type(torch.float32), self.weight.type( + # torch.float32), None + # ) + # if self.moe_router_activation_func == "sigmoid": + # scores = logits.sigmoid() + # elif self.moe_router_activation_func == "softmax": + # scores = logits.softmax(dim=1) + # else: + # raise NotImplementedError( + # f"insupportable scoring function for MoE gating: {self.moe_router_activation_func}" + # ) + + # # select top-k experts + # assert not self.training + # scores_for_choice = scores.view(bsz * seq_len, -1) + # scores_for_choice += self.e_score_correction_bias.unsqueeze(0) + # group_scores = ( + # scores_for_choice.view( + # bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + # ) # [n, num_expert_group] + # group_idx = torch.topk( + # group_scores, k=self.topk_group, dim=-1, sorted=False + # )[ + # 1 + # ] # [n, top_k_group] + # group_mask = torch.zeros_like(group_scores) # [n, num_expert_group] + # group_mask.scatter_(1, group_idx, 1) # [n, num_expert_group] + # score_mask = ( + # group_mask.unsqueeze(-1) + # .expand( + # bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group + # ) + # .reshape(bsz * seq_len, -1) + # ) # [n, e] + # tmp_scores = scores_for_choice.masked_fill( + # ~score_mask.bool(), 0.0) # [n, e] + # _, topk_idx = torch.topk( + # tmp_scores, k=self.top_k, dim=-1, sorted=False + # ) + # topk_weight = scores.gather(1, topk_idx) + + # # norm gate to sum 1 + # if self.top_k > 1 and self.moe_renormalize: + # denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + # topk_weight = topk_weight / denominator + # # must multiply the scaling factor + # topk_weight = topk_weight * self.routed_scaling_factor + + # return topk_idx, topk_weight + + +# class KimiSparseMoeBlock(nn.Module): +# """ +# Adapted from Deepseek-V3's MOE implementation +# The namings are consistent with Kimi's version. +# """ + +# def __init__(self, config: KimiLinearConfig): +# super().__init__() +# self.config = config +# self.hidden_dim = config.hidden_size +# self.num_experts = config.num_experts +# self.top_k = config.num_experts_per_token +# self.moe_renormalize = config.moe_renormalize + +# self.ep_size = 1 +# self.experts_per_rank = config.num_experts +# self.ep_rank = 0 +# self.experts = nn.ModuleList( +# [ +# KimiBlockSparseMLP( +# config, intermediate_size=config.moe_intermediate_size +# ) +# for _ in range(config.num_experts) +# ] +# ) +# self.gate = KimiMoEGate(config) +# if config.num_shared_experts is not None: +# intermediate_size = config.moe_intermediate_size * config.num_shared_experts +# self.shared_experts = KimiMLP( +# config=config, intermediate_size=intermediate_size +# ) + +# def forward(self, hidden_states): +# identity = hidden_states +# orig_shape = hidden_states.shape +# topk_idx, topk_weight = self.gate(hidden_states) +# hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) +# flat_topk_idx = topk_idx.view(-1) +# if not self.training: +# y = self.moe_infer(hidden_states, topk_idx, +# topk_weight).view(*orig_shape) +# else: +# raise NotImplementedError( +# "Training mode is not supported in KimiSparseMoeBlock") +# if self.config.num_shared_experts is not None: +# y = y + self.shared_experts(identity) +# return y + +# @torch.no_grad() +# def moe_infer(self, x, topk_ids, topk_weight): +# cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) +# cnts.scatter_(1, topk_ids, 1) +# tokens_per_expert = cnts.sum(dim=0) +# idxs = topk_ids.view(-1).argsort() +# sorted_tokens = x[idxs // topk_ids.shape[1]] + +# tokens_per_expert = tokens_per_expert.cpu().numpy() + +# outputs = [] +# start_idx = 0 +# for i, num_tokens in enumerate(tokens_per_expert): +# end_idx = start_idx + num_tokens +# if num_tokens == 0: +# continue +# expert = self.experts[i + self.ep_rank * self.experts_per_rank] +# tokens_for_this_expert = sorted_tokens[start_idx:end_idx] +# expert_out = expert(tokens_for_this_expert) +# outputs.append(expert_out) +# start_idx = end_idx + +# outs = torch.cat(outputs, dim=0) if len( +# outputs) else sorted_tokens.new_empty(0) + +# new_x = torch.empty_like(outs) +# new_x[idxs] = outs +# final_out = ( +# new_x.view(*topk_ids.shape, -1) +# .type(topk_weight.dtype) +# .mul_(topk_weight.unsqueeze(dim=-1)) +# .sum(dim=1) +# .type(new_x.dtype) +# ) +# return final_out + + +# Replace the KimiSparseMoeBlock class with this new version +class KimiSparseMoeBlock(nn.Module): + """ + MoE block adapted from Deepseek-V3. + Returns only hidden_states - router_logits captured by OutputRecorder. + """ + + def __init__(self, config: KimiLinearConfig): + super().__init__() + self.config = config + self.hidden_dim = config.hidden_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_token + self.moe_renormalize = config.moe_renormalize + self.routed_scaling_factor = config.routed_scaling_factor + self.num_expert_group = getattr(config, "num_expert_group", 1) + self.topk_group = getattr(config, "topk_group", 1) + + self.experts = nn.ModuleList( + [ + KimiBlockSparseMLP( + config, intermediate_size=config.moe_intermediate_size + ) + for _ in range(config.num_experts) + ] + ) + self.gate = KimiMoEGate(config) + + if config.num_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.num_shared_experts + self.shared_experts = KimiMLP( + config=config, intermediate_size=intermediate_size + ) + + def route_tokens_to_experts( + self, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute routing decisions from router logits. + + Args: + router_logits: [num_tokens, num_experts] + + Returns: + topk_idx: [num_tokens, top_k] + topk_weight: [num_tokens, top_k] + """ + num_tokens = router_logits.shape[0] + + if self.training: + # Training: use softmax for standard aux loss compatibility + scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) + topk_weight, topk_idx = torch.topk(scores, self.top_k, dim=-1, sorted=False) + else: + # Inference: use original sigmoid + group selection + scores = router_logits.sigmoid() + scores_for_choice = scores + self.gate.e_score_correction_bias.unsqueeze(0) + + # Group-based selection + group_scores = ( + scores_for_choice.view(num_tokens, self.num_expert_group, -1) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + num_tokens, + self.num_expert_group, + self.num_experts // self.num_expert_group, + ) + .reshape(num_tokens, -1) + ) + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + + # Normalize and scale + if self.top_k > 1 and self.moe_renormalize: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor + + return topk_idx, topk_weight + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass returning only hidden_states. + Router logits are captured by OutputRecorder for aux loss. + """ + identity = hidden_states + batch_size, seq_len, hidden_dim = hidden_states.shape + num_tokens = batch_size * seq_len + + # Flatten for routing + hidden_states_flat = hidden_states.view(num_tokens, hidden_dim) + + # Get router logits - OutputRecorder captures this! + router_logits = self.gate(hidden_states) + + # Get routing decisions + topk_idx, topk_weight = self.route_tokens_to_experts(router_logits) + + if self.training: + final_hidden_states = self._training_forward( + hidden_states_flat, topk_idx, topk_weight, num_tokens, hidden_dim + ) + else: + final_hidden_states = self._inference_forward( + hidden_states_flat, topk_idx, topk_weight + ) + + final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) + + # Add shared experts if present + if self.config.num_shared_experts is not None: + final_hidden_states = final_hidden_states + self.shared_experts(identity) + + return final_hidden_states + + def _training_forward( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weight: torch.Tensor, + num_tokens: int, + hidden_dim: int, + ) -> torch.Tensor: + """ + Differentiable training forward using scatter-gather pattern. + """ + # Flatten expert indices: [num_tokens * top_k] + flat_topk_idx = topk_idx.view(-1) + + # Sort by expert index to group tokens going to same expert + sorted_indices = torch.argsort(flat_topk_idx) + inverse_permutation = torch.argsort(sorted_indices) + + # Each token appears top_k times (once per expert choice) + token_indices = torch.arange( + num_tokens, device=hidden_states.device + ).repeat_interleave(self.top_k) + + # Gather tokens and weights in sorted order + shuffled_tokens = hidden_states[token_indices[sorted_indices]] + shuffled_weights = topk_weight.view(-1)[sorted_indices].unsqueeze(-1) + + # Count tokens per expert + tokens_per_expert = F.one_hot(flat_topk_idx, num_classes=self.num_experts).sum( + dim=0 + ) + + # Process each expert's batch + expert_outputs = [] + current_pos = 0 + for i in range(self.num_experts): + num_tokens_for_expert = tokens_per_expert[i].item() + if num_tokens_for_expert == 0: + continue + + expert_input = shuffled_tokens[ + current_pos : current_pos + num_tokens_for_expert + ] + expert_output = self.experts[i](expert_input) + expert_outputs.append(expert_output) + current_pos += num_tokens_for_expert + + # Concatenate all outputs + if expert_outputs: + concatenated_outputs = torch.cat(expert_outputs, dim=0) + else: + concatenated_outputs = torch.zeros( + num_tokens * self.top_k, + hidden_dim, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + # Apply weights while still in sorted order + weighted_outputs = concatenated_outputs * shuffled_weights + + # Unsort back to original token order + unshuffled_outputs = weighted_outputs[inverse_permutation] + + # Sum contributions from all top_k experts for each token + final_hidden_states = unshuffled_outputs.view( + num_tokens, self.top_k, hidden_dim + ).sum(dim=1) + + return final_hidden_states + + @torch.no_grad() + def _inference_forward( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weight: torch.Tensor, + ) -> torch.Tensor: + """ + Optimized inference forward (original implementation). + """ + cnts = topk_idx.new_zeros((topk_idx.shape[0], len(self.experts))) + cnts.scatter_(1, topk_idx, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_idx.view(-1).argsort() + sorted_tokens = hidden_states[idxs // topk_idx.shape[1]] + + tokens_per_expert_list = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert_list): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i] + tokens_for_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if outputs else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_idx.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +class KimiDecoderLayer(nn.Module): + def __init__(self, config: KimiLinearConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.config = config + if config.is_kda_layer(layer_idx): + self.is_linear_attn = True + self.self_attn = KimiDeltaAttention(config=config, layer_idx=layer_idx) + elif config.is_mla: + self.is_linear_attn = False + self.self_attn = KimiMLAAttention(config=config, layer_idx=layer_idx) + else: + raise NotImplementedError + if ( + config.num_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % getattr(config, "moe_layer_freq", 1) == 0 + ): + self.block_sparse_moe = KimiSparseMoeBlock(config) + else: + self.mlp = KimiMLP(config) + self.input_layernorm = KimiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = KimiRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + if self.is_linear_attn is False: + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + else: + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + cache_params=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if hasattr(self, "block_sparse_moe"): + hidden_states = self.block_sparse_moe(hidden_states) + else: + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class KimiPreTrainedModel(PreTrainedModel): + config_class = KimiLinearConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["KimiDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _can_record_outputs = { + "router_logits": OutputRecorder(KimiMoEGate, index=0), + "hidden_states": KimiDecoderLayer, + "attentions": KimiMLAAttention, + } + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class KimiLinearModel(KimiPreTrainedModel): + def __init__(self, config: KimiLinearConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + KimiDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = KimiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if getattr(config, "_attn_implementation", None) is not None: + if config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Ignoring the provided attention implementation {config._attn_implementation}" + ) + logger.warning_once("Using flash_attention_2 backend instead.") + config._attn_implementation = "flash_attention_2" + else: + config._attn_implementation = "flash_attention_2" + + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _update_linear_attn_mask(self, attention_mask, cache_position): + """ + NOTE: Left-padding is used for linear attention mask. + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + linear_attn_mask = attention_mask + if cache_position[0] > 0 or ( + attention_mask is not None and torch.all(attention_mask == 1) + ): + linear_attn_mask = None + return linear_attn_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) and (inputs_embeds is None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + # Get inputs_embeds + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = KimiDynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position: torch.Tensor = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + + hidden_states = inputs_embeds + if past_key_values is not None: + assert isinstance(past_key_values, KimiDynamicCache) + + for decoder_layer in self.layers: + layer_mask = ( + linear_attn_mask if decoder_layer.is_linear_attn else causal_mask + ) + + hidden_states = decoder_layer( + hidden_states, + attention_mask=layer_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = KimiLinearModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + generation_mode: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, KimiLinearForCausalLM + + >>> model = KimiLinearForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + logits = outputs[0] + if generation_mode: + logits = logits[:, -1:] + logits = self.lm_head(logits) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if kwargs.get("output_router_logits", False): + aux_loss = load_balancing_loss_func( + outputs.router_logits, + num_experts=self.config.num_experts, + top_k=self.config.num_experts_per_token, + attention_mask=attention_mask, + ) + if loss is not None: + loss = loss + self.config.router_aux_loss_coef * aux_loss + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py b/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py new file mode 100644 index 000000000..f9d1546d6 --- /dev/null +++ b/src/axolotl/monkeypatch/models/kimi_linear/patch_kimi_linear.py @@ -0,0 +1,85 @@ +import importlib.resources +import importlib.util +import sys +from pathlib import Path + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear" + + +def get_patch_file_path(package_dot_path: str, filename: str) -> Path: + """ + Gets the absolute path to a patch file using importlib.resources.files. + """ + try: + return importlib.resources.files(package_dot_path) / filename + except ModuleNotFoundError: + return None + + +def _load_local_module(module_name: str, filename: str): + """Helper to load a local module if not already loaded.""" + if module_name in sys.modules: + return sys.modules[module_name] + + patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename) + if patch_path and patch_path.exists(): + spec = importlib.util.spec_from_file_location(module_name, patch_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + return None + + +def _patch_get_class_in_module(): + """ + Core patch function that hijacks Transformers' dynamic module loading. + """ + from transformers.dynamic_module_utils import get_class_in_module + + if hasattr(get_class_in_module, "_axolotl_patched"): + return + + original_get_class_in_module = get_class_in_module + + # Mapping of module path patterns to (module_name, filename) + KIMI_MODULE_MAP = { + "configuration_kimi": ("configuration_kimi", "configuration_kimi.py"), + "modeling_kimi": ("modeling_kimi", "modeling_kimi.py"), + "tokenization_kimi": ("tokenization_kimi", "tokenization_kimi.py"), + } + + def patched_get_class_in_module(class_name, module_path, **kwargs): + """Patched version that returns our local modules instead of remote ones.""" + for pattern, (module_name, filename) in KIMI_MODULE_MAP.items(): + if pattern in module_path: + module = _load_local_module(module_name, filename) + if module: + return getattr(module, class_name) + break # Pattern matched but file not found, fall through + + return original_get_class_in_module(class_name, module_path, **kwargs) + + import transformers.dynamic_module_utils + + transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module + patched_get_class_in_module._axolotl_patched = True + + +def patch_kimi(): + """ + Apply all Kimi patches. + Must be called BEFORE loading config/tokenizer/model. + """ + _patch_get_class_in_module() + LOG.info("Kimi patches applied successfully!") + + +# Keep these for backward compatibility if needed +patch_kimi_config = patch_kimi +patch_kimi_tokenizer = patch_kimi +patch_kimi_model = patch_kimi diff --git a/src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py b/src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py new file mode 100644 index 000000000..83f7ab4ae --- /dev/null +++ b/src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py @@ -0,0 +1,357 @@ +""" +Adapted Kimi-Linear tokenizer to use proper template defaults and misc fixes. + +Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/tokenization_kimi.py +Revision: 919416f +""" + +import os +from logging import getLogger +from pathlib import Path +from shutil import copyfile +from typing import ( + Any, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, + cast, +) + +import tiktoken +from tiktoken.load import load_tiktoken_bpe +from tokenizers import AddedToken +from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode +from transformers.tokenization_utils import PreTrainedTokenizer + +logger = getLogger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"} + + +class TikTokenTokenizer(PreTrainedTokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + The path to the Tiktoken model file. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`): + The end of sequence token. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. The second to last item in special_tokens. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`): + The token used for padding, for example when batching sequences of different lengths. + additional_special_tokens (list of `str`, *optional*): + A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be + skipped when decoding if `skip_special_tokens` is set to `True`. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + model_input_names = ["input_ids", "attention_mask"] + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = "|".join( + [ + r"""[\p{Han}]+""", + r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", + r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", + r"""\p{N}{1,3}""", + r""" ?[^\s\p{L}\p{N}]+[\r\n]*""", + r"""\s*[\r\n]+""", + r"""\s+(?!\S)""", + r"""\s+""", + ] + ) + + def __init__( + self, + vocab_file, + bos_token: Union[str, AddedToken] = "[BOS]", # nosec: B107 + eos_token: Union[str, AddedToken] = "[EOS]", # nosec: B107 + unk_token: Union[str, AddedToken, None] = None, + pad_token: Union[str, AddedToken, None] = None, + additional_special_tokens: List[str] = None, + added_tokens_decoder: Optional[dict] = None, + **kwargs, + ): + assert os.path.isfile(vocab_file), vocab_file + + if additional_special_tokens is None: + additional_special_tokens = [ + "<|im_end|>", + "<|im_user|>", + "<|im_assistant|>", + "<|start_header_id|>", + "<|end_header_id|>", + "[EOT]", + "<|im_system|>", + "<|im_middle|>", + ] + + special_tokens_mapping = { + i: added_tokens_decoder[i].content for i in added_tokens_decoder + } + + self.vocab_file = vocab_file + mergeable_ranks = load_tiktoken_bpe(vocab_file) + num_base_tokens = len(mergeable_ranks) + self.special_tokens = { + special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i + for i in range( + num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2 + ) + } + + self.model = tiktoken.Encoding( + name=Path(vocab_file).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + logger.info(f"Reloaded tiktoken model from {vocab_file}") + + self.n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens[str(bos_token)] + self.eos_id: int = self.special_tokens[str(eos_token)] + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + + self.pad_id: int = self.special_tokens[str(pad_token)] + self.unk_id: int = self.special_tokens[str(unk_token)] + + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + self.decoder = {} + for i in range(self.n_words): + # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee + decoding = "".join( + [ + self.byte_encoder[ord(char)] + for char in self.model.decode_single_token_bytes(i).decode( + "latin-1" + ) + ] + ) + self.decoder[i] = decoding + + self.encoder = {} + for i in range(self.n_words): + if i in self.decoder: + self.encoder[self.decoder[i]] = i + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + self.all_special_ids_set = set(self.all_special_ids) + + def encode( + self, text: str, allow_special_tokens: bool = True, **kwargs + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + text (str): The input string to be encoded. + + Returns: + list[int]: A list of token IDs. + """ + # If there are other args, we should call super().encode because there are a lot of code + # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id. + # NOTE: our encode method is not compatible with the super().encode method, + # e.g. split_special_tokens' default is True in our encode method. + if len(kwargs) > 0: + # logger.warning(f"Calling super().encode with {kwargs}") + return super().encode(text, **kwargs) + + assert type(text) is str + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + texts = self.pre_tokenizer_process(text) + + all_substrs = [] + for text in texts: + substrs = ( + substr + for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + text[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + all_substrs.extend(substrs) + + t: List[int] = [] + for substr in all_substrs: + if allow_special_tokens: + t.extend( + # we should consider special token as a common token + self.model.encode( + substr, + allowed_special="all", + ) + ) + else: + t.extend( + # we should consider special token as a common token + self.model.encode( + substr, + disallowed_special=(), + ) + ) + + return t + + def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + token_ids (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # If there are other args, we should call super().decode because there are a lot of code + # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token. + if len(kwargs) > 0: + return super().decode(token_ids, **kwargs) + + if type(token_ids) is int: + token_ids = [token_ids] + + return self.model.decode(cast(List[int], token_ids)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + def pre_tokenizer_process(self, text: str) -> List[str]: + """ + pre-tokenizes the input text into a list of tokens. + This method is used to split the input text into smaller chunks for internal processing. + """ + return [text] + + """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """ + + @property + def vocab_size(self) -> int: + return self.n_words + + def get_vocab(self) -> Dict[str, int]: + return self.encoder + + def _tokenize(self, text: str, **kwargs) -> List[str]: + return [self.decoder[t] for t in self.encode(text)] + + def _convert_token_to_id(self, token: str) -> int: + return self.encoder.get(token, self.unk_id) + + def _convert_id_to_token(self, index: int) -> str: + return self.decoder.get(index) + + @staticmethod + def clean_up_tokenization(out_string: str) -> str: + return out_string + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + "utf-8", "replace" + ) + return text + + def save_vocabulary( + self, save_directory: str, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + if not os.path.isdir(save_directory): + raise ValueError( + f"vocabulary path ({save_directory}) should be a directory" + ) + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + def apply_chat_template( + self, + conversation, + tools: Optional[list[dict]] = None, + tokenize: bool = True, + add_generation_prompt: bool = False, + **kwargs, + ): + tools = deep_sort_dict(tools) + return super().apply_chat_template( + conversation, + tools=tools, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + **kwargs, + ) + + +def deep_sort_dict(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: deep_sort_dict(v) for k, v in sorted(obj.items())} + if isinstance(obj, list): + return [deep_sort_dict(item) for item in obj] + return obj diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 7a2bbd6f9..8b35ed406 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -151,6 +151,11 @@ def normalize_config(cfg): if not cfg.base_model_config: cfg.base_model_config = cfg.base_model + # Apply pre-config load patches (e.g., for Kimi Linear remote code patching) + from axolotl.loaders.patch_manager import PatchManager + + PatchManager.apply_pre_config_load_patches(cfg) + model_config = load_model_config(cfg) cfg.tokenizer_config = (