Files
axolotl/scripts/cutcrossentropy_install.py
2025-09-26 20:35:08 -04:00

33 lines
899 B
Python

"""Script to output the correct installation command for cut-cross-entropy."""
import importlib.util
import sys
from shlex import quote
try:
import torch
except ImportError as exc:
raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V
v = V(torch.__version__)
# no cut-cross-entropy support for torch < 2.4.0
if v < V("2.4.0"):
print("")
sys.exit(0)
cce_spec = importlib.util.find_spec("cut_cross_entropy")
python_path = quote(sys.executable)
commands = []
if cce_spec and not importlib.util.find_spec("cut_cross_entropy.transformers"):
commands.append(f"uv pip uninstall --python {python_path} cut-cross-entropy")
commands.append(
f'uv pip install --python {python_path} "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
)
print(" && ".join(commands))