From e2da821e67072109cfd9d0e83bab524618099d42 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 27 Mar 2025 05:14:07 +0700 Subject: [PATCH] chore: minor optim changes (add apollo, improve docs, remove lion-pytorch) (#2444) * feat: add apollo-torch * chore: update optimizer list * fix: deleted accidental requirements file * fix: remove mention of deprecated lion_pytorch --- docs/config.qmd | 27 +++- requirements_env.txt | 315 ------------------------------------------- setup.py | 12 +- 3 files changed, 30 insertions(+), 324 deletions(-) delete mode 100644 requirements_env.txt diff --git a/docs/config.qmd b/docs/config.qmd index 2a79a0126..8620ab27d 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -506,7 +506,7 @@ lr_div_factor: # Learning rate div factor # Specify optimizer # Valid values are driven by the Transformers OptimizerNames class, see: -# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134 +# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189 # # Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of # torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used @@ -516,25 +516,48 @@ lr_div_factor: # Learning rate div factor # - adamw_torch # - adamw_torch_fused # - adamw_torch_xla +# - adamw_torch_npu_fused # - adamw_apex_fused -# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1) +# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1) # - adafactor # - adamw_anyprecision +# - adamw_torch_4bit +# - ademamix # - sgd # - adagrad # - adamw_bnb_8bit +# - adamw_8bit # alias for adamw_bnb_8bit +# - ademamix_8bit # - lion_8bit # - lion_32bit # - paged_adamw_32bit # - paged_adamw_8bit +# - paged_ademamix_32bit +# - paged_ademamix_8bit # - paged_lion_32bit # - paged_lion_8bit +# - rmsprop +# - rmsprop_bnb +# - rmsprop_bnb_8bit +# - rmsprop_bnb_32bit # - galore_adamw # - galore_adamw_8bit # - galore_adafactor # - galore_adamw_layerwise # - galore_adamw_8bit_layerwise # - galore_adafactor_layerwise +# - lomo +# - adalomo +# - grokadamw +# - schedule_free_adamw +# - schedule_free_sgd +# - apollo_adamw +# - apollo_adamw_layerwise +# +# Additional custom optimizers include: +# - optimi_adamw +# - ao_adamw_8bit +# - ao_adamw_fp8 optimizer: # Dictionary of arguments to pass to the optimizer optim_args: diff --git a/requirements_env.txt b/requirements_env.txt deleted file mode 100644 index f8acbf73c..000000000 --- a/requirements_env.txt +++ /dev/null @@ -1,315 +0,0 @@ -accelerate==0.34.1 -addict==2.4.0 -aiofiles==23.2.1 -aiohttp==3.9.0 -aiosignal==1.3.1 -aiostream==0.5.2 -alembic==1.13.1 -annotated-types==0.6.0 -annoy==1.17.3 -ansible==6.7.0 -ansible-core==2.13.13 -ansible-vault==2.1.0 -anyio==3.7.1 -appdirs==1.4.4 -art==6.0 -asgiref==3.7.2 -async-timeout==4.0.2 -attrdict==2.0.1 -attrs==22.2.0 -awscli==1.32.75 --e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl -backoff==2.2.1 -base58==2.1.1 -beartype==0.17.2 -bitnet==0.2.1 -bitsandbytes==0.42.0 -bittensor==6.7.0 -black==23.7.0 -blinker==1.7.0 -boto3==1.34.75 -botocore==1.34.75 -cachetools==5.3.3 -cachy==0.1.1 -certifi==2023.7.22 -cffi==1.16.0 -cfgv==3.3.1 -chai-guanaco==1.2.4 -charset-normalizer==3.2.0 -cleo==0.6.8 -click==8.1.7 -cloudpickle==2.0.0 -cohere==4.11.2 -colorama==0.4.4 -coloredlogs==15.0.1 -CoLT5-attention==0.10.20 -contextlib2==21.6.0 -contourpy==1.2.0 -cryptography==41.0.3 -cycler==0.12.1 -cytoolz==0.12.3 -databricks-cli==0.18.0 -dataclasses-json==0.5.7 -datasets==2.11.0 -ddt==1.6.0 -decorator==5.1.1 -deepspeed==0.15.0 -# Editable Git install with no remote (dialogpt==0.1) --e /Users/wing/Projects/ml/dialogpt/src -dill==0.3.6 -distlib==0.3.6 -docker==7.0.0 -docker-pycreds==0.4.0 -docstring-parser==0.15 -docutils==0.16 -ecdsa==0.18.0 -einops==0.7.0 -einops-exts==0.0.4 -einx==0.1.3 -entrypoints==0.4 -eth-hash==0.6.0 -eth-keys==0.5.0 -eth-typing==4.0.0 -eth-utils==2.3.1 -evaluate==0.4.0 -exceptiongroup==1.1.1 -fastapi==0.109.2 -fastcore==1.5.29 -ffmpy==0.4.0 -filelock==3.12.2 --e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet -fire==0.5.0 -first==2.0.2 -flake8==7.0.0 -Flask==3.0.1 -fonttools==4.47.2 -frozendict==2.4.1 -frozenlist==1.3.3 -fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe -fsspec==2023.6.0 -fuzzywuzzy==0.18.0 -gitdb==4.0.10 -GitPython==3.1.31 -google-pasta==0.2.0 -gradio==4.42.0 -gradio_client==1.3.0 -greenlet==2.0.2 -grpclib==0.4.7 -gunicorn==21.2.0 -h11==0.14.0 -h2==4.1.0 -hpack==4.0.0 -httpcore==0.17.3 -httpx==0.24.1 -huggingface-hub==0.23.4 -humanfriendly==10.0 -hyperframe==6.0.1 -identify==2.5.24 -idna==3.4 -immutables==0.20 -importlib-metadata==6.7.0 -importlib-resources==6.1.1 -inflection==0.5.1 -iniconfig==2.0.0 -itsdangerous==2.1.2 -Jinja2==3.1.2 -jmespath==1.0.1 -joblib==1.3.2 -jsonlines==3.1.0 -jsonschema==2.6.0 -kiwisolver==1.4.5 -langchain==0.0.144 -Levenshtein==0.24.0 -libcst==1.1.0 -liger-kernel==0.0.0 -lion-pytorch==0.1.2 -llama-cpp-python==0.1.36 -llvmlite==0.40.1 -local-attention==1.9.0 -loguru==0.7.0 -Mako==1.3.2 -Markdown==3.5.2 -markdown-it-py==3.0.0 -markdown2==2.4.10 -MarkupSafe==2.1.2 -marshmallow==3.19.0 -marshmallow-enum==1.5.1 -matplotlib==3.8.2 -mccabe==0.7.0 -mdurl==0.1.2 -MEGABYTE-pytorch==0.0.7 --e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit -mlflow==2.10.0 -modal==0.62.77 -more-itertools==10.2.0 -mpmath==1.2.1 -msgpack==1.0.7 -msgpack-numpy-opentensor==0.5.0 -multidict==6.0.4 -multiprocess==0.70.14 -munch==2.5.0 -mypy==1.3.0 -mypy-extensions==1.0.0 -nest-asyncio==1.6.0 -netaddr==0.10.1 -networkx==3.0rc1 -nh3==0.2.14 -nodeenv==1.8.0 -nomic==2.0.2 -numba==0.57.1 -numexpr==2.8.4 -numpy==1.24.4 -oauthlib==3.2.2 -openai==0.27.4 -openapi==1.1.0 -openapi-schema-pydantic==1.2.4 -optimum==1.8.6 -orjson==3.10.7 -packaging==23.1 -pandas==2.0.0 -parameterized==0.9.0 -password-strength==0.0.3.post2 -pastel==0.1.1 -pathos==0.3.0 -pathspec==0.11.1 -pathtools==0.1.2 -peft==0.11.1 -pendulum==3.0.0 -Pillow==9.5.0 -pip-tools==1.11.0 -platformdirs==3.2.0 -pluggy==1.4.0 -poetry==0.7.1 -pox==0.3.2 -ppft==1.7.6.6 -pre-commit==3.3.2 -prettytable==3.10.0 -prompt-toolkit==3.0.39 -protobuf==3.20.2 -protobuf3-to-dict==0.1.5 -psutil==5.9.5 -psycopg==3.1.18 -PuLP==2.8.0 -py==1.11.0 -py-bip39-bindings==0.1.11 -py-cpuinfo==9.0.0 -py-ed25519-zebra-bindings==1.0.1 -py-sr25519-bindings==0.2.0 -pyarrow==11.0.0 -pyasn1==0.6.0 -pycodestyle==2.11.1 -pycparser==2.21 -pycryptodome==3.20.0 -pydantic==2.5.3 -pydantic_core==2.14.6 -pydub==0.25.1 -pyfiglet==0.8.post1 -pyflakes==3.2.0 -Pygments==2.15.1 -PyJWT==2.8.0 -pylev==1.4.0 -PyNaCl==1.5.0 -pynvml==11.5.0 -pyparsing==2.4.7 -pyrsistent==0.14.11 -pytest==8.0.2 -pytest-asyncio==0.23.4 -python-dateutil==2.8.2 -python-dotenv==1.0.1 -python-Levenshtein==0.24.0 -python-multipart==0.0.9 -pytz==2023.3 -PyYAML==6.0.1 -querystring-parser==1.2.4 -rapidfuzz==3.6.1 -regex==2023.6.3 -requests==2.31.0 -requests-toolbelt==0.8.0 -resolvelib==0.8.1 -responses==0.18.0 -retry==0.9.2 -rich==13.7.0 -rsa==4.7.2 -ruff==0.6.3 -s3transfer==0.10.1 -safetensors==0.4.5 -sagemaker==2.148.0 -scalecodec==1.2.7 -schedulefree==1.2.1 -schema==0.7.5 -scikit-learn==1.4.0 -scipy==1.9.3 -seaborn==0.13.2 -semantic-version==2.10.0 -sentencepiece==0.2.0 -sentry-sdk==1.19.1 -setproctitle==1.3.2 -shellingham==1.5.4 -shortuuid==1.0.11 -shtab==1.6.5 -sigtools==4.0.1 -six==1.16.0 -skypilot==0.4.1 -smdebug-rulesconfig==1.0.1 -smmap==5.0.0 -sniffio==1.3.0 -SQLAlchemy==1.4.47 -sqlparse==0.4.4 -starlette==0.36.3 -substrate-interface==1.5.2 -svgwrite==1.4.3 -sympy==1.11.1 -synchronicity==0.6.7 -tabulate==0.9.0 -tblib==1.7.0 -tenacity==8.2.2 -tensor-parallel==2.0.0 -termcolor==2.2.0 -text2art==0.2.0 -threadpoolctl==3.2.0 -tiktoken==0.6.0 -time-machine==2.14.1 -timm==0.9.16 -tokenizers==0.19.1 -tokenmonster==1.1.12 -toml==0.9.6 -tomli==2.0.1 -tomlkit==0.12.0 -toolz==0.12.1 -torch==2.2.0 -torchdata==0.6.1 -torchdiffeq==0.2.3 -TorchFix==0.4.0 -torchtext==0.15.2 -torchvision==0.17.0 -tqdm==4.66.2 -transformers==4.44.2 -trl==0.9.6 -typer==0.12.5 -types-certifi==2021.10.8.3 -types-requests==2.31.0.20240125 -types-setuptools==69.0.0.20240125 -types-toml==0.10.8.7 -typing==3.7.4.3 -typing-inspect==0.8.0 -typing_extensions==4.9.0 -tyro==0.5.18 -tzdata==2023.3 -unique-names-generator==1.0.2 -urllib3==2.2.2 -uvicorn==0.22.0 -vector_quantize_pytorch==1.14.1 -virtualenv==20.23.0 -voyager==2.0.2 -wandb==0.16.2 -watchfiles==0.21.0 -wavedrom==2.0.3.post3 -wcwidth==0.2.6 -websocket-client==1.7.0 -websockets==12.0 -Werkzeug==3.0.1 -wonderwords==2.2.0 -xxhash==3.2.0 -yarl==1.8.2 -zetascale==2.2.7 -zipp==3.15.0 diff --git a/setup.py b/setup.py index 8b2f1b2a5..4c3024b85 100644 --- a/setup.py +++ b/setup.py @@ -16,9 +16,7 @@ def parse_requirements(): with open("./requirements.txt", encoding="utf-8") as requirements_file: lines = [r.strip() for r in requirements_file.readlines()] for line in lines: - is_extras = ( - "deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line - ) + is_extras = "deepspeed" in line or "mamba-ssm" in line if line.startswith("--extra-index-url"): # Handle custom index URLs _, url = line.split() @@ -135,15 +133,15 @@ setup( "mlflow": [ "mlflow", ], - "lion-pytorch": [ - "lion-pytorch==0.1.2", - ], "galore": [ "galore_torch", ], + "apollo": [ + "apollo-torch", + ], "optimizers": [ "galore_torch", - "lion-pytorch==0.1.2", + "apollo-torch", "lomo-optim==0.1.1", "torch-optimi==0.2.1", ],