Files
axolotl/scripts/unsloth_install.py
Dan Saunders 79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00

41 lines
1.1 KiB
Python

# noqa
import sys
try:
import torch
except ImportError as error:
raise ImportError("Install torch via `pip install torch`") from error
from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]
v = V(torch.__version__)
cuda = str(torch.version.cuda)
try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
raise RuntimeError(f"Torch = {v} too old!")
elif v <= V("2.1.1"):
x = "cu{}{}-torch211"
elif v <= V("2.1.2"):
x = "cu{}{}-torch212"
elif v < V("2.3.0"):
x = "cu{}{}-torch220"
elif v < V("2.4.0"):
x = "cu{}{}-torch230"
elif v < V("2.5.0"):
x = "cu{}{}-torch240"
elif v < V("2.6.0"):
x = "cu{}{}-torch250"
else:
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
uv_prefix = "uv " if use_uv else ""
print(
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
)