Compare commits

...

26 Commits

Author SHA1 Message Date
Wing Lian
60763b2e61 fix missing return 2024-11-14 10:14:13 -05:00
Wing Lian
082a41af9d add check for broken fsdp+grad_accum 2024-11-14 10:12:57 -05:00
Wing Lian
2d7830fda6 upgrade to flash-attn 2.7.0 (#2048) 2024-11-14 06:59:25 -05:00
Wing Lian
5e98cdddac Grokfast support (#1917) 2024-11-13 17:10:36 -05:00
Sunny Liu
1d7aee0ad2 ADOPT optimizer integration (#2032) [skip ci]
* adopt integration

* stuff

* doc and test for ADOPT

* rearrangement

* fixed formatting

* hacking pre-commit

* chore: lint

* update module doc for adopt optimizer

* remove un-necessary example yaml for adopt optimizer

* skip test adopt if torch<2.5.1

* formatting

* use version.parse

* specifies required torch version for adopt_adamw

---------

Co-authored-by: sunny <sunnyliu19981005@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-11-13 17:10:17 -05:00
Wing Lian
659ee5d723 don't cancel the tests on main automatically for concurrency (#2055) [skip ci] 2024-11-13 17:07:41 -05:00
Sunny Liu
342935cff3 Update unsloth for torch.cuda.amp deprecation (#2042)
* update deprecated unsloth tirch cuda amp  decorator

* WIP fix torch.cuda.amp deprecation

* lint

* laxing torch version requirement

* remove use of partial

* remove use of partial

* lint

---------

Co-authored-by: sunny <sunnyliu19981005@gmail.com>
2024-11-13 15:17:34 -05:00
Wing Lian
c5eb9ea2c2 fix push to main and tag semver build for docker ci (#2054) 2024-11-13 14:04:28 -05:00
Wing Lian
f2145a3ccb add default torch version if not installed, and support for xformers new wheels (#2049) 2024-11-13 13:16:47 -05:00
Wing Lian
010d0e7ff3 retry flaky test_packing_stream_dataset test that timesout on read (#2052) [skip ci] 2024-11-13 13:16:16 -05:00
Wing Lian
01881c3113 make sure to tag images in docker for tagged releases (#2051) [skip ci]
* make sure to tag images in docker for tagged releases

* fix tag event
2024-11-13 13:15:49 -05:00
Wing Lian
0e8eb96e07 run pypi release action on tag create w version (#2047) 2024-11-13 10:21:48 -05:00
NanoCode012
4e1891b12b feat: upgrade to liger 0.4.1 (#2045) 2024-11-13 10:07:24 -05:00
NanoCode012
28924fc791 feat: cancel ongoing tests if new CI is triggered (#2046) [skip ci] 2024-11-13 10:06:59 -05:00
NanoCode012
8c480b2804 fix: inference not using chat_template (#2019) [skip ci] 2024-11-13 10:06:41 -05:00
Oliver Molenschot
a4b1cc6df0 Add example YAML file for training Mistral using DPO (#2029) [skip ci]
* Add example YAML file for training Mistral using DPO

* chore: lint

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update mistral-dpo.yml 

Adding qlora and removing role-related data (unecessary)

* Rename mistral-dpo.yml to mistral-dpo-qlora.yml

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2024-11-13 10:06:25 -05:00
NanoCode012
7b78a31593 feat: print out dataset length even if not preprocess (#2034) [skip ci] 2024-11-13 10:06:00 -05:00
Wing Lian
810ebc2c0e invert the string in string check for p2p device check (#2044) 2024-11-12 23:20:47 -05:00
Wing Lian
ad435a3b09 add P2P env when multi-gpu but not the full node (#2041)
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-11-12 17:58:26 -05:00
NanoCode012
9f1cf9b17c fix: handle sharegpt dataset missing (#2035)
* fix: handle sharegpt dataset missing

* fix: explanation

* feat: add test
2024-11-12 12:51:37 +07:00
Wing Lian
3931a42763 change deprecated modal Stub to App (#2038) 2024-11-11 15:10:34 -05:00
NanoCode012
dc8f9059f7 feat: add metharme chat_template (#2033) [skip ci]
* feat: add metharme chat_template

* fix: add eos token
2024-11-11 15:09:58 -05:00
Wing Lian
234e94e9dd replace references to personal docker hub to org docker hub (#2036) [skip ci] 2024-11-11 15:09:29 -05:00
Wing Lian
f68fb71005 update actions version for node16 deprecation (#2037) [skip ci]
* update actions version for node16 deprecation

* update pre-commit/action to use 3.0.1 for actions/cache@v4 dep

* update docker/setup-buildx-action too to v3
2024-11-11 15:09:11 -05:00
Wing Lian
9bc3ee6c75 add axolotlai docker hub org to publish list (#2031)
* add axolotlai docker hub org to publish list

* fix to use latest actions docker metadata version

* fix list in yaml for expected format for action

* missed a change
2024-11-11 09:48:19 -05:00
Wing Lian
d356740ffa move deprecated kwargs from trainer to trainingargs (#2028) 2024-11-10 12:45:47 -05:00
43 changed files with 1131 additions and 101 deletions

View File

@@ -44,19 +44,21 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v3
uses: docker/metadata-action@v5
with:
images: winglian/axolotl-base
images: |
winglian/axolotl-base
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v4
with:

View File

@@ -17,7 +17,7 @@ jobs:
- name: Set up Quarto
uses: quarto-dev/quarto-actions/setup@v2
- name: Setup Python
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: install dependencies

View File

@@ -15,9 +15,9 @@ jobs:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
- uses: pre-commit/action@v3.0.1

View File

@@ -4,6 +4,8 @@ on:
push:
branches:
- "main"
tags:
- "v*"
workflow_dispatch:
jobs:
@@ -42,7 +44,12 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
@@ -56,7 +63,7 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
@@ -104,20 +111,22 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl-cloud
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud
push: ${{ github.event_name != 'pull_request' }}
@@ -146,20 +155,25 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl-cloud-term
images: |
winglian/axolotl-cloud-term
axolotlai/axolotl-cloud-term
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-no-tmux
push: ${{ github.event_name != 'pull_request' }}

View File

@@ -41,7 +41,9 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
- name: Set up Docker Buildx
@@ -103,7 +105,9 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl-cloud
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
- name: Login to Docker Hub
@@ -112,7 +116,7 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:

View File

@@ -3,13 +3,31 @@ name: publish pypi
on:
push:
tags:
- '*'
- 'v*'
workflow_dispatch:
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [setup_release]
environment:
name: pypi
url: https://pypi.org/p/axolotl
@@ -17,10 +35,10 @@ jobs:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.10"

View File

@@ -9,12 +9,12 @@ jobs:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -30,10 +30,10 @@ jobs:
steps:
- name: Check out repository code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies

View File

@@ -15,17 +15,22 @@ on:
- '.github/workflows/*.yml'
workflow_dispatch:
# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -41,10 +46,10 @@ jobs:
steps:
- name: Check out repository code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies

View File

@@ -159,7 +159,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
#### Docker
```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
Or run on the current files for development:
@@ -178,7 +178,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
```
It additionally:
@@ -210,7 +210,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
#### Cloud GPU
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
@@ -319,7 +319,7 @@ Write a job description in YAML as below:
# dstack.yaml
type: task
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
image: axolotlai/axolotl-cloud:main-latest
env:
- HUGGING_FACE_HUB_TOKEN

View File

@@ -1,4 +1,4 @@
FROM winglian/axolotl-base:{{ BASE_TAG }}
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"

View File

@@ -10,7 +10,7 @@ import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -46,7 +46,7 @@ cicd_image = (
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
app = App("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 2))
@@ -61,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
@app.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
@@ -72,6 +72,6 @@ def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
@stub.local_entrypoint()
@app.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -10,7 +10,7 @@ import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -47,7 +47,7 @@ cicd_image = (
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
app = App("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 1))
@@ -62,7 +62,7 @@ def run_cmd(cmd: str, run_folder: str):
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
@app.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
@@ -73,6 +73,6 @@ def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
@stub.local_entrypoint()
@app.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main-base
FROM winglian/axolotl-base:$BASE_TAG
FROM axolotlai/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""

View File

@@ -35,7 +35,3 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
pip3 install flash-attn==2.6.3; \
fi

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main
FROM winglian/axolotl:$BASE_TAG
FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main
FROM winglian/axolotl:$BASE_TAG
FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main-base
FROM winglian/axolotl-base:$BASE_TAG
FROM axolotlai/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""

View File

@@ -405,6 +405,7 @@ lr_div_factor: # Learning rate div factor
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_apex_fused
# - adopt_adamw (only for torch version >= 2.5.1)
# - adafactor
# - adamw_anyprecision
# - sgd

View File

@@ -185,7 +185,7 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
## Debugging With Docker
Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
Using [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
### Setup
@@ -202,11 +202,11 @@ cd axolotl
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
```
>[!Tip]
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/winglian/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
You will now be in the container. Next, perform an editable install of Axolotl:

View File

@@ -44,7 +44,7 @@
"outputs": [],
"source": [
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install flash-attn==\"2.7.0.post2\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
]
},

View File

@@ -0,0 +1,93 @@
#Note that we are switching from the regular chat template to chatml.
#If you experience problems with the special tokens, training for more epochs can help.
#After training, merge the model before inference otherwise you might
#face problems with the special tokens.
base_model: mistralai/Mistral-7B-Instruct-v0.2
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
chat_template: chatml
rl: dpo
datasets:
- path: olivermolenschot/alpaca_messages_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-qlora
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true
adapter: qlora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.2
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 16
num_epochs: 6
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<|im_start|>"
eos_token: "<|im_end|>"

View File

@@ -1,2 +1,3 @@
pytest
pytest-xdist
pytest-retry

View File

@@ -12,7 +12,7 @@ addict
fire
PyYAML>=6.0
requests
flash-attn==2.6.3
flash-attn==2.7.0.post2
sentencepiece
wandb
einops
@@ -33,7 +33,7 @@ tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=2.3.0
liger-kernel==0.4.0
liger-kernel==0.4.1
mamba-ssm==1.2.0.post1

View File

@@ -39,7 +39,10 @@ def parse_requirements():
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
torch_version = version("torch")
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -54,6 +57,10 @@ def parse_requirements():
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
@@ -98,7 +105,7 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.6.3",
"flash-attn==2.7.0.post2",
],
"deepspeed": [
"deepspeed==0.14.4",

View File

@@ -190,18 +190,15 @@ def do_inference(
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -211,13 +208,31 @@ def do_inference(
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
@@ -257,13 +272,6 @@ def do_inference_gradio(
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
default_tokens: Dict[str, str] = {}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None

View File

@@ -436,7 +436,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()
@@ -505,6 +511,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -1273,6 +1287,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def _get_trainer_cls(self):
@@ -1625,11 +1651,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
@@ -1933,6 +1961,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
@@ -1956,7 +1990,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
@@ -1970,12 +2003,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]

View File

@@ -140,7 +140,7 @@ class BasePlugin:
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer before training.
setup callbacks before creating the trainer.
Parameters:
cfg (dict): The configuration for the plugin.
@@ -155,14 +155,15 @@ class BasePlugin:
self, cfg, trainer
): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer after training.
Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
List[callable]: A list of callback functions to be added
"""
return []
@@ -393,7 +394,9 @@ class PluginManager:
"""
callbacks = []
for plugin in self.plugins.values():
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
if plugin_callbacks: # if the plugin returned a list of callbacks
callbacks.extend(plugin_callbacks)
return callbacks
def add_callbacks_post_trainer(self, cfg, trainer):
@@ -409,7 +412,9 @@ class PluginManager:
"""
callbacks = []
for plugin in self.plugins.values():
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
if plugin_callbacks:
callbacks.extend(plugin_callbacks)
return callbacks
def post_train_unload(self, cfg):

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,13 @@
# Grokfast Optimizer
See https://github.com/ironjr/grokfast
### Usage
```yaml
plugins:
- axolotl.integrations.grokfast.GrokfastPlugin
grokfast_alpha: 2.0
grokfast_lamb: 0.98
```

View File

@@ -0,0 +1,50 @@
"""
Grokfast plugin for Axolotl
"""
import logging
from transformers.trainer_callback import TrainerCallback
from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .optimizer import gradfilter_ema
LOG = logging.getLogger("axolotl.integrations.grokfast")
class GrokfastCallbackHandler(TrainerCallback):
"""
Transformer trainer callbacks for Grokfast
"""
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
super().__init__(*args_, **kwargs)
self.grads = None
self.alpha = alpha
self.lamb = lamb
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
self.grads = None
def on_pre_optimizer_step(
self, args_, state, control, **kwargs
): # pylint: disable=unused-argument
model = kwargs.pop("model")
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
return control
class GrokfastPlugin(BasePlugin):
"""
Plugin for Grokfast optimizer integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.grokfast.GrokfastArgs"
def add_callbacks_post_trainer(self, cfg, trainer):
LOG.info("Adding Grokfast callback to the trainer")
callback = GrokfastCallbackHandler(
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
)
return [callback]

View File

@@ -0,0 +1,15 @@
"""
config args for grokfast plugin
"""
from typing import Optional
from pydantic import BaseModel
class GrokfastArgs(BaseModel):
"""
Input args for Grokfast optimizer.
"""
grokfast_alpha: Optional[float] = 0.98
grokfast_lamb: Optional[float] = 2.0

View File

@@ -0,0 +1,63 @@
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
# Reference: https://github.com/ironjr/grokfast
# pylint: skip-file
from collections import deque
from typing import Dict, Literal, Optional
import torch
import torch.nn as nn
def gradfilter_ma(
m: nn.Module,
grads: Optional[Dict[str, deque]] = None,
window_size: int = 100,
lamb: float = 5.0,
filter_type: Literal["mean", "sum"] = "mean",
warmup: bool = True,
trigger: bool = False, # For ablation study.
) -> Dict[str, deque]:
if grads is None:
grads = {
n: deque(maxlen=window_size)
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}
for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
grads[n].append(p.grad.data.detach()) # .cpu())
# Modify the gradients.
if not warmup or len(grads[n]) == window_size and not trigger:
if filter_type == "mean":
avg = sum(grads[n]) / len(grads[n])
elif filter_type == "sum":
avg = sum(grads[n])
else:
raise ValueError(f"Unrecognized filter_type {filter_type}")
p.grad.data = p.grad.data + avg * lamb
return grads
def gradfilter_ema(
m: nn.Module,
grads: Optional[Dict[str, torch.Tensor]] = None,
alpha: float = 0.98,
lamb: float = 2.0,
) -> Dict[str, torch.Tensor]:
if grads is None:
grads = {
n: p.grad.data.detach()
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}
for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
p.grad.data = p.grad.data + grads[n] * lamb
return grads

View File

@@ -23,6 +23,7 @@ import logging
import sys
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
@@ -82,7 +83,9 @@ class LigerPlugin(BasePlugin):
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "deepseek_v2":
@@ -106,6 +109,8 @@ class LigerPlugin(BasePlugin):
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward

File diff suppressed because one or more lines are too long

View File

@@ -58,6 +58,7 @@ class ChatTemplate(str, Enum):
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
@@ -427,6 +428,7 @@ class HyperparametersConfig(BaseModel):
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF.value
@@ -781,6 +783,8 @@ class AxolotlInputConfig(
is_mistral_derived_model: Optional[bool] = Field(default=None)
is_qwen_derived_model: Optional[bool] = Field(default=None)
plugins: Optional[List[str]] = Field(default=None)
@field_validator("datasets", mode="before")
@classmethod
def deprecate_sharegpt_datasets(cls, datasets):
@@ -788,7 +792,12 @@ class AxolotlInputConfig(
if not ds_cfg.get("type"):
continue
if ds_cfg["type"].startswith("sharegpt"):
ds_type = ds_cfg["type"]
# skip if it's a dict (for custom user instruction prompt)
if isinstance(ds_type, dict):
continue
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
raise ValueError(
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
)
@@ -1393,6 +1402,17 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_grad_accum_4_46_2(cls, data):
if data.get("fsdp") and data.get("gradient_accumulation_steps") > 1:
if version("transformers") == "4.46.2":
raise ValueError(
"FSDP w/ gradient_accumulation_steps > 1 is broken with transformers==4.46.2. "
"Please use a lower value or switch to an older version of transformers."
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""

View File

@@ -0,0 +1,25 @@
"""
utils to get GPU info for the current environment
"""
from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
)
from accelerate.utils.environment import get_gpu_info
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
unsupported_devices = {"RTX 6000 Ada"}
try:
device_names, device_count = get_gpu_info()
if 1 < device_count < 8:
if any(
unsupported_device in device_name
for device_name in device_names
for unsupported_device in unsupported_devices
):
return False
except Exception: # pylint: disable=broad-except # nosec
pass
return True

View File

@@ -14,6 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from packaging import version
torch_version = version.parse(torch.__version__)
if torch_version < version.parse("2.4.0"):
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
@@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
"""
@staticmethod
@torch.cuda.amp.custom_fwd
@torch_cuda_amp_custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
@@ -36,7 +46,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
return output
@staticmethod
@torch.cuda.amp.custom_bwd
@torch_cuda_amp_custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()

View File

@@ -0,0 +1,508 @@
"""
Copied from https://github.com/iShohei220/adopt
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
"""
# mypy: ignore-errors
# pylint: skip-file
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union, cast
import torch
from torch import Tensor
from torch.optim.optimizer import (
Optimizer,
ParamsT,
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_disable_dynamo_if_unsupported,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_use_grad_for_differentiable,
_view_as_real,
)
__all__ = ["ADOPT", "adopt"]
class ADOPT(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6,
weight_decay: float = 0.0,
decoupled: bool = False,
*,
foreach: Optional[bool] = None,
maximize: bool = False,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
):
if isinstance(lr, Tensor):
if foreach and not capturable:
raise ValueError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
if lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
decoupled=decoupled,
maximize=maximize,
foreach=foreach,
capturable=capturable,
differentiable=differentiable,
fused=fused,
)
super().__init__(params, defaults)
if fused:
# TODO: support fused
raise RuntimeError("`fused` is not currently supported")
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
self._step_supports_amp_scaling = True
# TODO(crcrpar): [low prec params & their higher prec copy]
# Support AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("maximize", False)
group.setdefault("foreach", None)
group.setdefault("capturable", False)
group.setdefault("differentiable", False)
fused = group.setdefault("fused", None)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val,
dtype=_get_scalar_dtype(is_fused=fused),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(
self,
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
):
has_complex = False
for p in group["params"]:
if p.grad is not None:
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("ADOPT does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
if group["fused"]:
_device_dtype_check_for_fused(p)
# note(crcrpar): [special device hosting for step]
# Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (
torch.zeros(
(),
dtype=_get_scalar_dtype(is_fused=group["fused"]),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if group["differentiable"] and state["step"].requires_grad:
raise RuntimeError(
"`requires_grad` is not supported for `step` in differentiable mode"
)
# Foreach without capturable does not support a tensor lr
if (
group["foreach"]
and torch.is_tensor(group["lr"])
and not group["capturable"]
):
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
beta1, beta2 = group["betas"]
has_complex = self._init_group(
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
)
adopt(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
has_complex=has_complex,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
decoupled=group["decoupled"],
eps=group["eps"],
maximize=group["maximize"],
foreach=group["foreach"],
capturable=group["capturable"],
differentiable=group["differentiable"],
fused=group["fused"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
)
return loss
def _single_tensor_adopt(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
has_complex: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
decoupled: bool,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
):
assert grad_scale is None and found_inf is None
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
# have overloads to handle both float and Tensor lrs, so we just assert it's
# a float since most people using JIT are using floats
assert isinstance(lr, float)
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
# update step
step_t += 1
if weight_decay != 0:
if decoupled:
param.add_(param, alpha=-lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
grad = torch.view_as_real(grad)
if exp_avg is not None:
exp_avg = torch.view_as_real(exp_avg)
if exp_avg_sq is not None:
exp_avg_sq = torch.view_as_real(exp_avg_sq)
param = torch.view_as_real(param)
step = step_t if capturable or differentiable else _get_value(step_t)
if step == 1:
exp_avg_sq.addcmul_(grad, grad.conj())
continue
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
if step == 2:
exp_avg.addcdiv_(grad, denom)
else:
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
def _multi_tensor_adopt(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
has_complex: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
decoupled: bool,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
):
if len(params) == 0:
return
if isinstance(lr, Tensor) and not capturable:
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
assert grad_scale is None and found_inf is None
assert not differentiable, "_foreach ops don't support autograd"
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
)
for (
device_params_,
device_grads_,
device_exp_avgs_,
device_exp_avg_sqs_,
device_state_steps_,
), _ in grouped_tensors.values():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(List[Tensor], device_state_steps_)
# Handle complex parameters
if has_complex:
_view_as_real(
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0:
if decoupled:
torch._foreach_add_(
device_params, device_params, alpha=-lr * weight_decay
)
else:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 1:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
continue
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
if device_state_steps[0] == 2:
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
else:
torch._foreach_mul_(device_exp_avgs, beta1)
torch._foreach_addcdiv_(
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
def adopt(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
decoupled: bool,
eps: float,
maximize: bool,
):
r"""Functional API that performs ADOPT algorithm computation."""
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
if foreach and isinstance(lr, Tensor) and not capturable:
foreach = False
if fused is None:
fused = False
if foreach is None:
foreach = False
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
# if fused and not torch.jit.is_scripting():
# func = _fused_adopt
# elif foreach and not torch.jit.is_scripting():
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_adopt
else:
func = _single_tensor_adopt
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
has_complex=has_complex,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
decoupled=decoupled,
eps=eps,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
grad_scale=grad_scale,
found_inf=found_inf,
)

View File

@@ -17,6 +17,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger("axolotl")
@@ -184,11 +185,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
min_sequence_len=cfg.min_sample_len or 2,
)
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
@@ -461,6 +461,9 @@ def setup_fsdp_envs(cfg):
def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None:
os.environ["NCCL_P2P_DISABLE"] = "1"
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import require_torch_2_5_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -65,3 +65,46 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
@require_torch_2_5_1
def test_adopt_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adopt_adamw",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -6,11 +6,13 @@ import shutil
import tempfile
import unittest
from functools import wraps
from importlib.metadata import version
from pathlib import Path
import torch
# from importlib.metadata import version
from packaging import version
def with_temp_dir(test_func):
@wraps(test_func)
@@ -43,12 +45,24 @@ def require_torch_2_3_1(test_case):
"""
def is_min_2_3_1():
torch_version = version("torch")
return torch_version >= "2.3.1"
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.3.1")
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
def require_torch_2_5_1(test_case):
"""
Decorator marking a test that requires torch >= 2.3.1
"""
def is_min_2_5_1():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.5.1")
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)

View File

@@ -2,6 +2,7 @@
import functools
import unittest
import pytest
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
@@ -21,6 +22,7 @@ class TestPretrainingPacking(unittest.TestCase):
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>"
@pytest.mark.flaky(retries=3, delay=5)
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
dataset = load_dataset(

View File

@@ -234,3 +234,59 @@ class TestValidationCheckDatasetConfig(BaseValidation):
)
_check_config()
def test_dataset_sharegpt_deprecation(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "sharegpt",
"conversation": "chatml",
}
],
}
)
# Check sharegpt deprecation is raised
with pytest.raises(ValueError, match=r".*type: sharegpt.*` is deprecated.*"):
validate_config(cfg)
# Check that deprecation is not thrown for non-str type
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": {
"field_instruction": "instruction",
"field_output": "output",
"field_system": "system",
"format": "<|user|> {instruction} {input} <|model|>",
"no_input_format": "<|user|> {instruction} <|model|>",
"system_prompt": "",
},
}
],
}
)
validate_config(cfg)
# Check that deprecation is not thrown for non-sharegpt type
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
}
)
validate_config(cfg)