diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index a47a33aae..e67c3b182 100755 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -1,8 +1,8 @@ -"""Script to output the correct installation command for cut-cross-entropy.""" +"""Emit the install command for the Axolotl cut-cross-entropy fork.""" from __future__ import annotations -import importlib.util +import shutil import sys try: @@ -12,30 +12,29 @@ except ImportError as exc: # pragma: no cover - defensive from packaging.version import Version as V -USE_UV = "--uv" in sys.argv[1:] +USE_UV_FLAG = "--uv" in sys.argv[1:] +USE_PIP_FLAG = "--pip" in sys.argv[1:] -v = V(torch.__version__) +if USE_UV_FLAG and USE_PIP_FLAG: + raise SystemExit("Specify only one of --uv or --pip") -# no cut-cross-entropy support for torch < 2.4.0 -if v < V("2.4.0"): +if USE_PIP_FLAG: + use_uv = False +elif USE_UV_FLAG: + use_uv = True +else: + use_uv = shutil.which("uv") is not None + +if V(torch.__version__) < V("2.4.0"): print("") sys.exit(0) -cce_spec = importlib.util.find_spec("cut_cross_entropy") - -UNINSTALL_PREFIX = "" -if cce_spec: - if not importlib.util.find_spec("cut_cross_entropy.transformers"): - if USE_UV: - UNINSTALL_PREFIX = "uv pip uninstall --yes cut-cross-entropy && " - else: - UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && " - -installer = "uv pip install --system" if USE_UV else "pip install" +# No need to uninstall in CI runs; the environment is fresh. Just emit the install command. +installer = "uv pip install --system" if use_uv else "pip install" command = ( f"{installer} " '"cut-cross-entropy[transformers] ' '@ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"' ) -print(f"{UNINSTALL_PREFIX}{command}") +print(command)