* feat: add custom kimi linear patch [skip ci] * feat: add configuration file and fix import [skip ci] * fix: hijack tokenizer temporarily [skip ci] * chore: remove accidental commit * fix: attempt patch kimi remote * fix: kwargs passsed * fix: device for tensor * fix: aux loss calculation * feat: cleaned up patches order * fix: remove duplicate tokenizer patch * chore: add debug logs * chore: add debug logs * chore: debug * Revert "chore: add debug logs" This reverts commitda372a5f67. * Revert "chore: add debug logs" This reverts commit97d1de1d7c. * fix: KeyError: 'tokenization_kimi' * fix: support remote_model_id in cce patch * feat: add config preload patch * fix: use standard aux loss calc and updated modeling * fix: import * feat: add kimi-linear docs and example * chore: add note about moe kernels * feat: update cce to include kimi-linear * chore: lint * chore: update main readme * fix: patch mechanism to address comments * chore: lint * fix: tests * chore: cleanup comment
34 lines
874 B
Python
34 lines
874 B
Python
"""Script to output the correct installation command for cut-cross-entropy."""
|
|
|
|
import importlib.util
|
|
import sys
|
|
|
|
try:
|
|
import torch
|
|
except ImportError as exc:
|
|
raise ImportError("Install torch via `pip install torch`") from exc
|
|
from packaging.version import Version as V
|
|
|
|
USE_UV = "--uv" in sys.argv[1:]
|
|
|
|
v = V(torch.__version__)
|
|
|
|
# no cut-cross-entropy support for torch < 2.4.0
|
|
if v < V("2.4.0"):
|
|
print("")
|
|
sys.exit(0)
|
|
|
|
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
|
|
|
UNINSTALL_PREFIX = ""
|
|
if cce_spec:
|
|
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
|
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
|
|
|
UV_PREFIX = "uv " if USE_UV else ""
|
|
|
|
print(
|
|
UNINSTALL_PREFIX
|
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@242b245"'
|
|
)
|