diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d8d6f73ea..1659ff1cf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -83,7 +83,8 @@ jobs: run: | uv pip show --system torch uv pip install --system wheel - uv pip install --system --no-build-isolation -e ".[dev]" + printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt + uv pip install --system --no-build-isolation -e ".[dev]" --constraints torch-constraints.txt set -o pipefail python scripts/unsloth_install.py | bash python scripts/cutcrossentropy_install.py | bash @@ -160,7 +161,8 @@ jobs: uv pip install --system dist/*.tar.gz python scripts/unsloth_install.py | sh python scripts/cutcrossentropy_install.py | sh - uv pip install --system ".[dev]" + printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt + uv pip install --system ".[dev]" --constraints torch-constraints.txt - name: Make sure PyTorch version wasn't clobbered run: | diff --git a/scripts/unsloth_install.py b/scripts/unsloth_install.py index bbb360797..5bc3e0fd3 100644 --- a/scripts/unsloth_install.py +++ b/scripts/unsloth_install.py @@ -1,6 +1,23 @@ -"""Print the uv commands required to install Unsloth.""" +"""Print the uv commands required to install Unsloth without altering Torch.""" -UNSLOTH_BASE = "uv pip install --system unsloth-zoo==2025.9.12" -UNSLOTH_HF = 'uv pip install --system --no-deps "unsloth[huggingface]==2025.9.9"' +try: + import torch +except ImportError as error: + raise ImportError("Install torch via `pip install torch`") from error -print(f"{UNSLOTH_BASE} && {UNSLOTH_HF}") +from packaging.version import Version as V + +TORCH_MIN = V("2.6.0") +UNSLOTH_BASE = ( + "uv pip install --system --no-deps unsloth-zoo==2025.9.12" + ' && uv pip install --system --no-deps "unsloth[huggingface]==2025.9.9"' +) + +version = V(torch.__version__) +if version < TORCH_MIN: + raise RuntimeError( + f"Torch {version} detected, but Unsloth requires >= {TORCH_MIN}. " + "Upgrade your torch install and re-run this helper." + ) + +print(UNSLOTH_BASE)