feat: add cut_cross_entropy (#2091)
* feat: add cut_cross_entropy * fix: add to input * fix: remove from setup.py * feat: refactor into an integration * chore: ignore lint * feat: add test for cce * fix: set max_steps for liger test * chore: Update base model following suggestion Co-authored-by: Wing Lian <wing.lian@gmail.com> * chore: update special_tokens following suggestion Co-authored-by: Wing Lian <wing.lian@gmail.com> * chore: remove with_temp_dir following comments * fix: plugins aren't loaded * chore: update quotes in error message * chore: lint * chore: lint * feat: enable FA on test * chore: refactor get_pytorch_version * fix: lock cce commit version * fix: remove subclassing UT * fix: downcast even if not using FA and config check * feat: add test to check different attentions * feat: add install to CI * chore: refactor to use parametrize for attention * fix: pytest not detecting test * feat: handle torch lower than 2.4 * fix args/kwargs to match docs * use release version cut-cross-entropy==24.11.4 * fix quotes * fix: use named params for clarity for modal builder * fix: handle install from pip * fix: test check only top level module install * fix: re-add import check * uninstall existing version if no transformers submodule in cce * more dataset fixtures into the cache --------- Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
28
scripts/cutcrossentropy_install.py
Normal file
28
scripts/cutcrossentropy_install.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Script to output the correct installation command for cut-cross-entropy."""
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError as exc:
|
||||
raise ImportError("Install torch via `pip install torch`") from exc
|
||||
from packaging.version import Version as V
|
||||
|
||||
v = V(torch.__version__)
|
||||
|
||||
# no cut-cross-entropy support for torch < 2.4.0
|
||||
if v < V("2.4.0"):
|
||||
print("")
|
||||
sys.exit(0)
|
||||
|
||||
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers")
|
||||
|
||||
UNINSTALL_PREFIX = ""
|
||||
if cce_spec and not cce_spec_transformers:
|
||||
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
||||
)
|
||||
Reference in New Issue
Block a user