diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 52816c1c8..e3bcec5e5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -87,7 +87,7 @@ jobs: uv pip install --system --no-build-isolation -e ".[dev]" --constraints torch-constraints.txt set -o pipefail python scripts/unsloth_install.py | bash - python scripts/cutcrossentropy_install.py | bash + python scripts/cutcrossentropy_install.py --uv | bash - name: Make sure PyTorch version wasn't clobbered run: | @@ -160,7 +160,7 @@ jobs: python -m build --sdist uv pip install --system dist/*.tar.gz python scripts/unsloth_install.py | sh - python scripts/cutcrossentropy_install.py | sh + python scripts/cutcrossentropy_install.py --uv | sh printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt uv pip install --system ".[dev]" --constraints torch-constraints.txt diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py old mode 100644 new mode 100755 index fcfe152fa..31f3f223a --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -1,15 +1,19 @@ """Script to output the correct installation command for cut-cross-entropy.""" +from __future__ import annotations + import importlib.util import sys -from shlex import quote try: import torch -except ImportError as exc: +except ImportError as exc: # pragma: no cover - defensive raise ImportError("Install torch via `pip install torch`") from exc + from packaging.version import Version as V +USE_UV = "--uv" in sys.argv[1:] + v = V(torch.__version__) # no cut-cross-entropy support for torch < 2.4.0 @@ -19,14 +23,19 @@ if v < V("2.4.0"): cce_spec = importlib.util.find_spec("cut_cross_entropy") -python_path = quote(sys.executable) +UNINSTALL_PREFIX = "" +if cce_spec: + if not importlib.util.find_spec("cut_cross_entropy.transformers"): + if USE_UV: + UNINSTALL_PREFIX = "uv pip uninstall --system --yes cut-cross-entropy && " + else: + UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && " -commands = [] -if cce_spec and not importlib.util.find_spec("cut_cross_entropy.transformers"): - commands.append(f"uv pip uninstall --python {python_path} cut-cross-entropy") - -commands.append( - f'uv pip install --python {python_path} "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"' +installer = "uv pip --system" if USE_UV else "pip" +command = ( + f"{installer} install " + '"cut-cross-entropy[transformers] ' + '@ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"' ) -print(" && ".join(commands)) +print(f"{UNINSTALL_PREFIX}{command}")