From 9344fa5e8c67a8c9fdbd6304a36d906a0bf8d5d6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 26 Sep 2025 20:35:08 -0400 Subject: [PATCH] fix install scripts (?) --- docker/Dockerfile | 4 ++-- scripts/cutcrossentropy_install.py | 17 ++++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 0680c0420..3e2403cee 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -35,8 +35,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ else \ uv pip install --python "$VENV_PYTHON" --no-build-isolation -e .[ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ fi && \ - python scripts/unsloth_install.py | sh && \ - python scripts/cutcrossentropy_install.py | sh && \ + "$VENV_PYTHON" scripts/unsloth_install.py | sh && \ + "$VENV_PYTHON" scripts/cutcrossentropy_install.py | sh && \ uv pip install --python "$VENV_PYTHON" pytest # fix so that git fetch/pull from remote works with shallow clone diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index ca4351996..0364c03f6 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -2,6 +2,7 @@ import importlib.util import sys +from shlex import quote try: import torch @@ -18,12 +19,14 @@ if v < V("2.4.0"): cce_spec = importlib.util.find_spec("cut_cross_entropy") -UNINSTALL_PREFIX = "" -if cce_spec: - if not importlib.util.find_spec("cut_cross_entropy.transformers"): - UNINSTALL_PREFIX = "uv pip uninstall --system cut-cross-entropy && " +python_path = quote(sys.executable) -print( - UNINSTALL_PREFIX - + 'uv pip install --system "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"' +commands = [] +if cce_spec and not importlib.util.find_spec("cut_cross_entropy.transformers"): + commands.append(f"uv pip uninstall --python {python_path} cut-cross-entropy") + +commands.append( + f'uv pip install --python {python_path} "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"' ) + +print(" && ".join(commands))