* bump hf deps * upgrade liger-kernel too * install cce from fork for transformers fix * fix reference to vocab size in gemma3 patch * use padding_idx instead of pad_token_id * remove fixed gemma3 patch * use updated cce fork * fix local mllama cce patches w docstring * add test for multipack with trainer setup and fix trainer for trainer refactor upstream * bump modal version * guard for iterable datasetS * mllama model arch layout changed in latest transformers * fix batch sampler with drop_last * fix: address upstream vlm changes for lora * fix: update references to old lora target path * fix: remove mllama fa2 patch due to upstream fix * fix: lora kernel patch path for multimodal models * fix: removed mllama from quarto * run test for came optim on 2.6.0+ * fix fsdp2 patch and remove deprecated patch * make sure to set sequence_parallel_degree for grpo * Add SP test for GRPO * add sp to grpo config for trainer * use reward_funcs as kwarg to grpo trainer * fix the comprehension for reward funcs * reward funcs already passed in as args * init sp_group right before training * fix check for adding models to SP context * make sure to pass args to super * upgrade deepspeed * use updated trl and add reasoning flags for vllm * patch the worker --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
30 lines
792 B
Python
30 lines
792 B
Python
"""Script to output the correct installation command for cut-cross-entropy."""
|
|
|
|
import importlib.util
|
|
import sys
|
|
|
|
try:
|
|
import torch
|
|
except ImportError as exc:
|
|
raise ImportError("Install torch via `pip install torch`") from exc
|
|
from packaging.version import Version as V
|
|
|
|
v = V(torch.__version__)
|
|
|
|
# no cut-cross-entropy support for torch < 2.4.0
|
|
if v < V("2.4.0"):
|
|
print("")
|
|
sys.exit(0)
|
|
|
|
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
|
|
|
UNINSTALL_PREFIX = ""
|
|
if cce_spec:
|
|
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
|
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
|
|
|
print(
|
|
UNINSTALL_PREFIX
|
|
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
|
|
)
|