fix?
This commit is contained in:
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -87,7 +87,7 @@ jobs:
|
|||||||
uv pip install --system --no-build-isolation -e ".[dev]" --constraints torch-constraints.txt
|
uv pip install --system --no-build-isolation -e ".[dev]" --constraints torch-constraints.txt
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
python scripts/unsloth_install.py | bash
|
python scripts/unsloth_install.py | bash
|
||||||
python scripts/cutcrossentropy_install.py | bash
|
python scripts/cutcrossentropy_install.py --uv | bash
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
run: |
|
run: |
|
||||||
@@ -160,7 +160,7 @@ jobs:
|
|||||||
python -m build --sdist
|
python -m build --sdist
|
||||||
uv pip install --system dist/*.tar.gz
|
uv pip install --system dist/*.tar.gz
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/unsloth_install.py | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py --uv | sh
|
||||||
printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt
|
printf "torch==${{ matrix.pytorch_version }}\n" > torch-constraints.txt
|
||||||
uv pip install --system ".[dev]" --constraints torch-constraints.txt
|
uv pip install --system ".[dev]" --constraints torch-constraints.txt
|
||||||
|
|
||||||
|
|||||||
29
scripts/cutcrossentropy_install.py
Normal file → Executable file
29
scripts/cutcrossentropy_install.py
Normal file → Executable file
@@ -1,15 +1,19 @@
|
|||||||
"""Script to output the correct installation command for cut-cross-entropy."""
|
"""Script to output the correct installation command for cut-cross-entropy."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import sys
|
import sys
|
||||||
from shlex import quote
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
except ImportError as exc:
|
except ImportError as exc: # pragma: no cover - defensive
|
||||||
raise ImportError("Install torch via `pip install torch`") from exc
|
raise ImportError("Install torch via `pip install torch`") from exc
|
||||||
|
|
||||||
from packaging.version import Version as V
|
from packaging.version import Version as V
|
||||||
|
|
||||||
|
USE_UV = "--uv" in sys.argv[1:]
|
||||||
|
|
||||||
v = V(torch.__version__)
|
v = V(torch.__version__)
|
||||||
|
|
||||||
# no cut-cross-entropy support for torch < 2.4.0
|
# no cut-cross-entropy support for torch < 2.4.0
|
||||||
@@ -19,14 +23,19 @@ if v < V("2.4.0"):
|
|||||||
|
|
||||||
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||||
|
|
||||||
python_path = quote(sys.executable)
|
UNINSTALL_PREFIX = ""
|
||||||
|
if cce_spec:
|
||||||
|
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
||||||
|
if USE_UV:
|
||||||
|
UNINSTALL_PREFIX = "uv pip uninstall --system --yes cut-cross-entropy && "
|
||||||
|
else:
|
||||||
|
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
||||||
|
|
||||||
commands = []
|
installer = "uv pip --system" if USE_UV else "pip"
|
||||||
if cce_spec and not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
command = (
|
||||||
commands.append(f"uv pip uninstall --python {python_path} cut-cross-entropy")
|
f"{installer} install "
|
||||||
|
'"cut-cross-entropy[transformers] '
|
||||||
commands.append(
|
'@ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"'
|
||||||
f'uv pip install --python {python_path} "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(" && ".join(commands))
|
print(f"{UNINSTALL_PREFIX}{command}")
|
||||||
|
|||||||
Reference in New Issue
Block a user