This commit is contained in:
Dan Saunders
2025-10-01 16:15:06 -04:00
parent f782957002
commit 00e0238501
2 changed files with 21 additions and 12 deletions

29
scripts/cutcrossentropy_install.py Normal file → Executable file
View File

@@ -1,15 +1,19 @@
"""Script to output the correct installation command for cut-cross-entropy."""
from __future__ import annotations
import importlib.util
import sys
from shlex import quote
try:
import torch
except ImportError as exc:
except ImportError as exc: # pragma: no cover - defensive
raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V
USE_UV = "--uv" in sys.argv[1:]
v = V(torch.__version__)
# no cut-cross-entropy support for torch < 2.4.0
@@ -19,14 +23,19 @@ if v < V("2.4.0"):
cce_spec = importlib.util.find_spec("cut_cross_entropy")
python_path = quote(sys.executable)
UNINSTALL_PREFIX = ""
if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
if USE_UV:
UNINSTALL_PREFIX = "uv pip uninstall --system --yes cut-cross-entropy && "
else:
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
commands = []
if cce_spec and not importlib.util.find_spec("cut_cross_entropy.transformers"):
commands.append(f"uv pip uninstall --python {python_path} cut-cross-entropy")
commands.append(
f'uv pip install --python {python_path} "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"'
installer = "uv pip --system" if USE_UV else "pip"
command = (
f"{installer} install "
'"cut-cross-entropy[transformers] '
'@ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"'
)
print(" && ".join(commands))
print(f"{UNINSTALL_PREFIX}{command}")