chore: minor optim changes (add apollo, improve docs, remove lion-pytorch) (#2444)
* feat: add apollo-torch * chore: update optimizer list * fix: deleted accidental requirements file * fix: remove mention of deprecated lion_pytorch
This commit is contained in:
@@ -506,7 +506,7 @@ lr_div_factor: # Learning rate div factor
|
|||||||
|
|
||||||
# Specify optimizer
|
# Specify optimizer
|
||||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
|
||||||
#
|
#
|
||||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||||
@@ -516,25 +516,48 @@ lr_div_factor: # Learning rate div factor
|
|||||||
# - adamw_torch
|
# - adamw_torch
|
||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
|
# - adamw_torch_npu_fused
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
|
# - adamw_torch_4bit
|
||||||
|
# - ademamix
|
||||||
# - sgd
|
# - sgd
|
||||||
# - adagrad
|
# - adagrad
|
||||||
# - adamw_bnb_8bit
|
# - adamw_bnb_8bit
|
||||||
|
# - adamw_8bit # alias for adamw_bnb_8bit
|
||||||
|
# - ademamix_8bit
|
||||||
# - lion_8bit
|
# - lion_8bit
|
||||||
# - lion_32bit
|
# - lion_32bit
|
||||||
# - paged_adamw_32bit
|
# - paged_adamw_32bit
|
||||||
# - paged_adamw_8bit
|
# - paged_adamw_8bit
|
||||||
|
# - paged_ademamix_32bit
|
||||||
|
# - paged_ademamix_8bit
|
||||||
# - paged_lion_32bit
|
# - paged_lion_32bit
|
||||||
# - paged_lion_8bit
|
# - paged_lion_8bit
|
||||||
|
# - rmsprop
|
||||||
|
# - rmsprop_bnb
|
||||||
|
# - rmsprop_bnb_8bit
|
||||||
|
# - rmsprop_bnb_32bit
|
||||||
# - galore_adamw
|
# - galore_adamw
|
||||||
# - galore_adamw_8bit
|
# - galore_adamw_8bit
|
||||||
# - galore_adafactor
|
# - galore_adafactor
|
||||||
# - galore_adamw_layerwise
|
# - galore_adamw_layerwise
|
||||||
# - galore_adamw_8bit_layerwise
|
# - galore_adamw_8bit_layerwise
|
||||||
# - galore_adafactor_layerwise
|
# - galore_adafactor_layerwise
|
||||||
|
# - lomo
|
||||||
|
# - adalomo
|
||||||
|
# - grokadamw
|
||||||
|
# - schedule_free_adamw
|
||||||
|
# - schedule_free_sgd
|
||||||
|
# - apollo_adamw
|
||||||
|
# - apollo_adamw_layerwise
|
||||||
|
#
|
||||||
|
# Additional custom optimizers include:
|
||||||
|
# - optimi_adamw
|
||||||
|
# - ao_adamw_8bit
|
||||||
|
# - ao_adamw_fp8
|
||||||
optimizer:
|
optimizer:
|
||||||
# Dictionary of arguments to pass to the optimizer
|
# Dictionary of arguments to pass to the optimizer
|
||||||
optim_args:
|
optim_args:
|
||||||
|
|||||||
@@ -1,315 +0,0 @@
|
|||||||
accelerate==0.34.1
|
|
||||||
addict==2.4.0
|
|
||||||
aiofiles==23.2.1
|
|
||||||
aiohttp==3.9.0
|
|
||||||
aiosignal==1.3.1
|
|
||||||
aiostream==0.5.2
|
|
||||||
alembic==1.13.1
|
|
||||||
annotated-types==0.6.0
|
|
||||||
annoy==1.17.3
|
|
||||||
ansible==6.7.0
|
|
||||||
ansible-core==2.13.13
|
|
||||||
ansible-vault==2.1.0
|
|
||||||
anyio==3.7.1
|
|
||||||
appdirs==1.4.4
|
|
||||||
art==6.0
|
|
||||||
asgiref==3.7.2
|
|
||||||
async-timeout==4.0.2
|
|
||||||
attrdict==2.0.1
|
|
||||||
attrs==22.2.0
|
|
||||||
awscli==1.32.75
|
|
||||||
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
|
|
||||||
backoff==2.2.1
|
|
||||||
base58==2.1.1
|
|
||||||
beartype==0.17.2
|
|
||||||
bitnet==0.2.1
|
|
||||||
bitsandbytes==0.42.0
|
|
||||||
bittensor==6.7.0
|
|
||||||
black==23.7.0
|
|
||||||
blinker==1.7.0
|
|
||||||
boto3==1.34.75
|
|
||||||
botocore==1.34.75
|
|
||||||
cachetools==5.3.3
|
|
||||||
cachy==0.1.1
|
|
||||||
certifi==2023.7.22
|
|
||||||
cffi==1.16.0
|
|
||||||
cfgv==3.3.1
|
|
||||||
chai-guanaco==1.2.4
|
|
||||||
charset-normalizer==3.2.0
|
|
||||||
cleo==0.6.8
|
|
||||||
click==8.1.7
|
|
||||||
cloudpickle==2.0.0
|
|
||||||
cohere==4.11.2
|
|
||||||
colorama==0.4.4
|
|
||||||
coloredlogs==15.0.1
|
|
||||||
CoLT5-attention==0.10.20
|
|
||||||
contextlib2==21.6.0
|
|
||||||
contourpy==1.2.0
|
|
||||||
cryptography==41.0.3
|
|
||||||
cycler==0.12.1
|
|
||||||
cytoolz==0.12.3
|
|
||||||
databricks-cli==0.18.0
|
|
||||||
dataclasses-json==0.5.7
|
|
||||||
datasets==2.11.0
|
|
||||||
ddt==1.6.0
|
|
||||||
decorator==5.1.1
|
|
||||||
deepspeed==0.15.0
|
|
||||||
# Editable Git install with no remote (dialogpt==0.1)
|
|
||||||
-e /Users/wing/Projects/ml/dialogpt/src
|
|
||||||
dill==0.3.6
|
|
||||||
distlib==0.3.6
|
|
||||||
docker==7.0.0
|
|
||||||
docker-pycreds==0.4.0
|
|
||||||
docstring-parser==0.15
|
|
||||||
docutils==0.16
|
|
||||||
ecdsa==0.18.0
|
|
||||||
einops==0.7.0
|
|
||||||
einops-exts==0.0.4
|
|
||||||
einx==0.1.3
|
|
||||||
entrypoints==0.4
|
|
||||||
eth-hash==0.6.0
|
|
||||||
eth-keys==0.5.0
|
|
||||||
eth-typing==4.0.0
|
|
||||||
eth-utils==2.3.1
|
|
||||||
evaluate==0.4.0
|
|
||||||
exceptiongroup==1.1.1
|
|
||||||
fastapi==0.109.2
|
|
||||||
fastcore==1.5.29
|
|
||||||
ffmpy==0.4.0
|
|
||||||
filelock==3.12.2
|
|
||||||
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
|
|
||||||
fire==0.5.0
|
|
||||||
first==2.0.2
|
|
||||||
flake8==7.0.0
|
|
||||||
Flask==3.0.1
|
|
||||||
fonttools==4.47.2
|
|
||||||
frozendict==2.4.1
|
|
||||||
frozenlist==1.3.3
|
|
||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
|
||||||
fsspec==2023.6.0
|
|
||||||
fuzzywuzzy==0.18.0
|
|
||||||
gitdb==4.0.10
|
|
||||||
GitPython==3.1.31
|
|
||||||
google-pasta==0.2.0
|
|
||||||
gradio==4.42.0
|
|
||||||
gradio_client==1.3.0
|
|
||||||
greenlet==2.0.2
|
|
||||||
grpclib==0.4.7
|
|
||||||
gunicorn==21.2.0
|
|
||||||
h11==0.14.0
|
|
||||||
h2==4.1.0
|
|
||||||
hpack==4.0.0
|
|
||||||
httpcore==0.17.3
|
|
||||||
httpx==0.24.1
|
|
||||||
huggingface-hub==0.23.4
|
|
||||||
humanfriendly==10.0
|
|
||||||
hyperframe==6.0.1
|
|
||||||
identify==2.5.24
|
|
||||||
idna==3.4
|
|
||||||
immutables==0.20
|
|
||||||
importlib-metadata==6.7.0
|
|
||||||
importlib-resources==6.1.1
|
|
||||||
inflection==0.5.1
|
|
||||||
iniconfig==2.0.0
|
|
||||||
itsdangerous==2.1.2
|
|
||||||
Jinja2==3.1.2
|
|
||||||
jmespath==1.0.1
|
|
||||||
joblib==1.3.2
|
|
||||||
jsonlines==3.1.0
|
|
||||||
jsonschema==2.6.0
|
|
||||||
kiwisolver==1.4.5
|
|
||||||
langchain==0.0.144
|
|
||||||
Levenshtein==0.24.0
|
|
||||||
libcst==1.1.0
|
|
||||||
liger-kernel==0.0.0
|
|
||||||
lion-pytorch==0.1.2
|
|
||||||
llama-cpp-python==0.1.36
|
|
||||||
llvmlite==0.40.1
|
|
||||||
local-attention==1.9.0
|
|
||||||
loguru==0.7.0
|
|
||||||
Mako==1.3.2
|
|
||||||
Markdown==3.5.2
|
|
||||||
markdown-it-py==3.0.0
|
|
||||||
markdown2==2.4.10
|
|
||||||
MarkupSafe==2.1.2
|
|
||||||
marshmallow==3.19.0
|
|
||||||
marshmallow-enum==1.5.1
|
|
||||||
matplotlib==3.8.2
|
|
||||||
mccabe==0.7.0
|
|
||||||
mdurl==0.1.2
|
|
||||||
MEGABYTE-pytorch==0.0.7
|
|
||||||
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
|
|
||||||
mlflow==2.10.0
|
|
||||||
modal==0.62.77
|
|
||||||
more-itertools==10.2.0
|
|
||||||
mpmath==1.2.1
|
|
||||||
msgpack==1.0.7
|
|
||||||
msgpack-numpy-opentensor==0.5.0
|
|
||||||
multidict==6.0.4
|
|
||||||
multiprocess==0.70.14
|
|
||||||
munch==2.5.0
|
|
||||||
mypy==1.3.0
|
|
||||||
mypy-extensions==1.0.0
|
|
||||||
nest-asyncio==1.6.0
|
|
||||||
netaddr==0.10.1
|
|
||||||
networkx==3.0rc1
|
|
||||||
nh3==0.2.14
|
|
||||||
nodeenv==1.8.0
|
|
||||||
nomic==2.0.2
|
|
||||||
numba==0.57.1
|
|
||||||
numexpr==2.8.4
|
|
||||||
numpy==1.24.4
|
|
||||||
oauthlib==3.2.2
|
|
||||||
openai==0.27.4
|
|
||||||
openapi==1.1.0
|
|
||||||
openapi-schema-pydantic==1.2.4
|
|
||||||
optimum==1.8.6
|
|
||||||
orjson==3.10.7
|
|
||||||
packaging==23.1
|
|
||||||
pandas==2.0.0
|
|
||||||
parameterized==0.9.0
|
|
||||||
password-strength==0.0.3.post2
|
|
||||||
pastel==0.1.1
|
|
||||||
pathos==0.3.0
|
|
||||||
pathspec==0.11.1
|
|
||||||
pathtools==0.1.2
|
|
||||||
peft==0.11.1
|
|
||||||
pendulum==3.0.0
|
|
||||||
Pillow==9.5.0
|
|
||||||
pip-tools==1.11.0
|
|
||||||
platformdirs==3.2.0
|
|
||||||
pluggy==1.4.0
|
|
||||||
poetry==0.7.1
|
|
||||||
pox==0.3.2
|
|
||||||
ppft==1.7.6.6
|
|
||||||
pre-commit==3.3.2
|
|
||||||
prettytable==3.10.0
|
|
||||||
prompt-toolkit==3.0.39
|
|
||||||
protobuf==3.20.2
|
|
||||||
protobuf3-to-dict==0.1.5
|
|
||||||
psutil==5.9.5
|
|
||||||
psycopg==3.1.18
|
|
||||||
PuLP==2.8.0
|
|
||||||
py==1.11.0
|
|
||||||
py-bip39-bindings==0.1.11
|
|
||||||
py-cpuinfo==9.0.0
|
|
||||||
py-ed25519-zebra-bindings==1.0.1
|
|
||||||
py-sr25519-bindings==0.2.0
|
|
||||||
pyarrow==11.0.0
|
|
||||||
pyasn1==0.6.0
|
|
||||||
pycodestyle==2.11.1
|
|
||||||
pycparser==2.21
|
|
||||||
pycryptodome==3.20.0
|
|
||||||
pydantic==2.5.3
|
|
||||||
pydantic_core==2.14.6
|
|
||||||
pydub==0.25.1
|
|
||||||
pyfiglet==0.8.post1
|
|
||||||
pyflakes==3.2.0
|
|
||||||
Pygments==2.15.1
|
|
||||||
PyJWT==2.8.0
|
|
||||||
pylev==1.4.0
|
|
||||||
PyNaCl==1.5.0
|
|
||||||
pynvml==11.5.0
|
|
||||||
pyparsing==2.4.7
|
|
||||||
pyrsistent==0.14.11
|
|
||||||
pytest==8.0.2
|
|
||||||
pytest-asyncio==0.23.4
|
|
||||||
python-dateutil==2.8.2
|
|
||||||
python-dotenv==1.0.1
|
|
||||||
python-Levenshtein==0.24.0
|
|
||||||
python-multipart==0.0.9
|
|
||||||
pytz==2023.3
|
|
||||||
PyYAML==6.0.1
|
|
||||||
querystring-parser==1.2.4
|
|
||||||
rapidfuzz==3.6.1
|
|
||||||
regex==2023.6.3
|
|
||||||
requests==2.31.0
|
|
||||||
requests-toolbelt==0.8.0
|
|
||||||
resolvelib==0.8.1
|
|
||||||
responses==0.18.0
|
|
||||||
retry==0.9.2
|
|
||||||
rich==13.7.0
|
|
||||||
rsa==4.7.2
|
|
||||||
ruff==0.6.3
|
|
||||||
s3transfer==0.10.1
|
|
||||||
safetensors==0.4.5
|
|
||||||
sagemaker==2.148.0
|
|
||||||
scalecodec==1.2.7
|
|
||||||
schedulefree==1.2.1
|
|
||||||
schema==0.7.5
|
|
||||||
scikit-learn==1.4.0
|
|
||||||
scipy==1.9.3
|
|
||||||
seaborn==0.13.2
|
|
||||||
semantic-version==2.10.0
|
|
||||||
sentencepiece==0.2.0
|
|
||||||
sentry-sdk==1.19.1
|
|
||||||
setproctitle==1.3.2
|
|
||||||
shellingham==1.5.4
|
|
||||||
shortuuid==1.0.11
|
|
||||||
shtab==1.6.5
|
|
||||||
sigtools==4.0.1
|
|
||||||
six==1.16.0
|
|
||||||
skypilot==0.4.1
|
|
||||||
smdebug-rulesconfig==1.0.1
|
|
||||||
smmap==5.0.0
|
|
||||||
sniffio==1.3.0
|
|
||||||
SQLAlchemy==1.4.47
|
|
||||||
sqlparse==0.4.4
|
|
||||||
starlette==0.36.3
|
|
||||||
substrate-interface==1.5.2
|
|
||||||
svgwrite==1.4.3
|
|
||||||
sympy==1.11.1
|
|
||||||
synchronicity==0.6.7
|
|
||||||
tabulate==0.9.0
|
|
||||||
tblib==1.7.0
|
|
||||||
tenacity==8.2.2
|
|
||||||
tensor-parallel==2.0.0
|
|
||||||
termcolor==2.2.0
|
|
||||||
text2art==0.2.0
|
|
||||||
threadpoolctl==3.2.0
|
|
||||||
tiktoken==0.6.0
|
|
||||||
time-machine==2.14.1
|
|
||||||
timm==0.9.16
|
|
||||||
tokenizers==0.19.1
|
|
||||||
tokenmonster==1.1.12
|
|
||||||
toml==0.9.6
|
|
||||||
tomli==2.0.1
|
|
||||||
tomlkit==0.12.0
|
|
||||||
toolz==0.12.1
|
|
||||||
torch==2.2.0
|
|
||||||
torchdata==0.6.1
|
|
||||||
torchdiffeq==0.2.3
|
|
||||||
TorchFix==0.4.0
|
|
||||||
torchtext==0.15.2
|
|
||||||
torchvision==0.17.0
|
|
||||||
tqdm==4.66.2
|
|
||||||
transformers==4.44.2
|
|
||||||
trl==0.9.6
|
|
||||||
typer==0.12.5
|
|
||||||
types-certifi==2021.10.8.3
|
|
||||||
types-requests==2.31.0.20240125
|
|
||||||
types-setuptools==69.0.0.20240125
|
|
||||||
types-toml==0.10.8.7
|
|
||||||
typing==3.7.4.3
|
|
||||||
typing-inspect==0.8.0
|
|
||||||
typing_extensions==4.9.0
|
|
||||||
tyro==0.5.18
|
|
||||||
tzdata==2023.3
|
|
||||||
unique-names-generator==1.0.2
|
|
||||||
urllib3==2.2.2
|
|
||||||
uvicorn==0.22.0
|
|
||||||
vector_quantize_pytorch==1.14.1
|
|
||||||
virtualenv==20.23.0
|
|
||||||
voyager==2.0.2
|
|
||||||
wandb==0.16.2
|
|
||||||
watchfiles==0.21.0
|
|
||||||
wavedrom==2.0.3.post3
|
|
||||||
wcwidth==0.2.6
|
|
||||||
websocket-client==1.7.0
|
|
||||||
websockets==12.0
|
|
||||||
Werkzeug==3.0.1
|
|
||||||
wonderwords==2.2.0
|
|
||||||
xxhash==3.2.0
|
|
||||||
yarl==1.8.2
|
|
||||||
zetascale==2.2.7
|
|
||||||
zipp==3.15.0
|
|
||||||
12
setup.py
12
setup.py
@@ -16,9 +16,7 @@ def parse_requirements():
|
|||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
is_extras = (
|
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
||||||
"deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line
|
|
||||||
)
|
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
_, url = line.split()
|
_, url = line.split()
|
||||||
@@ -135,15 +133,15 @@ setup(
|
|||||||
"mlflow": [
|
"mlflow": [
|
||||||
"mlflow",
|
"mlflow",
|
||||||
],
|
],
|
||||||
"lion-pytorch": [
|
|
||||||
"lion-pytorch==0.1.2",
|
|
||||||
],
|
|
||||||
"galore": [
|
"galore": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
],
|
],
|
||||||
|
"apollo": [
|
||||||
|
"apollo-torch",
|
||||||
|
],
|
||||||
"optimizers": [
|
"optimizers": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
"lion-pytorch==0.1.2",
|
"apollo-torch",
|
||||||
"lomo-optim==0.1.1",
|
"lomo-optim==0.1.1",
|
||||||
"torch-optimi==0.2.1",
|
"torch-optimi==0.2.1",
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user