Compare commits
42 Commits
quartodoc
...
lora-kerne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
700409be6f | ||
|
|
64d8035f50 | ||
|
|
5249e98058 | ||
|
|
3877c5c69d | ||
|
|
adb593abac | ||
|
|
a0117c9bce | ||
|
|
e6cfb093d2 | ||
|
|
7abc71dc0b | ||
|
|
45bf634d17 | ||
|
|
80ba4b69f1 | ||
|
|
0bfa180f7d | ||
|
|
9e22c4ca6a | ||
|
|
990b5896bc | ||
|
|
7d0eb66b54 | ||
|
|
df119e3724 | ||
|
|
f4ae8816bb | ||
|
|
9b95e06cbb | ||
|
|
e0aba74dd0 | ||
|
|
328d598114 | ||
|
|
4d36ecc724 | ||
|
|
7acf93b59f | ||
|
|
b6fc46ada8 | ||
|
|
b35992262e | ||
|
|
ef6eb77cc8 | ||
|
|
5410195e0b | ||
|
|
cf0c79d52e | ||
|
|
4ba80a0e5a | ||
|
|
c49682132b | ||
|
|
e46239f8d3 | ||
|
|
05f03b541a | ||
|
|
a4e430e7c4 | ||
|
|
6cdcb8ddd5 | ||
|
|
a7811ad4a0 | ||
|
|
e2da821e67 | ||
|
|
2c34a4634e | ||
|
|
a9b0733f2c | ||
|
|
9f00465a5c | ||
|
|
86bac48d14 | ||
|
|
e44953d50c | ||
|
|
23f0c51d88 | ||
|
|
113e9cd193 | ||
|
|
61825a464a |
14
.github/workflows/base.yml
vendored
14
.github/workflows/base.yml
vendored
@@ -40,12 +40,24 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: nightly
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: next
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -67,7 +79,7 @@ jobs:
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
|
||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter quartodoc
|
||||
python3 -m pip install -e .
|
||||
python3 -m pip install -e . --no-deps
|
||||
- name: Build autodoc
|
||||
run: quartodoc build
|
||||
- name: Publish to GitHub Pages (and render)
|
||||
|
||||
4
.github/workflows/main.yml
vendored
4
.github/workflows/main.yml
vendored
@@ -25,12 +25,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -87,12 +87,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
3
.github/workflows/multi-gpu-e2e.yml
vendored
3
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -42,8 +42,7 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
# awaiting vllm#12721
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
|
||||
25
.github/workflows/tests-nightly.yml
vendored
25
.github/workflows/tests-nightly.yml
vendored
@@ -33,6 +33,15 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -46,7 +55,7 @@ jobs:
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Update requirements.txt
|
||||
run: |
|
||||
@@ -58,8 +67,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
@@ -73,10 +81,15 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest tests/patched/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -136,4 +149,4 @@ jobs:
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
23
.github/workflows/tests.yml
vendored
23
.github/workflows/tests.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -96,10 +96,15 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -136,7 +141,7 @@ jobs:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -170,10 +175,14 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Show HF cache
|
||||
run: huggingface-cli scan-cache
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -227,7 +236,7 @@ jobs:
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
@@ -251,7 +260,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -274,4 +283,4 @@ jobs:
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
[settings]
|
||||
profile=black
|
||||
known_third_party=wandb,comet_ml
|
||||
known_local_folder=src,tests
|
||||
|
||||
@@ -40,6 +40,7 @@ quartodoc:
|
||||
- cli.preprocess
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.vllm_serve
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- title: Trainers
|
||||
@@ -133,6 +134,7 @@ quartodoc:
|
||||
- utils.schemas.datasets
|
||||
- utils.schemas.peft
|
||||
- utils.schemas.trl
|
||||
- utils.schemas.multimodal
|
||||
- utils.schemas.integrations
|
||||
- utils.schemas.enums
|
||||
- utils.schemas.utils
|
||||
@@ -242,6 +244,7 @@ website:
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -3,9 +3,10 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -2,4 +2,5 @@
|
||||
set -e
|
||||
|
||||
# only run one test at a time so as not to OOM the GPU
|
||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
pytest -v -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
|
||||
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
38
docker/Dockerfile-base-next
Normal file
38
docker/Dockerfile-base-next
Normal file
@@ -0,0 +1,38 @@
|
||||
ARG CUDA_VERSION="12.8.1"
|
||||
ARG CUDNN_VERSION="8"
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="next"
|
||||
ARG CUDA="128"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
pip3 install -U --no-cache-dir pydantic==2.10.6
|
||||
40
docs/cli.qmd
40
docs/cli.qmd
@@ -170,7 +170,7 @@ axolotl merge-sharded-fsdp-weights config.yml
|
||||
|
||||
### evaluate
|
||||
|
||||
Evaluates a model's performance using metrics specified in the config.
|
||||
Evaluates a model's performance (loss etc) on the train and eval datasets.
|
||||
|
||||
```bash
|
||||
# Basic evaluation
|
||||
@@ -197,6 +197,8 @@ lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
```
|
||||
|
||||
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
|
||||
|
||||
## Legacy CLI Usage
|
||||
|
||||
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
|
||||
@@ -235,7 +237,7 @@ Create a cloud config YAML with your Modal settings:
|
||||
```yaml
|
||||
# cloud_config.yml
|
||||
provider: modal
|
||||
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
|
||||
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
|
||||
gpu_count: 1 # Number of GPUs to use
|
||||
timeout: 86400 # Maximum runtime in seconds (24 hours)
|
||||
branch: main # Git branch to use (optional)
|
||||
@@ -248,7 +250,7 @@ volumes: # Persistent storage volumes
|
||||
- name: axolotl-artifacts
|
||||
mount: /workspace/artifacts
|
||||
|
||||
env: # Environment variables
|
||||
secrets: # Secrets to inject
|
||||
- WANDB_API_KEY
|
||||
- HF_TOKEN
|
||||
```
|
||||
@@ -274,15 +276,27 @@ axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
### Cloud Configuration Options
|
||||
|
||||
```yaml
|
||||
provider: # compute provider, currently only `modal` is supported
|
||||
gpu: # GPU type to use
|
||||
gpu_count: # Number of GPUs (default: 1)
|
||||
memory: # RAM in GB (default: 128)
|
||||
timeout: # Maximum runtime in seconds
|
||||
provider: # compute provider, currently only `modal` is supported
|
||||
gpu: # GPU type to use
|
||||
gpu_count: # Number of GPUs (default: 1)
|
||||
memory: # RAM in GB (default: 128)
|
||||
timeout: # Maximum runtime in seconds
|
||||
timeout_preprocess: # Preprocessing timeout
|
||||
branch: # Git branch to use
|
||||
docker_tag: # Custom Docker image tag
|
||||
volumes: # List of persistent storage volumes
|
||||
env: # Environment variables to pass
|
||||
secrets: # Secrets to inject
|
||||
branch: # Git branch to use
|
||||
docker_tag: # Custom Docker image tag
|
||||
volumes: # List of persistent storage volumes
|
||||
|
||||
# Environment variables to pass. Can be specified in two ways:
|
||||
# 1. As a string: Will load the value from the host computer's environment variables
|
||||
# 2. As a key-value pair: Will use the specified value directly
|
||||
# Example:
|
||||
# env:
|
||||
# - CUSTOM_VAR # Loads from host's $CUSTOM_VAR
|
||||
# - {CUSTOM_VAR: "value"} # Uses "value" directly
|
||||
env:
|
||||
|
||||
# Secrets to inject. Same input format as `env` but for sensitive data.
|
||||
secrets:
|
||||
# - HF_TOKEN
|
||||
# - WANDB_API_KEY
|
||||
```
|
||||
|
||||
115
docs/config.qmd
115
docs/config.qmd
@@ -32,6 +32,9 @@ tokenizer_legacy:
|
||||
resize_token_embeddings_to_32x:
|
||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||
shrink_embeddings:
|
||||
# Whether to load the model with randomly initialized weights. Useful for
|
||||
# pre-training a model from scratch or debugging purposes.
|
||||
random_init_weights:
|
||||
|
||||
# (Internal use only)
|
||||
# Used to identify which the model is based on
|
||||
@@ -235,10 +238,10 @@ simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||
# grpo
|
||||
trl:
|
||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||
vllm_device: # Optional[str]. Device to use for VLLM.
|
||||
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
||||
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
||||
vllm_dtype: # Optional[str]. Data type for VLLM.
|
||||
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
|
||||
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
|
||||
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
|
||||
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
|
||||
|
||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||
@@ -317,9 +320,13 @@ total_num_tokens:
|
||||
sample_packing_group_size: 100000
|
||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||
sample_packing_bin_size: 200
|
||||
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
|
||||
|
||||
# whether to concatenate samples during pretraining
|
||||
pretraining_sample_concatenation:
|
||||
|
||||
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
|
||||
|
||||
# Use batch flattening for speedups when not using sample_packing
|
||||
batch_flattening:
|
||||
|
||||
@@ -351,7 +358,27 @@ lora_target_modules:
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
lora_target_linear: # If true, will target all linear modules
|
||||
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
|
||||
|
||||
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
|
||||
peft_layers_to_transform:
|
||||
|
||||
# Optional[bool]. Whether to use DoRA.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
|
||||
peft_use_dora:
|
||||
|
||||
# Optional[bool]. Whether to use RSLoRA.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
|
||||
peft_use_rslora:
|
||||
|
||||
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
|
||||
peft_layer_replication:
|
||||
|
||||
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
|
||||
# How to initialize LoRA weights. Default to True which is MS original implementation.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
|
||||
peft_init_lora_weights:
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
@@ -463,6 +490,7 @@ auto_find_batch_size: # Optional[bool]
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
|
||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
||||
@@ -482,7 +510,8 @@ train_on_inputs: false
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||
# gradient_checkpointing_kwargs:
|
||||
@@ -503,36 +532,58 @@ lr_div_factor: # Learning rate div factor
|
||||
|
||||
# Specify optimizer
|
||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
|
||||
#
|
||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||
# in the examples/ for your model and fine-tuning use case.
|
||||
#
|
||||
# Valid values for 'optimizer' include:
|
||||
# - adamw_hf
|
||||
# - adamw_torch
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_torch_npu_fused
|
||||
# - adamw_apex_fused
|
||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - adamw_torch_4bit
|
||||
# - ademamix
|
||||
# - sgd
|
||||
# - adagrad
|
||||
# - adamw_bnb_8bit
|
||||
# - adamw_8bit # alias for adamw_bnb_8bit
|
||||
# - ademamix_8bit
|
||||
# - lion_8bit
|
||||
# - lion_32bit
|
||||
# - paged_adamw_32bit
|
||||
# - paged_adamw_8bit
|
||||
# - paged_ademamix_32bit
|
||||
# - paged_ademamix_8bit
|
||||
# - paged_lion_32bit
|
||||
# - paged_lion_8bit
|
||||
# - rmsprop
|
||||
# - rmsprop_bnb
|
||||
# - rmsprop_bnb_8bit
|
||||
# - rmsprop_bnb_32bit
|
||||
# - galore_adamw
|
||||
# - galore_adamw_8bit
|
||||
# - galore_adafactor
|
||||
# - galore_adamw_layerwise
|
||||
# - galore_adamw_8bit_layerwise
|
||||
# - galore_adafactor_layerwise
|
||||
# - lomo
|
||||
# - adalomo
|
||||
# - grokadamw
|
||||
# - schedule_free_adamw
|
||||
# - schedule_free_sgd
|
||||
# - apollo_adamw
|
||||
# - apollo_adamw_layerwise
|
||||
#
|
||||
# Additional custom optimizers include:
|
||||
# - optimi_adamw
|
||||
# - ao_adamw_8bit
|
||||
# - ao_adamw_fp8
|
||||
optimizer:
|
||||
# Dictionary of arguments to pass to the optimizer
|
||||
optim_args:
|
||||
@@ -561,29 +612,42 @@ max_grad_norm:
|
||||
# currently only supported on Llama and Mistral
|
||||
neftune_noise_alpha:
|
||||
|
||||
# Whether to bettertransformers
|
||||
# Optional[bool]. Whether to bettertransformers
|
||||
flash_optimum:
|
||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
|
||||
# Note: Only one of the following attention patches can be used at a time.
|
||||
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
|
||||
|
||||
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# Whether to use scaled-dot-product attention
|
||||
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
|
||||
# Optional[bool]. Whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||
s2_attention:
|
||||
|
||||
# Optional[bool]. Whether to use low_cpu_mem_usage
|
||||
low_cpu_mem_usage:
|
||||
# Resume from a specific checkpoint dir
|
||||
# Optional[str]. Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Be careful with this being turned on between different models.
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
## Multimodal section
|
||||
# int | tuple[int, int] | None . Size to resize images to, width x height.
|
||||
# Will read from model/processor config if not set.
|
||||
image_size:
|
||||
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
|
||||
image_resize_algorithm: 'bilinear'
|
||||
## End of multimodal section
|
||||
|
||||
# Don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
|
||||
@@ -617,6 +681,17 @@ ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Sequence parallelism
|
||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
heads_k_stride: 1
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
|
||||
@@ -103,8 +103,7 @@ This uses the same tags as the [`main` image](#sec-main-tags).
|
||||
|
||||
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
||||
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
||||
- `PUBLIC_KEY`: Add a public key for the SSH service.
|
||||
- `SSH_KEY`: Add a private key for the SSH service.
|
||||
- `PUBLIC_KEY` / `SSH_KEY`: Add a public key for the SSH service.
|
||||
|
||||
#### Volume mounts
|
||||
|
||||
|
||||
16
docs/faq.qmd
16
docs/faq.qmd
@@ -35,7 +35,21 @@ description: Frequently asked questions
|
||||
|
||||
**Q: How to call Axolotl via custom python scripts?**
|
||||
|
||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||
> A: Since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||
|
||||
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
|
||||
|
||||
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
|
||||
|
||||
**Q: ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as pad_token**
|
||||
|
||||
> A: This is because the tokenizer does not have a padding token. Please add a padding token to the tokenizer via:
|
||||
|
||||
> ```yaml
|
||||
> special_tokens:
|
||||
> # str. If you're not sure, set to same as `eos_token`.
|
||||
> pad_token: "..."
|
||||
> ```
|
||||
|
||||
### Chat templates
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ We currently support several common model architectures, including (but not limi
|
||||
- `qwen2`
|
||||
- `gemma`
|
||||
- `gemma2`
|
||||
- `gemma3`
|
||||
|
||||
<details>
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ Axolotl supports several methods for multi-GPU training:
|
||||
|
||||
- DeepSpeed (recommended)
|
||||
- FSDP (Fully Sharded Data Parallel)
|
||||
- Sequence parallelism
|
||||
- FSDP + QLoRA
|
||||
|
||||
## DeepSpeed {#sec-deepspeed}
|
||||
@@ -66,6 +67,28 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
## Sequence parallelism {#sec-sequence-parallelism}
|
||||
|
||||
We support sequence parallelism (SP) via the
|
||||
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
|
||||
allows one to split up sequences across GPUs, which is useful in the event that a
|
||||
single sequence causes OOM errors during model training.
|
||||
|
||||
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
|
||||
or from source with `pip install .[ring-flash-attn]`.
|
||||
|
||||
Your Axolotl YAML config should contain the following lines:
|
||||
|
||||
```{.yaml}
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
|
||||
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
|
||||
heads_k_stride: 1
|
||||
```
|
||||
|
||||
See our [dedicated guide](sequence_parallelism.qmd) for more details.
|
||||
|
||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||
|
||||
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||
|
||||
@@ -1,28 +1,171 @@
|
||||
# MultiModal / Vision Language Models (BETA)
|
||||
---
|
||||
title: MultiModal / Vision Language Models (BETA)
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
---
|
||||
|
||||
### Supported Models
|
||||
## Supported Models
|
||||
|
||||
- Mllama, i.e. llama with vision models
|
||||
- [Mllama](#sec-mllama)
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
|
||||
### Usage
|
||||
## Usage
|
||||
|
||||
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
|
||||
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
|
||||
Multimodal support is limited and doesn't have full feature parity.
|
||||
|
||||
Here are the hyperparams you'll need to use to finetune a multimodal model.
|
||||
|
||||
```yaml
|
||||
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
|
||||
processor_type: AutoProcessor
|
||||
skip_prepare_dataset: true
|
||||
|
||||
chat_template: llama3_2_vision
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
||||
sample_packing: false # not yet supported with multimodal
|
||||
|
||||
chat_template: # see in next section
|
||||
|
||||
# example dataset
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# only finetune the Language model, leave the vision model and vision tower frozen
|
||||
# (optional) if doing lora, only finetune the Language model,
|
||||
# leave the vision model and vision tower frozen
|
||||
# load_in_8bit: true
|
||||
adapter: lora
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# (optional) if you want to resize images to a set size
|
||||
image_size: 512
|
||||
image_resize_algorithm: bilinear
|
||||
```
|
||||
|
||||
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
||||
|
||||
::: {.callout-warning}
|
||||
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
||||
:::
|
||||
|
||||
### Mllama {#sec-mllama}
|
||||
|
||||
```yaml
|
||||
base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
|
||||
chat_template: llama3_2_vision
|
||||
```
|
||||
|
||||
### Pixtral {#sec-pixtral}
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Pixtral-12B-2409
|
||||
|
||||
chat_template: pixtral
|
||||
```
|
||||
|
||||
### Llava-1.5 {#sec-llava-15}
|
||||
|
||||
```yaml
|
||||
base_model: llava-hf/llava-1.5-7b-hf
|
||||
|
||||
chat_template: llava
|
||||
```
|
||||
|
||||
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
```
|
||||
|
||||
### Gemma-3 {#sec-gemma-3}
|
||||
|
||||
::: {.callout-tip}
|
||||
The Gemma3-1B model is a text-only model, so please train as regular text model.
|
||||
:::
|
||||
|
||||
For multi-modal 4B/12B/27B models, use the following config:
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-3-4b-it
|
||||
|
||||
chat_template: gemma3
|
||||
```
|
||||
|
||||
### Qwen2-VL {#sec-qwen2-vl}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||
|
||||
chat_template: qwen2_vl
|
||||
```
|
||||
|
||||
### Qwen2.5-VL {#sec-qwen25-vl}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||
|
||||
- A message is a list of `role` and `content`.
|
||||
- `role` can be `system`, `user`, `assistant`, etc.
|
||||
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
|
||||
|
||||
::: {.callout-note}
|
||||
For backwards compatibility:
|
||||
|
||||
- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{"type": "image", "image": ...}`. However, if the content already has a `{"type": "image"}` but no `image` key, it will be set the `image` key.
|
||||
- If `content` is a string, it will be converted to a list with `type` as `text`.
|
||||
:::
|
||||
|
||||
::: {.callout-tip}
|
||||
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
|
||||
|
||||
- `"path": "/path/to/image.jpg"`
|
||||
- `"url": "https://example.com/image.jpg"`
|
||||
- `"base64": "..."`
|
||||
- `"image": PIL.Image`
|
||||
:::
|
||||
|
||||
Here is an example of a multi-modal dataset:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
||||
{"type": "text", "text": "Describe this image in detail."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "The image is a bee."}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
@@ -502,9 +502,48 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||
:::
|
||||
|
||||
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
|
||||
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
|
||||
using 4 GPUs - 2 for training, and 2 for vLLM:
|
||||
|
||||
::: {.callout-important}
|
||||
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
tensor_parallel_size: 2
|
||||
gpu_memory_utilization: 0.85
|
||||
dtype: auto
|
||||
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
|
||||
|
||||
rl: grpo
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
vllm_server_timeout: 300
|
||||
```
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml
|
||||
```
|
||||
|
||||
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
|
||||
```
|
||||
|
||||
#### Reward functions
|
||||
|
||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||
|
||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
For example, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
|
||||
```python
|
||||
# rewards.py
|
||||
@@ -530,8 +569,6 @@ trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 256
|
||||
use_vllm: True
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.15
|
||||
num_generations: 4
|
||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||
reward_weights: [1.0]
|
||||
|
||||
97
docs/sequence_parallelism.qmd
Normal file
97
docs/sequence_parallelism.qmd
Normal file
@@ -0,0 +1,97 @@
|
||||
---
|
||||
title: Sequence Parallelism
|
||||
description: Train with long sequences split across multiple GPUs.
|
||||
---
|
||||
|
||||
# Sequence Parallelism
|
||||
|
||||
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||
GPU processes a different portion of the sequence, and the results are aggregated
|
||||
through a ring communication pattern.
|
||||
|
||||
## When to Use Sequence Parallelism
|
||||
|
||||
Use sequence parallelism when:
|
||||
|
||||
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
||||
- You have multiple GPUs available
|
||||
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||
- With 4 GPUs, valid values would be 2 or 4
|
||||
|
||||
## Implementation Details
|
||||
|
||||
When sequence parallelism is enabled:
|
||||
|
||||
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
||||
4. The trainer uses special ring communication patterns for attention operations
|
||||
|
||||
## Requirements
|
||||
|
||||
To use sequence parallelism, you need:
|
||||
|
||||
- Multiple GPUs (at least 2)
|
||||
- The `ring-flash-attn` package. Install with:
|
||||
- `pip install axolotl[ring-flash-attn]` (preferred)
|
||||
- `pip install ring-flash-attn>=0.1.4`
|
||||
|
||||
## Limitations
|
||||
|
||||
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
||||
- May have a small performance overhead due to communication between GPUs
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
base_model: meta-llama/Llama-3-8B-Instruct
|
||||
sequence_len: 8192
|
||||
|
||||
...
|
||||
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||
into 2 subsequences of length 4096 across 2 GPUs.
|
||||
|
||||
## Sample Packing with Sequence Parallelism
|
||||
|
||||
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||
|
||||
1. Samples are first packed together
|
||||
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
||||
3. Position IDs are automatically adjusted to maintain proper relative positions
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
71
examples/cohere/command-r-7b-qlora.yml
Normal file
71
examples/cohere/command-r-7b-qlora.yml
Normal file
@@ -0,0 +1,71 @@
|
||||
base_model: CohereForAI/c4ai-command-r7b-12-2024
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
# huggingface repo
|
||||
chat_template: cohere
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
79
examples/gemma3/gemma-3-1b-qlora.yml
Normal file
79
examples/gemma3/gemma-3-1b-qlora.yml
Normal file
@@ -0,0 +1,79 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
# huggingface repo
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
68
examples/gemma3/gemma-3-4b-qlora.yml
Normal file
68
examples/gemma3/gemma-3-4b-qlora.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
strict: false
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
70
examples/gemma3/gemma-3-4b-vision-qlora.yml
Normal file
70
examples/gemma3/gemma-3-4b-vision-qlora.yml
Normal file
@@ -0,0 +1,70 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -19,7 +19,6 @@ val_set_size: 0.0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
dataset_exact_deduplication: true
|
||||
test_value: true
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
80
examples/llama-3/lora-1b-sample-packing-sequentially.yml
Normal file
80
examples/llama-3/lora-1b-sample-packing-sequentially.yml
Normal file
@@ -0,0 +1,80 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
test_value: true
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
sample_packing_sequentially: true
|
||||
curriculum_sampling: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
63
examples/llava/lora-7b.yaml
Normal file
63
examples/llava/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: llava-hf/llava-1.5-7b-hf
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: llava
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
66
examples/mistral/mistral-small-3.1-24B-lora.yml
Normal file
66
examples/mistral/mistral-small-3.1-24B-lora.yml
Normal file
@@ -0,0 +1,66 @@
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
65
examples/pixtral/lora-12b.yml
Normal file
65
examples/pixtral/lora-12b.yml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: mistral-community/pixtral-12b
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: pixtral
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <pad>
|
||||
63
examples/qwen2-vl/lora-7b.yaml
Normal file
63
examples/qwen2-vl/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen2_vl
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -1,24 +1,23 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.3
|
||||
bitsandbytes==0.45.4
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.4.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.3
|
||||
liger-kernel==0.5.5
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
peft==0.15.0
|
||||
transformers==4.49.0
|
||||
transformers==4.50.3
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.5.2
|
||||
datasets==3.4.1
|
||||
deepspeed==0.16.4
|
||||
trl==0.15.1
|
||||
datasets==3.5.0
|
||||
deepspeed==0.15.4
|
||||
trl==0.16.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
@@ -36,6 +35,7 @@ einops
|
||||
colorama
|
||||
numba
|
||||
numpy>=1.24.4,<=2.0.1
|
||||
|
||||
# qlora things
|
||||
evaluate==0.4.1
|
||||
scipy
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
accelerate==0.34.1
|
||||
addict==2.4.0
|
||||
aiofiles==23.2.1
|
||||
aiohttp==3.9.0
|
||||
aiosignal==1.3.1
|
||||
aiostream==0.5.2
|
||||
alembic==1.13.1
|
||||
annotated-types==0.6.0
|
||||
annoy==1.17.3
|
||||
ansible==6.7.0
|
||||
ansible-core==2.13.13
|
||||
ansible-vault==2.1.0
|
||||
anyio==3.7.1
|
||||
appdirs==1.4.4
|
||||
art==6.0
|
||||
asgiref==3.7.2
|
||||
async-timeout==4.0.2
|
||||
attrdict==2.0.1
|
||||
attrs==22.2.0
|
||||
awscli==1.32.75
|
||||
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
|
||||
backoff==2.2.1
|
||||
base58==2.1.1
|
||||
beartype==0.17.2
|
||||
bitnet==0.2.1
|
||||
bitsandbytes==0.42.0
|
||||
bittensor==6.7.0
|
||||
black==23.7.0
|
||||
blinker==1.7.0
|
||||
boto3==1.34.75
|
||||
botocore==1.34.75
|
||||
cachetools==5.3.3
|
||||
cachy==0.1.1
|
||||
certifi==2023.7.22
|
||||
cffi==1.16.0
|
||||
cfgv==3.3.1
|
||||
chai-guanaco==1.2.4
|
||||
charset-normalizer==3.2.0
|
||||
cleo==0.6.8
|
||||
click==8.1.7
|
||||
cloudpickle==2.0.0
|
||||
cohere==4.11.2
|
||||
colorama==0.4.4
|
||||
coloredlogs==15.0.1
|
||||
CoLT5-attention==0.10.20
|
||||
contextlib2==21.6.0
|
||||
contourpy==1.2.0
|
||||
cryptography==41.0.3
|
||||
cycler==0.12.1
|
||||
cytoolz==0.12.3
|
||||
databricks-cli==0.18.0
|
||||
dataclasses-json==0.5.7
|
||||
datasets==2.11.0
|
||||
ddt==1.6.0
|
||||
decorator==5.1.1
|
||||
deepspeed==0.15.0
|
||||
# Editable Git install with no remote (dialogpt==0.1)
|
||||
-e /Users/wing/Projects/ml/dialogpt/src
|
||||
dill==0.3.6
|
||||
distlib==0.3.6
|
||||
docker==7.0.0
|
||||
docker-pycreds==0.4.0
|
||||
docstring-parser==0.15
|
||||
docutils==0.16
|
||||
ecdsa==0.18.0
|
||||
einops==0.7.0
|
||||
einops-exts==0.0.4
|
||||
einx==0.1.3
|
||||
entrypoints==0.4
|
||||
eth-hash==0.6.0
|
||||
eth-keys==0.5.0
|
||||
eth-typing==4.0.0
|
||||
eth-utils==2.3.1
|
||||
evaluate==0.4.0
|
||||
exceptiongroup==1.1.1
|
||||
fastapi==0.109.2
|
||||
fastcore==1.5.29
|
||||
ffmpy==0.4.0
|
||||
filelock==3.12.2
|
||||
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
|
||||
fire==0.5.0
|
||||
first==2.0.2
|
||||
flake8==7.0.0
|
||||
Flask==3.0.1
|
||||
fonttools==4.47.2
|
||||
frozendict==2.4.1
|
||||
frozenlist==1.3.3
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||
fsspec==2023.6.0
|
||||
fuzzywuzzy==0.18.0
|
||||
gitdb==4.0.10
|
||||
GitPython==3.1.31
|
||||
google-pasta==0.2.0
|
||||
gradio==4.42.0
|
||||
gradio_client==1.3.0
|
||||
greenlet==2.0.2
|
||||
grpclib==0.4.7
|
||||
gunicorn==21.2.0
|
||||
h11==0.14.0
|
||||
h2==4.1.0
|
||||
hpack==4.0.0
|
||||
httpcore==0.17.3
|
||||
httpx==0.24.1
|
||||
huggingface-hub==0.23.4
|
||||
humanfriendly==10.0
|
||||
hyperframe==6.0.1
|
||||
identify==2.5.24
|
||||
idna==3.4
|
||||
immutables==0.20
|
||||
importlib-metadata==6.7.0
|
||||
importlib-resources==6.1.1
|
||||
inflection==0.5.1
|
||||
iniconfig==2.0.0
|
||||
itsdangerous==2.1.2
|
||||
Jinja2==3.1.2
|
||||
jmespath==1.0.1
|
||||
joblib==1.3.2
|
||||
jsonlines==3.1.0
|
||||
jsonschema==2.6.0
|
||||
kiwisolver==1.4.5
|
||||
langchain==0.0.144
|
||||
Levenshtein==0.24.0
|
||||
libcst==1.1.0
|
||||
liger-kernel==0.0.0
|
||||
lion-pytorch==0.1.2
|
||||
llama-cpp-python==0.1.36
|
||||
llvmlite==0.40.1
|
||||
local-attention==1.9.0
|
||||
loguru==0.7.0
|
||||
Mako==1.3.2
|
||||
Markdown==3.5.2
|
||||
markdown-it-py==3.0.0
|
||||
markdown2==2.4.10
|
||||
MarkupSafe==2.1.2
|
||||
marshmallow==3.19.0
|
||||
marshmallow-enum==1.5.1
|
||||
matplotlib==3.8.2
|
||||
mccabe==0.7.0
|
||||
mdurl==0.1.2
|
||||
MEGABYTE-pytorch==0.0.7
|
||||
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
|
||||
mlflow==2.10.0
|
||||
modal==0.62.77
|
||||
more-itertools==10.2.0
|
||||
mpmath==1.2.1
|
||||
msgpack==1.0.7
|
||||
msgpack-numpy-opentensor==0.5.0
|
||||
multidict==6.0.4
|
||||
multiprocess==0.70.14
|
||||
munch==2.5.0
|
||||
mypy==1.3.0
|
||||
mypy-extensions==1.0.0
|
||||
nest-asyncio==1.6.0
|
||||
netaddr==0.10.1
|
||||
networkx==3.0rc1
|
||||
nh3==0.2.14
|
||||
nodeenv==1.8.0
|
||||
nomic==2.0.2
|
||||
numba==0.57.1
|
||||
numexpr==2.8.4
|
||||
numpy==1.24.4
|
||||
oauthlib==3.2.2
|
||||
openai==0.27.4
|
||||
openapi==1.1.0
|
||||
openapi-schema-pydantic==1.2.4
|
||||
optimum==1.8.6
|
||||
orjson==3.10.7
|
||||
packaging==23.1
|
||||
pandas==2.0.0
|
||||
parameterized==0.9.0
|
||||
password-strength==0.0.3.post2
|
||||
pastel==0.1.1
|
||||
pathos==0.3.0
|
||||
pathspec==0.11.1
|
||||
pathtools==0.1.2
|
||||
peft==0.11.1
|
||||
pendulum==3.0.0
|
||||
Pillow==9.5.0
|
||||
pip-tools==1.11.0
|
||||
platformdirs==3.2.0
|
||||
pluggy==1.4.0
|
||||
poetry==0.7.1
|
||||
pox==0.3.2
|
||||
ppft==1.7.6.6
|
||||
pre-commit==3.3.2
|
||||
prettytable==3.10.0
|
||||
prompt-toolkit==3.0.39
|
||||
protobuf==3.20.2
|
||||
protobuf3-to-dict==0.1.5
|
||||
psutil==5.9.5
|
||||
psycopg==3.1.18
|
||||
PuLP==2.8.0
|
||||
py==1.11.0
|
||||
py-bip39-bindings==0.1.11
|
||||
py-cpuinfo==9.0.0
|
||||
py-ed25519-zebra-bindings==1.0.1
|
||||
py-sr25519-bindings==0.2.0
|
||||
pyarrow==11.0.0
|
||||
pyasn1==0.6.0
|
||||
pycodestyle==2.11.1
|
||||
pycparser==2.21
|
||||
pycryptodome==3.20.0
|
||||
pydantic==2.5.3
|
||||
pydantic_core==2.14.6
|
||||
pydub==0.25.1
|
||||
pyfiglet==0.8.post1
|
||||
pyflakes==3.2.0
|
||||
Pygments==2.15.1
|
||||
PyJWT==2.8.0
|
||||
pylev==1.4.0
|
||||
PyNaCl==1.5.0
|
||||
pynvml==11.5.0
|
||||
pyparsing==2.4.7
|
||||
pyrsistent==0.14.11
|
||||
pytest==8.0.2
|
||||
pytest-asyncio==0.23.4
|
||||
python-dateutil==2.8.2
|
||||
python-dotenv==1.0.1
|
||||
python-Levenshtein==0.24.0
|
||||
python-multipart==0.0.9
|
||||
pytz==2023.3
|
||||
PyYAML==6.0.1
|
||||
querystring-parser==1.2.4
|
||||
rapidfuzz==3.6.1
|
||||
regex==2023.6.3
|
||||
requests==2.31.0
|
||||
requests-toolbelt==0.8.0
|
||||
resolvelib==0.8.1
|
||||
responses==0.18.0
|
||||
retry==0.9.2
|
||||
rich==13.7.0
|
||||
rsa==4.7.2
|
||||
ruff==0.6.3
|
||||
s3transfer==0.10.1
|
||||
safetensors==0.4.5
|
||||
sagemaker==2.148.0
|
||||
scalecodec==1.2.7
|
||||
schedulefree==1.2.1
|
||||
schema==0.7.5
|
||||
scikit-learn==1.4.0
|
||||
scipy==1.9.3
|
||||
seaborn==0.13.2
|
||||
semantic-version==2.10.0
|
||||
sentencepiece==0.2.0
|
||||
sentry-sdk==1.19.1
|
||||
setproctitle==1.3.2
|
||||
shellingham==1.5.4
|
||||
shortuuid==1.0.11
|
||||
shtab==1.6.5
|
||||
sigtools==4.0.1
|
||||
six==1.16.0
|
||||
skypilot==0.4.1
|
||||
smdebug-rulesconfig==1.0.1
|
||||
smmap==5.0.0
|
||||
sniffio==1.3.0
|
||||
SQLAlchemy==1.4.47
|
||||
sqlparse==0.4.4
|
||||
starlette==0.36.3
|
||||
substrate-interface==1.5.2
|
||||
svgwrite==1.4.3
|
||||
sympy==1.11.1
|
||||
synchronicity==0.6.7
|
||||
tabulate==0.9.0
|
||||
tblib==1.7.0
|
||||
tenacity==8.2.2
|
||||
tensor-parallel==2.0.0
|
||||
termcolor==2.2.0
|
||||
text2art==0.2.0
|
||||
threadpoolctl==3.2.0
|
||||
tiktoken==0.6.0
|
||||
time-machine==2.14.1
|
||||
timm==0.9.16
|
||||
tokenizers==0.19.1
|
||||
tokenmonster==1.1.12
|
||||
toml==0.9.6
|
||||
tomli==2.0.1
|
||||
tomlkit==0.12.0
|
||||
toolz==0.12.1
|
||||
torch==2.2.0
|
||||
torchdata==0.6.1
|
||||
torchdiffeq==0.2.3
|
||||
TorchFix==0.4.0
|
||||
torchtext==0.15.2
|
||||
torchvision==0.17.0
|
||||
tqdm==4.66.2
|
||||
transformers==4.44.2
|
||||
trl==0.9.6
|
||||
typer==0.12.5
|
||||
types-certifi==2021.10.8.3
|
||||
types-requests==2.31.0.20240125
|
||||
types-setuptools==69.0.0.20240125
|
||||
types-toml==0.10.8.7
|
||||
typing==3.7.4.3
|
||||
typing-inspect==0.8.0
|
||||
typing_extensions==4.9.0
|
||||
tyro==0.5.18
|
||||
tzdata==2023.3
|
||||
unique-names-generator==1.0.2
|
||||
urllib3==2.2.2
|
||||
uvicorn==0.22.0
|
||||
vector_quantize_pytorch==1.14.1
|
||||
virtualenv==20.23.0
|
||||
voyager==2.0.2
|
||||
wandb==0.16.2
|
||||
watchfiles==0.21.0
|
||||
wavedrom==2.0.3.post3
|
||||
wcwidth==0.2.6
|
||||
websocket-client==1.7.0
|
||||
websockets==12.0
|
||||
Werkzeug==3.0.1
|
||||
wonderwords==2.2.0
|
||||
xxhash==3.2.0
|
||||
yarl==1.8.2
|
||||
zetascale==2.2.7
|
||||
zipp==3.15.0
|
||||
97
setup.py
97
setup.py
@@ -10,19 +10,13 @@ from pathlib import Path
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def parse_requirements():
|
||||
def parse_requirements(extras_require_map):
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
is_extras = (
|
||||
"flash-attn" in line
|
||||
or "flash-attention" in line
|
||||
or "deepspeed" in line
|
||||
or "mamba-ssm" in line
|
||||
or "lion-pytorch" in line
|
||||
)
|
||||
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
_, url = line.split()
|
||||
@@ -39,7 +33,6 @@ def parse_requirements():
|
||||
"bitsandbytes",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"flash-attn",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"liger-kernel",
|
||||
@@ -74,6 +67,7 @@ def parse_requirements():
|
||||
if (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post2")
|
||||
extras_require_map["vllm"] = ["vllm==0.8.1"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
@@ -93,7 +87,7 @@ def parse_requirements():
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
return _install_requires, _dependency_links
|
||||
return _install_requires, _dependency_links, extras_require_map
|
||||
|
||||
|
||||
def get_package_version():
|
||||
@@ -110,7 +104,50 @@ def get_package_version():
|
||||
return version_
|
||||
|
||||
|
||||
install_requires, dependency_links = parse_requirements()
|
||||
extras_require = {
|
||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.7.4.post1",
|
||||
"ring-flash-attn>=0.1.4",
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.15.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
],
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
],
|
||||
"mlflow": [
|
||||
"mlflow",
|
||||
],
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
"apollo": [
|
||||
"apollo-torch",
|
||||
],
|
||||
"optimizers": [
|
||||
"galore_torch",
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
}
|
||||
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
extras_require
|
||||
)
|
||||
|
||||
setup(
|
||||
version=get_package_version(),
|
||||
@@ -123,41 +160,5 @@ setup(
|
||||
"axolotl=axolotl.cli.main:main",
|
||||
],
|
||||
},
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.7.4.post1",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
],
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
],
|
||||
"mlflow": [
|
||||
"mlflow",
|
||||
],
|
||||
"lion-pytorch": [
|
||||
"lion-pytorch==0.1.2",
|
||||
],
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
"optimizers": [
|
||||
"galore_torch",
|
||||
"lion-pytorch==0.1.2",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
},
|
||||
extras_require=extras_require_build,
|
||||
)
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.8.0.dev0"
|
||||
__version__ = "0.8.0"
|
||||
|
||||
@@ -35,6 +35,55 @@ class TrainerCliArgs:
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VllmServeCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
||||
|
||||
tensor_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
host: str = field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
)
|
||||
port: int = field(
|
||||
default=8000,
|
||||
metadata={"help": "Port to run the server on."},
|
||||
)
|
||||
gpu_memory_utilization: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||||
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
||||
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
||||
"out-of-memory (OOM) errors during initialization."
|
||||
},
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||
},
|
||||
)
|
||||
max_model_len: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
||||
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
||||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||
},
|
||||
)
|
||||
enable_prefix_caching: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
||||
"hardware support this feature."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluateCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
||||
|
||||
@@ -56,7 +56,7 @@ def do_inference(
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Inference-specific CLI arguments.
|
||||
"""
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||
prompter = cli_args.prompter
|
||||
|
||||
prompter_module = None
|
||||
@@ -151,7 +151,7 @@ def do_inference_gradio(
|
||||
"""
|
||||
import gradio as gr
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||
prompter = cli_args.prompter
|
||||
|
||||
prompter_module = None
|
||||
@@ -256,7 +256,7 @@ def do_cli(
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, inference=True, **kwargs)
|
||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
|
||||
@@ -14,7 +14,12 @@ import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.cli.args import (
|
||||
EvaluateCliArgs,
|
||||
PreprocessCliArgs,
|
||||
TrainerCliArgs,
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
@@ -23,6 +28,7 @@ from axolotl.cli.utils import (
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from axolotl.cli.vllm_serve import do_vllm_serve
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
@@ -316,6 +322,14 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
fetch_from_github(f"{directory}/", dest)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(VllmServeCliArgs)
|
||||
@filter_none_kwargs
|
||||
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||
do_vllm_serve(config, cli_args)
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
LOG.info("Running merge of LoRA with base model...")
|
||||
@@ -44,6 +44,9 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
if processor:
|
||||
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
@@ -71,8 +74,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
sequence_parallel_degree=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
fsdp_config=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -83,13 +88,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
|
||||
)
|
||||
|
||||
parsed_cfg.load_in_4bit = False
|
||||
parsed_cfg.load_in_8bit = False
|
||||
parsed_cfg.flash_attention = False
|
||||
parsed_cfg.deepspeed = None
|
||||
parsed_cfg.fsdp = None
|
||||
parsed_cfg.fsdp_config = None
|
||||
|
||||
do_merge_lora(cfg=parsed_cfg)
|
||||
|
||||
|
||||
|
||||
@@ -17,13 +17,14 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
"""
|
||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||
@@ -33,6 +34,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
@@ -44,16 +48,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
del model, tokenizer, trainer
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
del trainer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||
|
||||
|
||||
@@ -13,11 +13,16 @@ from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
import click
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -295,9 +300,13 @@ def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model and tokenizer specified in the given `axolotl`
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Args:
|
||||
@@ -305,7 +314,7 @@ def load_model_and_tokenizer(
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
`transformers` model and tokenizer.
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
@@ -313,4 +322,9 @@ def load_model_and_tokenizer(
|
||||
LOG.info("loading model...")
|
||||
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||
|
||||
return model, tokenizer
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
|
||||
55
src/axolotl/cli/vllm_serve.py
Normal file
55
src/axolotl/cli/vllm_serve.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
CLI to start the vllm server for online RL
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
from trl.scripts.vllm_serve import main as vllm_serve_main
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
|
||||
|
||||
def do_vllm_serve(
|
||||
config: Union[Path, str],
|
||||
cli_args: dict,
|
||||
):
|
||||
"""
|
||||
Starts the VLLM server for serving LLM models used for online RL
|
||||
|
||||
Args
|
||||
:param cfg: Parsed doct of the YAML config
|
||||
:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs
|
||||
|
||||
Returns:
|
||||
process_id: the process id of the started VLLM server
|
||||
"""
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
host = cli_args.get("host") or cfg.vllm.host
|
||||
port = cli_args.get("port") or cfg.vllm.port
|
||||
gpu_memory_utilization = (
|
||||
cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization
|
||||
)
|
||||
dtype = cli_args.get("dtype") or cfg.vllm.dtype
|
||||
max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len
|
||||
enable_prefix_caching = (
|
||||
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
|
||||
)
|
||||
|
||||
vllm_script_args = ScriptArguments(
|
||||
model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
host=host,
|
||||
port=port,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
vllm_serve_main(vllm_script_args)
|
||||
@@ -36,7 +36,7 @@ from transformers import (
|
||||
from transformers.training_args import OptimizerNames
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
from axolotl.core.trainers import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlMambaTrainer,
|
||||
@@ -60,6 +60,7 @@ from axolotl.core.training_args import (
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
@@ -68,7 +69,6 @@ from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SaveModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
@@ -248,7 +248,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
callbacks.append(SaveModelCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
@@ -525,9 +524,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
and self.cfg.eval_steps
|
||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||
) or False
|
||||
|
||||
# handle ddp
|
||||
ddp_find_unused_parameters = None
|
||||
if self.cfg.ddp:
|
||||
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
|
||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||
False if self.cfg.ddp else None
|
||||
ddp_find_unused_parameters
|
||||
)
|
||||
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
report_to = []
|
||||
@@ -747,6 +752,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.accelerator_config
|
||||
)
|
||||
|
||||
if self.cfg.image_size:
|
||||
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
||||
if self.cfg.image_resize_algorithm:
|
||||
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||
self.cfg.image_resize_algorithm
|
||||
)
|
||||
if self.cfg.kd_ce_alpha is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||
if self.cfg.kd_alpha is not None:
|
||||
@@ -762,6 +773,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
|
||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
elif self.cfg.process_reward_model:
|
||||
@@ -845,9 +860,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if self.cfg.pretraining_sample_concatenation is False:
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
if self.cfg.micro_batch_size > 1:
|
||||
if (
|
||||
self.cfg.pretraining_sample_concatenation is False
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
return None
|
||||
|
||||
@@ -875,9 +891,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if "max_length" in kwargs:
|
||||
kwargs.pop("max_length")
|
||||
elif use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
and self.cfg.flash_attention is not True
|
||||
):
|
||||
@@ -887,8 +901,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
if self.cfg.processor_type and self.processor:
|
||||
collator = MultiModalChatDataCollator
|
||||
kwargs["processor"] = self.processor
|
||||
kwargs["chat_template"] = training_args.chat_template
|
||||
kwargs["processing_strategy"] = get_processing_strategy(
|
||||
self.processor,
|
||||
training_args.chat_template,
|
||||
self.cfg.chat_template,
|
||||
image_size=training_args.image_size,
|
||||
image_resize_algorithm=training_args.image_resize_algorithm,
|
||||
)
|
||||
elif self.cfg.batch_flattening:
|
||||
collator = DataCollatorWithFlattening
|
||||
collator_args.pop(0)
|
||||
@@ -908,6 +927,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
kwargs["return_tensors"] = "pt"
|
||||
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||
|
||||
return collator(
|
||||
*collator_args,
|
||||
@@ -920,7 +941,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
callbacks.append(SaveModelCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Init for axolotl.core.trainers"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
TRLPPOTrainer,
|
||||
)
|
||||
|
||||
@@ -1,365 +1,49 @@
|
||||
"""
|
||||
module for customized trainers
|
||||
"""
|
||||
"""Module for customized trainers"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Dict, Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
RandomSampler,
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
from axolotl.core.trainers.mixins import (
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
SequenceParallelMixin,
|
||||
)
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||
if isinstance(dataset_tags, str):
|
||||
dataset_tags = [dataset_tags]
|
||||
|
||||
if (dataset_tags is not None) and (kwargs is not None):
|
||||
if "dataset_tags" not in kwargs:
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||
kwargs["dataset_tags"].extend(dataset_tags)
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||
dataset_tags.append(kwargs["dataset_tags"])
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
use_cosine_quadratic = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
)
|
||||
|
||||
use_cosine_min_lr = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
extra_lr_kwargs = {}
|
||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["pct_start"] = pct_start
|
||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||
if use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
|
||||
self.lr_scheduler = RexLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
total_steps=num_training_steps,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class OptimizerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for shared handling of building custom optimizers
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_optimizer_grouped_parameters(
|
||||
self, opt_model, optimizer_kwargs
|
||||
) -> list[dict]:
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params: dict = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
lr_groups_lookup = {}
|
||||
lr_groups_learning_rates = {}
|
||||
if self.args.lr_groups:
|
||||
for lr_group in self.args.lr_groups:
|
||||
group_name = lr_group["name"]
|
||||
group_modules = lr_group["modules"]
|
||||
for module in group_modules:
|
||||
lr_groups_lookup[module] = group_name
|
||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||
params[f"to_weight_decay_{group_name}"] = {}
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
lr_group_modules = [
|
||||
group_modules
|
||||
for group_modules in lr_groups_lookup
|
||||
if group_modules in name
|
||||
]
|
||||
if lr_groups_lookup and any(lr_group_modules):
|
||||
lr_group_module = lr_group_modules[0]
|
||||
group_name = lr_groups_lookup[lr_group_module]
|
||||
params[f"to_weight_decay_{group_name}"][name] = param
|
||||
else:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||
if params[f"to_weight_decay_{group_name}"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(
|
||||
params[f"to_weight_decay_{group_name}"].values()
|
||||
),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.optimizer_cls_and_kwargs is None
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
|
||||
if (
|
||||
not self.optimizer
|
||||
and self.optimizer_cls_and_kwargs is not None
|
||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||
):
|
||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
self.optimizer = optimizer_factory_cls()(
|
||||
opt_model, self.args, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if not self.optimizer:
|
||||
if self.optimizer_cls_and_kwargs is not None:
|
||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
else:
|
||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||
self.args, opt_model
|
||||
)
|
||||
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", 1e-6
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
else:
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for LOMO optimizer.
|
||||
if "model" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||
"optimizer_dict"
|
||||
)
|
||||
|
||||
self.optimizer = optimizer_cls(
|
||||
optimizer_grouped_parameters, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum(
|
||||
{
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlTrainer(
|
||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
@@ -376,12 +60,18 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
|
||||
super().__init__(*_args, **kwargs)
|
||||
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Initialize sequence parallelism if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._setup_sequence_parallel()
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
@@ -394,142 +84,248 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
def _create_multipack_sampler(
|
||||
self, base_sampler: Sampler, dataset: Dataset
|
||||
) -> MultipackBatchSampler:
|
||||
"""
|
||||
Helper method to create a `MultipackBatchSampler` for multipacking sequences
|
||||
for training.
|
||||
|
||||
if self.args.curriculum_sampling:
|
||||
sampler = SequentialSampler(self.train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(self.train_dataset)
|
||||
Args:
|
||||
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
|
||||
dataset: Dataset to sample from.
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
Returns:
|
||||
Multipack (sample packing) batch sampler.
|
||||
"""
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
return super()._get_train_sampler()
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_eval_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
lengths=get_dataset_lengths(self.eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
return MultipackBatchSampler(
|
||||
base_sampler,
|
||||
lengths=get_dataset_lengths(dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self) -> Sampler | None:
|
||||
"""
|
||||
Helper method to get the sampler for training. Handles cases for sequence
|
||||
parallelism, sample packing, and curriculum sampling (sequential).
|
||||
|
||||
Returns:
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||
|
||||
# Determine the base sampler first
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
||||
elif self.args.curriculum_sampling:
|
||||
base_sampler = SequentialSampler(self.train_dataset)
|
||||
elif use_sample_packing:
|
||||
base_sampler = RandomSampler(self.train_dataset)
|
||||
else:
|
||||
# Default to parent class implementation for standard random sampling
|
||||
return super()._get_train_sampler()
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if use_sample_packing:
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=self.train_dataset,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = self.train_dataset
|
||||
if "length" in train_dataset.features.keys():
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = (
|
||||
self.args.dataloader_prefetch_factor
|
||||
)
|
||||
return base_sampler
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||
"""
|
||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||
and sample packing cases.
|
||||
|
||||
Returns:
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||
use_multipack = (
|
||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||
)
|
||||
|
||||
# Determine the base sampler
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
||||
elif use_multipack:
|
||||
base_sampler = SequentialSampler(eval_dataset)
|
||||
else:
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if use_multipack:
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=eval_dataset,
|
||||
)
|
||||
|
||||
return base_sampler
|
||||
|
||||
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
||||
"""Create common dataloader parameters for train or eval."""
|
||||
batch_size = custom_batch_size or (
|
||||
self.args.eval_batch_size if is_eval else self._train_batch_size
|
||||
)
|
||||
|
||||
params = {
|
||||
"batch_size": batch_size,
|
||||
"collate_fn": self.data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
# Add persistent workers only for training
|
||||
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
||||
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
||||
|
||||
# Add prefetch factor if specified
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return params
|
||||
|
||||
def _prepare_dataloader(
|
||||
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
||||
):
|
||||
"""Prepare a dataloader with the given dataset and sampler."""
|
||||
# Get base parameters
|
||||
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
||||
|
||||
# Add sampler configuration
|
||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
if isinstance(sampler, BatchSampler):
|
||||
# batch_size and batch_sampler are mutually exclusive
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
if not is_eval:
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
if self.args.sample_packing and (
|
||||
(not is_eval and not self.args.pretraining)
|
||||
or (is_eval and self.args.eval_sample_packing is not False)
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
# Return unprepared dataloader if using sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator # type: ignore
|
||||
|
||||
# Handle dataset preprocessing
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
if not self.args.sample_packing or self.args.pretraining:
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||
data_collator,
|
||||
description="training",
|
||||
)
|
||||
|
||||
# Get sampler and create dataloader
|
||||
sampler = self._get_train_sampler()
|
||||
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||
"""Get dataloader for evaluation"""
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
)
|
||||
if eval_dataset:
|
||||
if "length" in eval_dataset.column_names:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.train_data_collator
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
# Handle sample packing or sequence parallelism
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and self.args.eval_sample_packing is not False
|
||||
or self.args.sequence_parallel_degree > 1
|
||||
):
|
||||
# Get appropriate data collator
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
||||
else self.data_collator
|
||||
)
|
||||
if "length" in eval_dataset.column_names:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
|
||||
# Handle dataset preprocessing for SP
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if isinstance(eval_dataset, datasets.Dataset):
|
||||
eval_dataset = self._remove_unused_columns(
|
||||
eval_dataset, description="evaluation"
|
||||
)
|
||||
else:
|
||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||
self.data_collator, description="evaluation"
|
||||
)
|
||||
|
||||
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
||||
batch_size = (
|
||||
self.args.eval_batch_size
|
||||
if self.args.sample_packing
|
||||
else self.args.per_device_eval_batch_size
|
||||
)
|
||||
sampler = self._get_eval_sampler(eval_dataset)
|
||||
dataloader = self._prepare_dataloader(
|
||||
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = (
|
||||
self.args.dataloader_prefetch_factor
|
||||
)
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
return dataloader
|
||||
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
) -> torch.utils.data.Sampler | None:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
@@ -554,6 +350,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
@@ -570,6 +367,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
@@ -744,10 +542,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@@ -764,15 +562,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
|
||||
return res
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
start_time (`Optional[float]`):
|
||||
The start of training.
|
||||
logs: The values to log.
|
||||
start_time: The start of training.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
@@ -784,7 +580,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
) -> None:
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
@@ -796,111 +592,3 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Mamba specific trainer to handle loss calculation
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "mamba"]
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
|
||||
@@ -13,17 +13,17 @@ from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
SchedulerMixin,
|
||||
_sanitize_kwargs_for_ds_tagging,
|
||||
_sanitize_kwargs_for_tagging,
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -40,18 +40,15 @@ class GRPOStrategy:
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_device"] = (
|
||||
trl.vllm_device if trl.vllm_device else "auto"
|
||||
)
|
||||
|
||||
if trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
trl.vllm_gpu_memory_utilization
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
|
||||
if trl.vllm_server_timeout:
|
||||
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||
if trl.vllm_guided_decoding_regex:
|
||||
grpo_args_kwargs["vllm_guided_decoding_regex"] = (
|
||||
trl.vllm_guided_decoding_regex
|
||||
)
|
||||
|
||||
if trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||
|
||||
if trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||
|
||||
@@ -70,6 +67,25 @@ class GRPOStrategy:
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
if trl.scale_rewards is not None:
|
||||
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
|
||||
|
||||
if trl.temperature is not None:
|
||||
grpo_args_kwargs["temperature"] = trl.temperature
|
||||
if trl.top_p is not None:
|
||||
grpo_args_kwargs["top_p"] = trl.top_p
|
||||
if trl.top_k is not None:
|
||||
grpo_args_kwargs["top_k"] = trl.top_k
|
||||
if trl.min_p is not None:
|
||||
grpo_args_kwargs["min_p"] = trl.min_p
|
||||
if trl.repetition_penalty is not None:
|
||||
grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty
|
||||
|
||||
if trl.num_iterations is not None:
|
||||
grpo_args_kwargs["num_iterations"] = trl.num_iterations
|
||||
if trl.epsilon is not None:
|
||||
grpo_args_kwargs["epsilon"] = trl.epsilon
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -2,108 +2,68 @@
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
|
||||
from accelerate.utils import is_peft_model
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from transformers import PreTrainedModel
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.models import unwrap_model_for_generation
|
||||
from contextlib import nullcontext
|
||||
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||
from trl import GRPOTrainer
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# pylint: disable=access-member-before-definition
|
||||
# Enable gradient checkpointing if requested
|
||||
if kwargs["args"].gradient_checkpointing:
|
||||
# Ensure use_cache is disabled
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(self.model) and hasattr(
|
||||
self.model.base_model, "gradient_checkpointing_enable"
|
||||
):
|
||||
self.model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# pylint: disable=unused-argument,redefined-builtin
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
@profiling_decorator
|
||||
def _move_model_to_vllm(self):
|
||||
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
gather_if_zero3 = (
|
||||
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
if is_peft_model(self.model):
|
||||
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
||||
# adapters in a sharded manner is not supported.
|
||||
with gather_if_zero3(list(self.model.parameters())):
|
||||
self.model.merge_adapter()
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
# Update vLLM weights while parameters are gathered
|
||||
for name, param in self.model.named_parameters():
|
||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
||||
name = (
|
||||
name.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", "")
|
||||
)
|
||||
if self.model.prefix in name:
|
||||
continue
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
if "original_module" in name:
|
||||
continue
|
||||
name = name.replace("modules_to_save.default.", "")
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(
|
||||
make_inputs_require_grad
|
||||
)
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
|
||||
return model
|
||||
# pylint: enable=unused-argument,redefined-builtin
|
||||
# Unmerge adapters while parameters are still gathered
|
||||
self.model.unmerge_adapter()
|
||||
# Parameters will automatically be repartitioned when exiting the context
|
||||
else:
|
||||
# For non-PEFT models, simply gather and update each parameter individually.
|
||||
for name, param in self.model.named_parameters():
|
||||
with gather_if_zero3([param]):
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
|
||||
def _move_model_to_vllm(self):
|
||||
with unwrap_model_for_generation(
|
||||
self.model,
|
||||
self.accelerator,
|
||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||
) as unwrapped_model:
|
||||
if is_compiled_module(unwrapped_model):
|
||||
unwrapped_model = (
|
||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||
)
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", ""): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
# Remove values with adapter prefix (example: "_lora")
|
||||
state_dict = {
|
||||
k: v
|
||||
for k, v in state_dict.items()
|
||||
if unwrapped_model.prefix not in k
|
||||
}
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
state_dict = {
|
||||
k.replace("modules_to_save.default.", ""): v
|
||||
for k, v in state_dict.items()
|
||||
if "original_module" not in k
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
# Reset cache on main process
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
|
||||
32
src/axolotl/core/trainers/mamba.py
Normal file
32
src/axolotl/core/trainers/mamba.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Module for mamba trainer"""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""Mamba specific trainer to handle loss calculation"""
|
||||
|
||||
tag_names = ["axolotl", "mamba"]
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
return lm_loss
|
||||
9
src/axolotl/core/trainers/mixins/__init__.py
Normal file
9
src/axolotl/core/trainers/mixins/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Init for axolotl.core.trainers.mixins"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .optimizer import OptimizerMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
from .sequence_parallel import SequenceParallelMixin
|
||||
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Module for Axolotl trainer optimizer mixin"""
|
||||
|
||||
import logging
|
||||
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OptimizerMixin(Trainer):
|
||||
"""Mixin class for shared handling of building custom optimizers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_optimizer_grouped_parameters(
|
||||
self, opt_model, optimizer_kwargs
|
||||
) -> list[dict]:
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params: dict = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
lr_groups_lookup = {}
|
||||
lr_groups_learning_rates = {}
|
||||
if self.args.lr_groups:
|
||||
for lr_group in self.args.lr_groups:
|
||||
group_name = lr_group["name"]
|
||||
group_modules = lr_group["modules"]
|
||||
for module in group_modules:
|
||||
lr_groups_lookup[module] = group_name
|
||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||
params[f"to_weight_decay_{group_name}"] = {}
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
lr_group_modules = [
|
||||
group_modules
|
||||
for group_modules in lr_groups_lookup
|
||||
if group_modules in name
|
||||
]
|
||||
if lr_groups_lookup and any(lr_group_modules):
|
||||
lr_group_module = lr_group_modules[0]
|
||||
group_name = lr_groups_lookup[lr_group_module]
|
||||
params[f"to_weight_decay_{group_name}"][name] = param
|
||||
else:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||
if params[f"to_weight_decay_{group_name}"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(
|
||||
params[f"to_weight_decay_{group_name}"].values()
|
||||
),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.optimizer_cls_and_kwargs is None
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
|
||||
if (
|
||||
not self.optimizer
|
||||
and self.optimizer_cls_and_kwargs is not None
|
||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||
):
|
||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
self.optimizer = optimizer_factory_cls()(
|
||||
opt_model, self.args, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if not self.optimizer:
|
||||
if self.optimizer_cls_and_kwargs is not None:
|
||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
else:
|
||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||
self.args, opt_model
|
||||
)
|
||||
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", 1e-6
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
else:
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for LOMO optimizer.
|
||||
if "model" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||
"optimizer_dict"
|
||||
)
|
||||
|
||||
self.optimizer = optimizer_cls(
|
||||
optimizer_grouped_parameters, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum(
|
||||
{
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
67
src/axolotl/core/trainers/mixins/rng_state_loader.py
Normal file
67
src/axolotl/core/trainers/mixins/rng_state_loader.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Temporary fix/override for bug in resume from checkpoint
|
||||
|
||||
See https://github.com/huggingface/transformers/pull/37162
|
||||
|
||||
TODO: Remove when upstream added PR to release
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Trainer, is_torch_npu_available
|
||||
from transformers.trainer import safe_globals
|
||||
from transformers.trainer_pt_utils import set_rng_state_for_device
|
||||
from transformers.training_args import ParallelMode
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RngLoaderMixin(Trainer):
|
||||
"""
|
||||
mixin for method override to load RNG states from a checkpoint
|
||||
"""
|
||||
|
||||
def _load_rng_state(self, checkpoint):
|
||||
# Load RNG states from `checkpoint`
|
||||
if checkpoint is None:
|
||||
return
|
||||
|
||||
if self.args.world_size > 1:
|
||||
process_index = self.args.process_index
|
||||
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
|
||||
if not os.path.isfile(rng_file):
|
||||
LOG.info(
|
||||
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
|
||||
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
|
||||
)
|
||||
return
|
||||
else:
|
||||
rng_file = os.path.join(checkpoint, "rng_state.pth")
|
||||
if not os.path.isfile(rng_file):
|
||||
LOG.info(
|
||||
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
|
||||
"fashion, reproducibility is not guaranteed."
|
||||
)
|
||||
return
|
||||
|
||||
# Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
|
||||
# which requires allowlisted classes when loading with weights_only=True.
|
||||
with safe_globals():
|
||||
checkpoint_rng_state = torch.load(rng_file) # nosec B614
|
||||
random.setstate(checkpoint_rng_state["python"])
|
||||
np.random.set_state(checkpoint_rng_state["numpy"])
|
||||
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
||||
|
||||
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
if torch.cuda.is_available():
|
||||
set_rng_state_for_device(
|
||||
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
|
||||
)
|
||||
if is_torch_npu_available():
|
||||
set_rng_state_for_device(
|
||||
"NPU", torch.npu, checkpoint_rng_state, is_distributed
|
||||
)
|
||||
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Module for Axolotl trainer scheduler mixin"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
use_cosine_quadratic = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
)
|
||||
|
||||
use_cosine_min_lr = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
extra_lr_kwargs = {}
|
||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["pct_start"] = pct_start
|
||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||
if use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
|
||||
self.lr_scheduler = RexLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
total_steps=num_training_steps,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler
|
||||
182
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
182
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
except ImportError:
|
||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||
pass
|
||||
|
||||
|
||||
class SequenceParallelMixin:
|
||||
"""
|
||||
Mixin class for sequence parallelism support in trainers.
|
||||
|
||||
This mixin provides functionality for handling sequence parallelism,
|
||||
including creating appropriate samplers, managing data partitioning,
|
||||
and updating ring flash attention parameters during training.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def _setup_sequence_parallel(self):
|
||||
"""Set up sequence parallelism environment."""
|
||||
self.ring_attn_group = get_ring_attn_group()
|
||||
|
||||
def _create_sequence_parallel_sampler(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
shuffle: bool = True,
|
||||
is_eval: bool = False,
|
||||
) -> DistributedSampler:
|
||||
"""
|
||||
Helper method to create sampler for sequence parallelism (SP).
|
||||
|
||||
We create a distributed sampler with rank equal to the SP group ID, which
|
||||
means that all ranks in the SP group receive the same sample / set of samples
|
||||
per training step. We also set the number of replicas equal to the number of
|
||||
SP groups, which is a bit of a hack / unintended use, but works!
|
||||
|
||||
Args:
|
||||
dataset: Dataset to sample from.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||
|
||||
Returns:
|
||||
Distributed sampler.
|
||||
"""
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
return DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed if shuffle else None,
|
||||
shuffle=shuffle,
|
||||
drop_last=not is_eval,
|
||||
)
|
||||
|
||||
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
||||
"""
|
||||
Get a training sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
dataset: The training dataset
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self._create_sequence_parallel_sampler(
|
||||
dataset,
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
)
|
||||
|
||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||
"""
|
||||
Get an evaluation sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
eval_dataset: The evaluation dataset.
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self._create_sequence_parallel_sampler(
|
||||
eval_dataset, shuffle=False, is_eval=True
|
||||
)
|
||||
|
||||
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
||||
"""
|
||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||
the substituted ring_flash_attn. This is accomplished by using the passed
|
||||
`input_ids`.
|
||||
|
||||
Args:
|
||||
inputs: Current batch of inputs.
|
||||
"""
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal training step
|
||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: list[str] | None = None,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Perform a prediction step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
Args:
|
||||
model: Model to perform prediction step for.
|
||||
inputs: Dictionary mapping of inputs.
|
||||
prediction_loss_only: Whether to return only the loss.
|
||||
ignore_keys: Keys to ignore in the inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of (loss, logits, labels).
|
||||
"""
|
||||
# Set up sequence parallelism for this prediction step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal prediction step
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||
43
src/axolotl/core/trainers/relora.py
Normal file
43
src/axolotl/core/trainers/relora.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Module for ReLoRA trainer"""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: torch.optim.Optimizer | None = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
@@ -1,16 +1,26 @@
|
||||
"""
|
||||
module for TRL PPO training
|
||||
"""
|
||||
"""Module for TRL PPO trainer"""
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from trl import PPOTrainer
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
"""
|
||||
wrapper for ppo trainer to handle customizations
|
||||
"""
|
||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||
|
||||
tag_names = ["axolotl", "ppo"]
|
||||
|
||||
def train(
|
||||
self,
|
||||
@@ -31,9 +41,7 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
"batch_size": 16,
|
||||
}
|
||||
|
||||
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
||||
enumerate(self.dataloader)
|
||||
):
|
||||
for _, batch in tqdm(enumerate(self.dataloader)):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# generate model response
|
||||
@@ -65,3 +73,189 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
rewards,
|
||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
|
||||
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
|
||||
|
||||
metrics = {}
|
||||
|
||||
forward_output = self.concatenated_forward(model, batch)
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_nll_loss,
|
||||
) = forward_output[:5]
|
||||
if self.aux_loss_enabled:
|
||||
aux_loss = forward_output[5]
|
||||
|
||||
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
|
||||
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
|
||||
)
|
||||
# full ORPO loss
|
||||
loss = policy_nll_loss - losses.mean()
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
|
||||
chosen_rewards
|
||||
).mean()
|
||||
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
|
||||
rejected_rewards
|
||||
).mean()
|
||||
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
|
||||
reward_accuracies
|
||||
).mean()
|
||||
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
||||
chosen_rewards - rejected_rewards
|
||||
).mean()
|
||||
metrics[f"{prefix}logps/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
||||
)
|
||||
metrics[f"{prefix}logps/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
||||
)
|
||||
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
|
||||
policy_rejected_logits.detach().mean()
|
||||
).mean()
|
||||
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
|
||||
policy_chosen_logits.detach().mean()
|
||||
).mean()
|
||||
metrics[f"{prefix}nll_loss"] = (
|
||||
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
||||
)
|
||||
metrics[f"{prefix}log_odds_ratio"] = (
|
||||
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
|
||||
)
|
||||
metrics[f"{prefix}log_odds_chosen"] = (
|
||||
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
|
||||
)
|
||||
for k, v in metrics.items():
|
||||
metrics[k] = v.item()
|
||||
if self.aux_loss_enabled:
|
||||
loss += self.aux_loss_coef * aux_loss
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
|
||||
forward_output = self.concatenated_forward(model, batch)
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_nll_loss,
|
||||
) = forward_output[:5]
|
||||
if self.aux_loss_enabled:
|
||||
aux_loss = forward_output[5]
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
)
|
||||
|
||||
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
||||
)
|
||||
metrics[f"{prefix}rewards/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
||||
)
|
||||
metrics[f"{prefix}rewards/accuracies"] = (
|
||||
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
||||
)
|
||||
metrics[f"{prefix}rewards/margins"] = (
|
||||
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
metrics[f"{prefix}logps/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(policy_rejected_logps)
|
||||
.detach()
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
metrics[f"{prefix}logps/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(policy_chosen_logps)
|
||||
.detach()
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
metrics[f"{prefix}logits/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
metrics[f"{prefix}logits/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
metrics[f"{prefix}nll_loss"] = (
|
||||
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
||||
)
|
||||
|
||||
if self.aux_loss_enabled:
|
||||
loss += self.aux_loss_coef * aux_loss
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
|
||||
33
src/axolotl/core/trainers/utils.py
Normal file
33
src/axolotl/core/trainers/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Utils for Axolotl trainers"""
|
||||
|
||||
|
||||
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||
if isinstance(dataset_tags, str):
|
||||
dataset_tags = [dataset_tags]
|
||||
|
||||
if (dataset_tags is not None) and (kwargs is not None):
|
||||
if "dataset_tags" not in kwargs:
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||
kwargs["dataset_tags"].extend(dataset_tags)
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||
dataset_tags.append(kwargs["dataset_tags"])
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
|
||||
return kwargs
|
||||
@@ -5,6 +5,7 @@ extra axolotl specific training args
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Resampling
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
@@ -33,6 +34,12 @@ class AxolotlTrainingMixins:
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_sequentially: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
@@ -207,14 +214,33 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
sequence_parallel_degree: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The size of the image to resize to"},
|
||||
)
|
||||
|
||||
image_resize_algorithm: Resampling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The algorithm to use for image resizing"},
|
||||
)
|
||||
|
||||
# end of multi-modal section
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
so it can't be used as a mixin.
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
||||
default value so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
@@ -159,4 +160,6 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
cleanup_distributed()
|
||||
|
||||
return all_metrics
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Cut Cross Entropy
|
||||
|
||||
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
||||
Cut Cross Entropy (CCE) reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
||||
|
||||
See https://github.com/apple/ml-cross-entropy
|
||||
|
||||
@@ -29,6 +29,20 @@ plugins:
|
||||
cut_cross_entropy: true
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
- llama
|
||||
- phi3
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3
|
||||
- gemma3_text
|
||||
- mistral
|
||||
- mistral3
|
||||
- qwen2
|
||||
- cohere
|
||||
- cohere2
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
|
||||
@@ -25,8 +25,8 @@ import torch
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.distributed import zero_only
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||
@@ -72,7 +72,9 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
if cfg.cut_cross_entropy:
|
||||
self._check_requirements()
|
||||
|
||||
from cut_cross_entropy.transformers import cce_patch
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
||||
cce_patch,
|
||||
)
|
||||
|
||||
with zero_only():
|
||||
LOG.info(
|
||||
|
||||
201
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
Normal file
201
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Cohere and Cohere2 CCE patch."""
|
||||
|
||||
# This patch is based off transformers 4.50.0.
|
||||
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
||||
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
||||
# operation is done internally and should be mathematically equivalent.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
_CONFIG_FOR_DOC,
|
||||
COHERE_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||
|
||||
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
|
||||
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>> # Generate
|
||||
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
# scale weight by logit_scale in-place of logits
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight * self.logit_scale,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
logits = logits * self.logit_scale # main diff from Llama
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_cohere(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.cohere import modeling_cohere
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_cohere.CohereForCausalLM
|
||||
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_cohere2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.cohere2 import modeling_cohere2
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
||||
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
175
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py
Normal file
175
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Gemma CCE patch"""
|
||||
|
||||
# This patch is based off transformers 4.50.0.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma.GemmaForCausalLM
|
||||
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
459
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py
Normal file
459
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py
Normal file
@@ -0,0 +1,459 @@
|
||||
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
||||
|
||||
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
||||
# and updated for transformers 4.50.0.
|
||||
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
||||
# with both gemma3 (text and multimodal) models.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache, HybridCache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA3_INPUTS_DOCSTRING,
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
||||
|
||||
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logits = logits / self.config.final_logit_softcapping
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.config.final_logit_softcapping
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"answer en Where is the cow standing?\nbeach"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids # type: ignore
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = (
|
||||
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
||||
)
|
||||
cache_position = torch.arange( # type: ignore
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(
|
||||
self.config.image_token_index,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||
|
||||
# mask out pad-token-ids in labels for BC
|
||||
if labels is not None and self.pad_token_id in labels:
|
||||
logger.warning_once(
|
||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||
)
|
||||
labels = torch.where( # type: ignore
|
||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
||||
)
|
||||
|
||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
inputs_embeds,
|
||||
is_training,
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
shift_logits = logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:]
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = shift_logits[
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = shift_labels[
|
||||
shift_attention_mask.to(shift_labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = shift_logits.contiguous()
|
||||
shift_labels = shift_labels.contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
|
||||
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
||||
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||
loss = loss_fct(flat_logits, flat_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Gemma3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma2 import modeling_gemma2
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
||||
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_gemma3_text(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
||||
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_gemma3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
||||
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -0,0 +1,392 @@
|
||||
"""Mistral and Mistral3 CCE patch."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.mistral3.modeling_mistral3 import (
|
||||
Mistral3CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
_CONFIG_FOR_DOC,
|
||||
MISTRAL_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] | None = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||
|
||||
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor | None = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
|
||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = logits[..., :-1, :][
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = labels[..., 1:][
|
||||
shift_attention_mask.to(labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1).to(shift_logits.device),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Mistral3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_mistral(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mistral.MistralForCausalLM
|
||||
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_mistral3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
from transformers.models.mistral3 import modeling_mistral3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
||||
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||
return None
|
||||
379
src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
Normal file
379
src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Mllama CCE patch."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.mllama.modeling_mllama import (
|
||||
MLLAMA_INPUTS_DOCSTRING,
|
||||
_prepare_cross_attention_mask,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cross_attention_states: Optional[torch.LongTensor] = None,
|
||||
cross_attention_mask: Optional[torch.LongTensor] = None,
|
||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
||||
|
||||
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
||||
|
||||
>>> prompt = "If I had to write a haiku, it would be:"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
||||
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
>>> print(result)
|
||||
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
||||
I love the idea of snowflakes gently falling, each one
|
||||
```
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class="MllamaConfig"
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
||||
|
||||
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
||||
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
||||
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
||||
|
||||
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> output = model.generate(**inputs, max_new_tokens=15)
|
||||
|
||||
>>> prompt_len = inputs.input_ids.shape[-1]
|
||||
>>> generated_ids = output[:, prompt_len:]
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
>>> print(generated_text)
|
||||
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
||||
```
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and cross_attention_states is not None:
|
||||
raise ValueError(
|
||||
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
if aspect_ratio_ids is None:
|
||||
raise ValueError(
|
||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||
)
|
||||
# get vision tokens from vision model
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
aspect_ratio_ids=aspect_ratio_ids,
|
||||
aspect_ratio_mask=aspect_ratio_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
cross_attention_states = vision_outputs[0]
|
||||
cross_attention_states = self.multi_modal_projector(
|
||||
cross_attention_states
|
||||
).reshape(
|
||||
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
||||
)
|
||||
|
||||
if cross_attention_mask is not None:
|
||||
cross_attention_mask, full_text_row_masked_out_mask = (
|
||||
_prepare_cross_attention_mask(
|
||||
cross_attention_mask,
|
||||
num_vision_tokens=self.vision_model.num_patches,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
)
|
||||
else:
|
||||
full_text_row_masked_out_mask = None
|
||||
|
||||
if cross_attention_mask is not None and cache_position is not None:
|
||||
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
||||
:, :, cache_position
|
||||
]
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
else:
|
||||
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
||||
logits = hidden_states
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (loss,) + outputs if loss is not None else outputs
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=outputs.logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_mllama(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mllama import modeling_mllama
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
||||
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the language model
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
||||
|
||||
# patch the causal language model
|
||||
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Cut Cross Entropy patcher"""
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
||||
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
||||
from cut_cross_entropy.transformers.llama import patch_llama
|
||||
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
||||
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
||||
patch_cohere,
|
||||
patch_cohere2,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
||||
patch_gemma2,
|
||||
patch_gemma3,
|
||||
patch_gemma3_text,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
||||
patch_mistral,
|
||||
patch_mistral3,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
||||
|
||||
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
||||
"llama": patch_llama,
|
||||
"mllama": patch_mllama,
|
||||
"phi3": patch_phi3,
|
||||
"gemma": patch_gemma,
|
||||
"gemma2": patch_gemma2,
|
||||
"gemma3": patch_gemma3,
|
||||
"gemma3_text": patch_gemma3_text,
|
||||
"mistral": patch_mistral,
|
||||
"mistral3": patch_mistral3,
|
||||
"qwen2": patch_qwen2,
|
||||
"cohere": patch_cohere,
|
||||
"cohere2": patch_cohere2,
|
||||
}
|
||||
|
||||
|
||||
def cce_patch(
|
||||
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
||||
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
||||
reduction: str = "mean",
|
||||
filter_eps: float | str | None = "auto",
|
||||
accum_e_fp32: bool = False,
|
||||
accum_c_fp32: bool = False,
|
||||
filter_e_grad: bool = True,
|
||||
filter_c_grad: bool = True,
|
||||
train_only: bool = False,
|
||||
) -> TransformersModelT | None:
|
||||
if isinstance(impl, LinearCrossEntropyImpl):
|
||||
impl = impl.name.lower()
|
||||
|
||||
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
||||
raise ValueError(f"Unknown {impl=}")
|
||||
|
||||
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
||||
model_type = model_type_or_model.config.model_type
|
||||
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
||||
model_type = model_type_or_model.model_type
|
||||
else:
|
||||
model_type = model_type_or_model
|
||||
|
||||
patch_options = PatchOptions(
|
||||
impl=impl,
|
||||
reduction=reduction,
|
||||
filter_eps=filter_eps,
|
||||
accum_e_fp32=accum_e_fp32,
|
||||
accum_c_fp32=accum_c_fp32,
|
||||
filter_e_grad=filter_e_grad,
|
||||
filter_c_grad=filter_c_grad,
|
||||
train_only=train_only,
|
||||
)
|
||||
|
||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
||||
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
||||
model_type_or_model, patch_options
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unknown model type {model_type}")
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Monkeypatch for apply_lce to add softcap."""
|
||||
|
||||
import torch
|
||||
from cut_cross_entropy import linear_cross_entropy
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions
|
||||
|
||||
|
||||
def apply_lce(
|
||||
e: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
opts: PatchOptions,
|
||||
bias: torch.Tensor | None = None,
|
||||
softcap: float | None = None,
|
||||
**loss_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Monkey patch for apply_lce to support softcap kwarg."""
|
||||
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
||||
cce_kwargs = opts.to_kwargs()
|
||||
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
||||
cce_kwargs["reduction"] = "sum"
|
||||
else:
|
||||
num_items_in_batch = None
|
||||
|
||||
loss = linear_cross_entropy(
|
||||
e,
|
||||
c,
|
||||
labels.to(e.device),
|
||||
bias=bias,
|
||||
shift=True,
|
||||
softcap=softcap,
|
||||
**cce_kwargs,
|
||||
)
|
||||
|
||||
if num_items_in_batch is not None:
|
||||
loss = loss / num_items_in_batch
|
||||
|
||||
return loss
|
||||
@@ -20,6 +20,26 @@ liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
- deepseek_v2
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3 (partial support, no support for FLCE yet)
|
||||
- granite
|
||||
- jamba
|
||||
- llama
|
||||
- mistral
|
||||
- mixtral
|
||||
- mllama
|
||||
- mllama_text_model
|
||||
- olmo2
|
||||
- paligemma
|
||||
- phi3
|
||||
- qwen2
|
||||
- qwen2_5_vl
|
||||
- qwen2_vl
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
|
||||
@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
@@ -41,11 +42,18 @@ class LigerPlugin(BasePlugin):
|
||||
def pre_model_load(self, cfg):
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
||||
raise ValueError(
|
||||
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
||||
)
|
||||
|
||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||
@@ -82,6 +90,8 @@ class LigerPlugin(BasePlugin):
|
||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
||||
if cfg.liger_cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
@@ -104,13 +114,51 @@ class LigerPlugin(BasePlugin):
|
||||
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||
if cfg.liger_glu_activation:
|
||||
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
|
||||
if cfg.liger_cross_entropy:
|
||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||
# nn.CrossEntropyLoss in the forward method.
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
|
||||
def _liger_rms_norm_wrapper(dim, **kwargs):
|
||||
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
|
||||
return LigerRMSNorm(hidden_size=dim, **kwargs)
|
||||
|
||||
modeling_gemma3.Gemma3RMSNorm = partial(
|
||||
_liger_rms_norm_wrapper,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cfg.liger_cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
raise NotImplementedError(
|
||||
"Fused linear cross entropy is not yet supported for Gemma3."
|
||||
)
|
||||
elif cfg.model_config_type in ["deepseek_v3"]:
|
||||
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
||||
|
||||
100
src/axolotl/monkeypatch/attention/ring_attn.py
Normal file
100
src/axolotl/monkeypatch/attention/ring_attn.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Ring attention group registration and flash attention patching.
|
||||
|
||||
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
RING_ATTN_GROUP = None
|
||||
|
||||
|
||||
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||
"""
|
||||
Getter for ring attention group on this rank.
|
||||
|
||||
Returns:
|
||||
The process group for ring attention for this rank.
|
||||
"""
|
||||
return RING_ATTN_GROUP
|
||||
|
||||
|
||||
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
"""
|
||||
Setter for ring attention group on this rank.
|
||||
|
||||
Args:
|
||||
Process group for ring attention.
|
||||
"""
|
||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed
|
||||
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
||||
"""
|
||||
if get_ring_attn_group() is not None:
|
||||
LOG.info("Ring attention already registered, exiting early...")
|
||||
return
|
||||
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
)
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
|
||||
# Detailed logging of group formation
|
||||
rank = dist.get_rank()
|
||||
group_assignments = {}
|
||||
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
i * sequence_parallel_degree,
|
||||
(i + 1) * sequence_parallel_degree,
|
||||
)
|
||||
)
|
||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||
|
||||
# Track which GPUs are in which groups
|
||||
for r in ring_attn_ranks:
|
||||
group_assignments[r] = i
|
||||
|
||||
if rank in ring_attn_ranks:
|
||||
set_ring_attn_group(group)
|
||||
|
||||
# Log the GPU group assignments
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if heads_k_stride is None:
|
||||
heads_k_stride = 1
|
||||
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||
)
|
||||
238
src/axolotl/monkeypatch/gemma3.py
Normal file
238
src/axolotl/monkeypatch/gemma3.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Monkeypatch for gemma3 conditional generation forward to fix loss exploding"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA3_INPUTS_DOCSTRING,
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def new_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"answer en Where is the cow standing?\nbeach"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id with PAD if the image token is OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = (
|
||||
past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
)
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(
|
||||
self.config.image_token_index,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
# mask out pad-token-ids in labels for BC
|
||||
if labels is not None and self.pad_token_id in labels:
|
||||
logger.warning_once(
|
||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||
)
|
||||
labels = torch.where(
|
||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
||||
)
|
||||
|
||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
inputs_embeds,
|
||||
is_training,
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if attention_mask is not None:
|
||||
# Get the shifted attention mask
|
||||
shift_attention_mask = attention_mask[:, -logits.shape[1] + 1 :].to(
|
||||
logits.device
|
||||
) # +1 for shift
|
||||
|
||||
# Filter logits and labels based on attention mask
|
||||
valid_indices = shift_attention_mask != 0
|
||||
filtered_logits = logits[..., :-1, :][valid_indices]
|
||||
filtered_labels = labels[..., 1:][valid_indices.to(labels.device)]
|
||||
|
||||
# TODO: do we need to handle num_items_in_batch given we filter the logits and labels?
|
||||
|
||||
loss = self.loss_function(
|
||||
logits=filtered_logits,
|
||||
labels=None, # we pass shift_labels
|
||||
shift_labels=filtered_labels,
|
||||
vocab_size=self.config.text_config.vocab_size,
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
# Standard case without filtering
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.text_config.vocab_size,
|
||||
**lm_kwargs,
|
||||
)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Gemma3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma3conditionalgeneration_forward():
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
Gemma3ForConditionalGeneration,
|
||||
)
|
||||
|
||||
Gemma3ForConditionalGeneration.forward = new_forward
|
||||
@@ -252,12 +252,38 @@ def apply_lora_kernel_patches(
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
# Choose activation based on model type
|
||||
activation = model.config.hidden_act
|
||||
activation = None
|
||||
text_config = (
|
||||
model.config.get_text_config()
|
||||
if hasattr(model.config, "get_text_config")
|
||||
else model.config
|
||||
)
|
||||
if hasattr(text_config, "hidden_act"):
|
||||
activation = text_config.hidden_act
|
||||
elif hasattr(text_config, "hidden_activation"):
|
||||
activation = text_config.hidden_activation
|
||||
|
||||
# map activation to supported activation
|
||||
if "gelu" in activation:
|
||||
# gemma3 uses gelu_pytorch_tanh
|
||||
activation = "gelu"
|
||||
|
||||
if activation not in SUPPORTED_ACTIVATIONS:
|
||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||
|
||||
layers = []
|
||||
# check for multimodal models first
|
||||
if hasattr(model, "language_model"):
|
||||
layers = model.language_model.model.layers
|
||||
elif hasattr(model, "model"):
|
||||
layers = model.model.model.layers
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
|
||||
)
|
||||
|
||||
# Patch each layer
|
||||
for layer in model.model.model.layers:
|
||||
for layer in layers:
|
||||
# Add QKV, O fallback implementations to start
|
||||
# These will be overwritten later (if some conditions apply)
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
|
||||
@@ -22,6 +22,10 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"phi3",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"gemma3",
|
||||
"gemma3_text",
|
||||
"cohere",
|
||||
"cohere2",
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
"deepseek_v2",
|
||||
|
||||
278
src/axolotl/processing_strategies.py
Normal file
278
src/axolotl/processing_strategies.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.Image import Resampling
|
||||
from torch import Tensor
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_utils import load_image
|
||||
|
||||
|
||||
class ProcessingStrategy:
|
||||
"""Base Processing Strategy class"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: ProcessorMixin,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
self.processor = processor
|
||||
self.chat_template = chat_template
|
||||
self.image_token = None
|
||||
self.image_token_id = None
|
||||
|
||||
self.image_size = image_size
|
||||
self.image_resize_algorithm = (
|
||||
image_resize_algorithm or Image.Resampling.BILINEAR
|
||||
)
|
||||
|
||||
if hasattr(processor, "image_token"):
|
||||
self.image_token = processor.image_token
|
||||
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
self.image_token
|
||||
)
|
||||
|
||||
def __call__(self, examples: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Preprocess conversation examples to ensure consistent format.
|
||||
Converts different conversation formats to OpenAI format with 'messages'.
|
||||
Supports two formats:
|
||||
1. OpenAI format with 'messages'
|
||||
2. Legacy format with 'conversations'
|
||||
|
||||
Args:
|
||||
examples: list of conversation dictionaries
|
||||
|
||||
Returns:
|
||||
list of dicts in OpenAI format with 'messages' key
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation format is not supported
|
||||
"""
|
||||
role_mapping = {
|
||||
"human": "user",
|
||||
"gpt": "assistant",
|
||||
}
|
||||
|
||||
def normalize_role(role: str) -> str:
|
||||
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
||||
return role_mapping.get(role, role)
|
||||
|
||||
def convert_legacy_format(example: dict) -> dict:
|
||||
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
||||
messages = [
|
||||
{"role": normalize_role(convo["from"]), "content": convo["value"]}
|
||||
for convo in example["conversations"]
|
||||
]
|
||||
|
||||
# Create new dict without 'conversations' key
|
||||
result = deepcopy(example)
|
||||
result.pop("conversations")
|
||||
result["messages"] = messages
|
||||
return result
|
||||
|
||||
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
|
||||
"""Convert regular messages format to Messages format with content type"""
|
||||
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str):
|
||||
new_messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": message["content"],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
elif isinstance(message["content"], list):
|
||||
content = message["content"]
|
||||
|
||||
new_messages.append(
|
||||
{
|
||||
"role": message["role"],
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
return new_messages
|
||||
|
||||
processed_examples = []
|
||||
for example in examples:
|
||||
if not ("messages" in example or "conversations" in example):
|
||||
raise ValueError(
|
||||
"Only `messages` and `conversations` message keys are currently supported."
|
||||
)
|
||||
|
||||
processed_example = None
|
||||
if "messages" in example: # OpenAI format
|
||||
processed_example = example
|
||||
else: # Legacy format
|
||||
processed_example = convert_legacy_format(example)
|
||||
|
||||
# convert regular messages format to Messages format with content type
|
||||
# for compatibility with apply_chat_template
|
||||
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
||||
processed_example["messages"]
|
||||
)
|
||||
|
||||
# find the image key if it exists
|
||||
possible_image_keys = ["images", "image"]
|
||||
image_key = None
|
||||
for key in possible_image_keys:
|
||||
if key in processed_example:
|
||||
image_key = key
|
||||
break
|
||||
|
||||
# if the image key exists, add the image to the first message
|
||||
if image_key is not None:
|
||||
# TODO: check if it's normal to be single image only for common datasets
|
||||
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
||||
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
|
||||
image_value = processed_example[image_key][0]
|
||||
|
||||
# Handle image loading (Image, url, path, base64)
|
||||
image_value = load_image(image_value)
|
||||
|
||||
if self.image_size is not None:
|
||||
assert hasattr(
|
||||
image_value, "resize"
|
||||
), "Image does not have a resize method"
|
||||
|
||||
if isinstance(self.image_size, tuple):
|
||||
image_value = image_value.resize(
|
||||
self.image_size, self.image_resize_algorithm
|
||||
)
|
||||
else:
|
||||
# Set the padding value; here we use black (0, 0, 0) for RGB images
|
||||
padding_color = (0, 0, 0)
|
||||
|
||||
# When image_size is an int (square target), preserve aspect ratio then pad
|
||||
# This is to prevent aspect ratio distortion when resizing to square
|
||||
image_value = ImageOps.pad(
|
||||
image_value,
|
||||
(self.image_size, self.image_size),
|
||||
method=self.image_resize_algorithm,
|
||||
color=padding_color,
|
||||
)
|
||||
|
||||
# Look for any image type in the first message
|
||||
# some dataset have an {type: "image"} in the first message
|
||||
ind_to_add = None
|
||||
|
||||
for i, content in enumerate(
|
||||
processed_example["messages"][0]["content"]
|
||||
):
|
||||
# Usually datasets created with image columns, don't have it in the messages itself
|
||||
if content["type"] == "image" and all(
|
||||
k not in content for k in ["image", "url", "path", "base64"]
|
||||
):
|
||||
ind_to_add = i
|
||||
break
|
||||
|
||||
# If an image type is found, add the image to that index
|
||||
if ind_to_add is not None:
|
||||
processed_example["messages"][0]["content"][ind_to_add][
|
||||
"image"
|
||||
] = image_value
|
||||
else:
|
||||
# if no image type is found, add it to end of the first message
|
||||
processed_example["messages"][0]["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"image": image_value,
|
||||
}
|
||||
)
|
||||
|
||||
processed_examples.append(processed_example)
|
||||
|
||||
return processed_examples
|
||||
|
||||
def process_labels(self, input_ids: Tensor) -> Tensor:
|
||||
labels = input_ids.clone()
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
|
||||
# Ignore the image token index in the loss computation (model specific)
|
||||
labels[labels == self.image_token_id] = -100
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
class Qwen2VLProcessingStrategy(ProcessingStrategy):
|
||||
"""Processing Strategy class for Qwen2-VL"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: ProcessorMixin,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||
self.image_token = "<|image_pad|>" # nosec
|
||||
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
self.image_token
|
||||
)
|
||||
|
||||
|
||||
class Gemma3ProcessingStrategy(ProcessingStrategy):
|
||||
"""Processing Strategy class for Gemma3"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: ProcessorMixin,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||
self.image_token = processor.tokenizer.special_tokens_map["boi_token"]
|
||||
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
self.image_token
|
||||
)
|
||||
|
||||
def process_labels(self, input_ids):
|
||||
labels = input_ids.clone()
|
||||
|
||||
# Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == self.image_token_id] = -100
|
||||
labels[labels == 262144] = -100 # corresponds to <image_soft_token>
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def get_processing_strategy(
|
||||
processor: ProcessorMixin,
|
||||
chat_template,
|
||||
chat_template_type,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
if chat_template_type == "qwen2_vl":
|
||||
return Qwen2VLProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
)
|
||||
if chat_template_type == "gemma3":
|
||||
return Gemma3ProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
)
|
||||
if chat_template_type in [
|
||||
"llama3_2_vision",
|
||||
"llava",
|
||||
"mistral_v7_tekken",
|
||||
"pixtral",
|
||||
]:
|
||||
return ProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
)
|
||||
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
|
||||
@@ -411,11 +411,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if turn_idx >= len(turns):
|
||||
raise ValueError(f"Turn index {turn_idx} out of range")
|
||||
|
||||
# mistral does not output message if it contains only system message
|
||||
# mistral/gemma3 does not output message if it contains only system message
|
||||
if (
|
||||
turn_idx == 0
|
||||
and turns[0].get("role") == "system"
|
||||
and "mistral" in self.tokenizer.name_or_path.lower()
|
||||
and (
|
||||
"mistral" in self.tokenizer.name_or_path.lower()
|
||||
# gemma3 uses gemma tokenizer
|
||||
or "gemma" in self.tokenizer.name_or_path.lower()
|
||||
)
|
||||
):
|
||||
return -1, -1
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from datasets import Dataset
|
||||
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
@@ -26,6 +27,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
@@ -156,6 +158,8 @@ def setup_signal_handler(
|
||||
_model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
cleanup_distributed()
|
||||
sys.exit(0)
|
||||
|
||||
_model_weakref = weakref.ref(model)
|
||||
@@ -169,7 +173,7 @@ def execute_training(
|
||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||
):
|
||||
"""
|
||||
Execute the training process with appropriate backend configurations.
|
||||
Execute the training process with appropriate SDP kernel configurations.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
@@ -177,9 +181,6 @@ def execute_training(
|
||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||
"""
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
@@ -305,7 +306,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||
|
||||
trainer.create_model_card(**model_card_kwarg)
|
||||
except (AttributeError, UnicodeDecodeError):
|
||||
except (AttributeError, UnicodeDecodeError, OfflineModeIsEnabled):
|
||||
pass
|
||||
elif cfg.hub_model_id:
|
||||
# Defensively push to the hub to ensure the model card is updated
|
||||
@@ -317,6 +318,7 @@ def save_initial_configs(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PreTrainedModel,
|
||||
peft_config: PeftConfig | None,
|
||||
processor: ProcessorMixin | None,
|
||||
):
|
||||
"""
|
||||
Save initial configurations before training.
|
||||
@@ -344,6 +346,10 @@ def save_initial_configs(
|
||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||
model.config.save_pretrained(str(output_dir))
|
||||
|
||||
if processor:
|
||||
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
||||
processor.save_pretrained(str(output_dir))
|
||||
|
||||
|
||||
def setup_model_card(cfg: DictDefault):
|
||||
"""
|
||||
@@ -411,6 +417,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
||||
PeftModel | PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PeftConfig | None,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
||||
@@ -426,6 +433,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
||||
- Model
|
||||
- Tokenizer
|
||||
- PEFT config
|
||||
- Processor
|
||||
"""
|
||||
# Load tokenizer, processor and model
|
||||
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
||||
@@ -456,6 +464,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
||||
model,
|
||||
tokenizer,
|
||||
peft_config,
|
||||
processor,
|
||||
)
|
||||
|
||||
|
||||
@@ -472,42 +481,35 @@ def train(
|
||||
Returns:
|
||||
Tuple of (model, tokenizer) after training
|
||||
"""
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer etc.
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||
(
|
||||
trainer,
|
||||
model,
|
||||
tokenizer,
|
||||
peft_config,
|
||||
processor,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
# Determine if we need to resume from a checkpoint
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
|
||||
# Configuration for saving
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
handle_untrained_tokens_fix(
|
||||
cfg, model, tokenizer, train_dataset, safe_serialization
|
||||
)
|
||||
|
||||
# Save initial configs
|
||||
save_initial_configs(cfg, tokenizer, model, peft_config)
|
||||
|
||||
# Set up signal handler for graceful termination
|
||||
# Additional setup
|
||||
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
||||
setup_signal_handler(cfg, model, safe_serialization)
|
||||
|
||||
# Set up badges and config info for model card
|
||||
setup_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# Save the trained model
|
||||
# Save the trained model and cleanup
|
||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||
|
||||
# Create model card
|
||||
create_model_card(cfg, trainer)
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -816,27 +816,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
|
||||
class SaveModelCallback(TrainerCallback):
|
||||
"""Callback to save model on train end"""
|
||||
|
||||
def on_step_end( # pylint: disable=unused-argument
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
# Save
|
||||
if state.global_step >= state.max_steps:
|
||||
control.should_save = True
|
||||
|
||||
def on_train_end( # pylint: disable=unused-argument
|
||||
self, args, state, control, **kwargs
|
||||
):
|
||||
control.should_save = True
|
||||
return control
|
||||
|
||||
|
||||
class GCCallback(TrainerCallback):
|
||||
"""Callback to garbage collect torch cache"""
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,14 +1,59 @@
|
||||
"""
|
||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
|
||||
includes logic for handling sequence parallelism collation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def adjust_position_ids_for_slice(
|
||||
position_ids: torch.Tensor, start_idx: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||
This handles the case where position IDs might not be contiguous due to sample
|
||||
packing.
|
||||
"""
|
||||
# Convert to tensor if not already
|
||||
# Find the boundaries between samples (where position_ids reset)
|
||||
adjusted_pos_ids = position_ids.clone()
|
||||
|
||||
# Process each sequence in the batch
|
||||
for i in range(position_ids.shape[0]):
|
||||
seq = position_ids[i]
|
||||
|
||||
# Find sample boundaries
|
||||
boundaries = []
|
||||
for j in range(1, len(seq)):
|
||||
if seq[j] < seq[j - 1]:
|
||||
boundaries.append(j)
|
||||
|
||||
# No need to adjust if there are no boundaries or this is a single sample
|
||||
if not boundaries:
|
||||
adjusted_pos_ids[i] = seq - start_idx
|
||||
continue
|
||||
|
||||
# Adjust each segment separately
|
||||
prev_boundary = 0
|
||||
for boundary in boundaries:
|
||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||
prev_boundary = boundary
|
||||
|
||||
# Last segment
|
||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||
|
||||
return adjusted_pos_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSeq2Seq:
|
||||
@@ -43,6 +88,8 @@ class DataCollatorForSeq2Seq:
|
||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||
return_tensors (`str`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
sequence_parallel_degree (`int`):
|
||||
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
@@ -53,8 +100,19 @@ class DataCollatorForSeq2Seq:
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
sequence_parallel_degree: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
# Get information about our position in the SP group
|
||||
sp_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(group=sp_group)
|
||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
has_attn_mask = "attention_mask" in features[0].keys()
|
||||
labels = None
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
@@ -107,6 +165,8 @@ class DataCollatorForSeq2Seq:
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
if not has_attn_mask:
|
||||
del features["attention_mask"]
|
||||
|
||||
# prepare decoder_input_ids
|
||||
if (
|
||||
@@ -119,8 +179,43 @@ class DataCollatorForSeq2Seq:
|
||||
)
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
if self.sequence_parallel_degree > 1:
|
||||
features = self.apply_sequence_parallelism(features)
|
||||
|
||||
return features
|
||||
|
||||
def apply_sequence_parallelism(
|
||||
self, batch: dict[str, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary from parent collator.
|
||||
|
||||
Returns:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||
|
||||
for key in keys_to_slice:
|
||||
if key in batch:
|
||||
seq_len = batch[key].shape[1]
|
||||
slice_size = seq_len // self.local_world_size
|
||||
start_idx = self.local_rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size
|
||||
if self.local_rank < self.local_world_size - 1
|
||||
else seq_len
|
||||
)
|
||||
batch[key] = batch[key][:, start_idx:end_idx]
|
||||
|
||||
# Special handling for position_ids
|
||||
if key == "position_ids" and self.local_rank > 0:
|
||||
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
@@ -148,6 +243,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@@ -177,6 +273,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
|
||||
@@ -2,15 +2,17 @@
|
||||
Collators for multi-modal chat messages and packing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.data.data_collator import DataCollatorMixin
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.processing_strategies import ProcessingStrategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
@@ -19,11 +21,9 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
processor: ProcessorMixin
|
||||
return_tensors: str = "pt"
|
||||
chat_template: Optional[str] = None
|
||||
processing_strategy: ProcessingStrategy
|
||||
packing: bool = False
|
||||
max_images: int = -1
|
||||
return_tensors: str = "pt"
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
@@ -31,162 +31,62 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
if self.packing:
|
||||
raise ValueError("Packing is currently not supported.")
|
||||
|
||||
def torch_call(
|
||||
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
||||
) -> dict[str, Any]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
|
||||
return self.__class__.process_rows(
|
||||
examples, self.processor, self.chat_template, self.max_images
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||
# see also DataCollatorWithFlattening and DefaultDataCollator
|
||||
|
||||
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||
# use this as a starting point
|
||||
|
||||
def _preprocess(examples: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Preprocess conversation examples to ensure consistent format.
|
||||
|
||||
Converts different conversation formats to OpenAI format with 'messages'.
|
||||
Supports two formats:
|
||||
1. OpenAI format with 'messages'
|
||||
2. Legacy format with 'conversations'
|
||||
|
||||
Args:
|
||||
examples: list of conversation dictionaries
|
||||
|
||||
Returns:
|
||||
dict in OpenAI format with 'messages' key
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation format is not supported
|
||||
"""
|
||||
role_mapping = {
|
||||
"human": "user",
|
||||
"gpt": "assistant",
|
||||
}
|
||||
|
||||
def normalize_role(role: str) -> str:
|
||||
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
||||
return role_mapping.get(role, role)
|
||||
|
||||
def convert_legacy_format(example: dict) -> dict:
|
||||
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
||||
messages = [
|
||||
{
|
||||
"role": normalize_role(convo["from"]),
|
||||
"content": convo["value"],
|
||||
}
|
||||
for convo in example["conversations"]
|
||||
]
|
||||
|
||||
# Create new dict without 'conversations' key
|
||||
result = deepcopy(example)
|
||||
result.pop("conversations")
|
||||
return {"messages": messages, **result}
|
||||
|
||||
processed_examples = []
|
||||
for example in examples:
|
||||
# OpenAI format
|
||||
if "messages" in example:
|
||||
processed_examples.append(example)
|
||||
|
||||
# Legacy format
|
||||
elif "conversations" in example:
|
||||
processed_examples.append(convert_legacy_format(example))
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only `messages` and `conversations` message keys are currently supported."
|
||||
)
|
||||
|
||||
return processed_examples
|
||||
|
||||
def _process_images(examples, max_images):
|
||||
"""
|
||||
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
||||
|
||||
Args:
|
||||
examples: List of dictionaries that may contain 'images' key
|
||||
max_images: Maximum number of images to keep per example (0 means no limit)
|
||||
|
||||
Returns:
|
||||
Either None (if no images) or List[Image objects] (if all examples have images)
|
||||
|
||||
Raises:
|
||||
ValueError: If there's a mix of None and non-None images
|
||||
"""
|
||||
|
||||
def get_image(example):
|
||||
if "images" not in example:
|
||||
return None
|
||||
images = example["images"]
|
||||
if isinstance(images, str):
|
||||
return Image.open(images)
|
||||
return images
|
||||
|
||||
images = [get_image(example) for example in examples]
|
||||
|
||||
# Count None and non-None images
|
||||
none_count = sum(1 for img in images if img is None)
|
||||
|
||||
# All images are None
|
||||
if none_count == len(images):
|
||||
return None
|
||||
|
||||
# Mix of None and non-None images
|
||||
if none_count > 0:
|
||||
raise ValueError(
|
||||
"All images should be either None or not None. "
|
||||
"Please provide images for all examples or None."
|
||||
)
|
||||
|
||||
# Apply max_images limit if specified
|
||||
if max_images > 0:
|
||||
images = [
|
||||
(
|
||||
img_batch[:max_images]
|
||||
if isinstance(img_batch, (list, tuple))
|
||||
else img_batch
|
||||
)
|
||||
for img_batch in images
|
||||
]
|
||||
|
||||
return images
|
||||
def torch_call(self, examples: list[dict]) -> dict[str, Any]:
|
||||
return self.process_rows(examples)
|
||||
|
||||
def process_rows(
|
||||
self,
|
||||
examples: list[dict],
|
||||
) -> dict[str, Tensor]:
|
||||
# Preprocess the examples
|
||||
examples = _preprocess(examples)
|
||||
examples = self.processing_strategy(examples)
|
||||
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
example["messages"], chat_template=chat_template, tokenize=False
|
||||
# Initialize batch
|
||||
batch: dict[str, Any] = {}
|
||||
|
||||
# Process each example
|
||||
for example in examples:
|
||||
# Apply chat template to process the example
|
||||
# This method requires transformers>=4.49.0
|
||||
result = self.processing_strategy.processor.apply_chat_template(
|
||||
example["messages"],
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_dict=True,
|
||||
chat_template=self.processing_strategy.chat_template,
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
|
||||
images = _process_images(examples, max_images=max_images)
|
||||
# TODO: Check if need handling for len(input_ids) > sequence_len
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
# Add the processed tensors to our batch
|
||||
for key in result.keys():
|
||||
if key not in batch:
|
||||
batch[key] = []
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||
# Ignore the image token index in the loss computation (model specific)
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.image_token
|
||||
batch[key].append(result[key].squeeze(0))
|
||||
|
||||
# Pad sequences to the same length
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
batch["input_ids"],
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
)
|
||||
labels[labels == image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
if length_only:
|
||||
return {
|
||||
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||
}
|
||||
return batch
|
||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
batch["attention_mask"], batch_first=True, padding_value=0
|
||||
)
|
||||
|
||||
# Create the final batch
|
||||
final_batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
# Process the labels
|
||||
final_batch["labels"] = self.processing_strategy.process_labels(
|
||||
final_batch["input_ids"]
|
||||
)
|
||||
|
||||
return final_batch
|
||||
|
||||
@@ -13,7 +13,7 @@ from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.config import merge_input_args
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config
|
||||
from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
@@ -78,6 +78,7 @@ def resolve_dtype(cfg):
|
||||
cfg.bf16 = False
|
||||
else:
|
||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||
torch.backends.cudnn.allow_tf32 = cfg.tf32 or False
|
||||
if cfg.bf16:
|
||||
cfg.fp16 = False
|
||||
|
||||
@@ -125,6 +126,9 @@ def normalize_config(cfg):
|
||||
with open(ds_config_path, encoding="utf-8") as f:
|
||||
cfg.deepspeed = json.load(f)
|
||||
|
||||
if cfg.sequence_parallel_degree is None:
|
||||
cfg.sequence_parallel_degree = 1
|
||||
|
||||
if cfg.saves_per_epoch:
|
||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||
if save_steps < 1.0: # prevent saves on every step
|
||||
@@ -155,7 +159,7 @@ def normalize_config(cfg):
|
||||
|
||||
cfg.is_multimodal = (
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type in ["llava", "mllama"]
|
||||
and model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING
|
||||
or any(
|
||||
multimodal_name in cfg.base_model.lower()
|
||||
for multimodal_name in [
|
||||
@@ -168,7 +172,6 @@ def normalize_config(cfg):
|
||||
cfg.processor_config = (
|
||||
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
||||
)
|
||||
model_config = model_config.text_config
|
||||
|
||||
cfg.model_config_type = model_config.model_type
|
||||
|
||||
|
||||
@@ -6,8 +6,12 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import HFValidationError
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from huggingface_hub.errors import (
|
||||
HFValidationError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
)
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -70,20 +74,25 @@ def load_dataset_w_config(
|
||||
# pylint: disable=invalid-name
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||
ds_from_hub = False
|
||||
ds_trust_remote_code = config_dataset.trust_remote_code
|
||||
try:
|
||||
# this is just a basic check to see if the path is a
|
||||
# valid HF dataset that's loadable
|
||||
load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=True,
|
||||
snapshot_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=ds_trust_remote_code,
|
||||
ignore_patterns=["*"],
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||
except (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
FileNotFoundError,
|
||||
ConnectionError,
|
||||
HFValidationError,
|
||||
ValueError,
|
||||
):
|
||||
pass
|
||||
|
||||
ds_from_cloud = False
|
||||
@@ -229,7 +238,8 @@ def load_dataset_w_config(
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
else:
|
||||
elif config_dataset.data_files:
|
||||
fp: str | list[str] | None = None
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
|
||||
@@ -71,8 +71,8 @@ def barrier():
|
||||
|
||||
def is_main_process():
|
||||
"""
|
||||
Check if the current process is the main process.
|
||||
If not in distributed mode, always return True.
|
||||
Check if the current process is the main process. If not in distributed mode,
|
||||
always return `True`.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return True
|
||||
@@ -87,6 +87,18 @@ def get_world_size():
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
def cleanup_distributed():
|
||||
"""
|
||||
Destroy process group if torch distributed is initialized. Called in training early
|
||||
termination or when training successfully completes.
|
||||
"""
|
||||
# Ensure that all operations are completed before destroying the process group
|
||||
torch.cuda.synchronize()
|
||||
# Destroy the process group
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_only():
|
||||
"""
|
||||
|
||||
@@ -8,7 +8,7 @@ import math
|
||||
import os
|
||||
import types
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import addict
|
||||
import bitsandbytes as bnb
|
||||
@@ -25,7 +25,7 @@ from peft import (
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers import ( # noqa: F401
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -34,12 +34,17 @@ from transformers import ( # noqa: F401
|
||||
AutoTokenizer,
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
Gemma3ForConditionalGeneration,
|
||||
GPTQConfig,
|
||||
LlavaForConditionalGeneration,
|
||||
Mistral3ForConditionalGeneration,
|
||||
MllamaForConditionalGeneration,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
from transformers.integrations.deepspeed import (
|
||||
HfTrainerDeepSpeedConfig,
|
||||
@@ -67,7 +72,16 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||
"mllama": MllamaForConditionalGeneration,
|
||||
"llava": LlavaForConditionalGeneration,
|
||||
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
||||
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
||||
"mistral3": Mistral3ForConditionalGeneration,
|
||||
"gemma3": Gemma3ForConditionalGeneration,
|
||||
}
|
||||
|
||||
|
||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||
@@ -94,9 +108,30 @@ def get_module_class_from_name(module, name):
|
||||
return None
|
||||
|
||||
|
||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||
def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||
# Set use_cache to False
|
||||
if hasattr(model_config, "use_cache"):
|
||||
model_config.use_cache = False
|
||||
|
||||
if cfg.is_multimodal:
|
||||
model_config = model_config.text_config
|
||||
# For multimodal configs, use_cache is set in the text_config
|
||||
if hasattr(model_config, "get_text_config"):
|
||||
text_config = model_config.get_text_config()
|
||||
if hasattr(text_config, "use_cache"):
|
||||
text_config.use_cache = False
|
||||
else:
|
||||
raise ValueError(
|
||||
"No text config found for multimodal model. Please raise an Issue with model details."
|
||||
)
|
||||
|
||||
# check if image_size is not set and load image size from model config if available
|
||||
if (
|
||||
cfg.image_size is None
|
||||
and hasattr(model_config, "vision_config")
|
||||
and hasattr(model_config.vision_config, "image_size")
|
||||
):
|
||||
cfg.image_size = model_config.vision_config.image_size
|
||||
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
|
||||
|
||||
quant_config_exists = (
|
||||
hasattr(model_config, "quantization_config")
|
||||
@@ -435,6 +470,31 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
# Attempt to load image size from processor if available
|
||||
if (
|
||||
cfg.image_size is None
|
||||
and hasattr(processor, "size")
|
||||
and any(dim in processor.size for dim in ["width", "height"])
|
||||
):
|
||||
im_width = None
|
||||
im_height = None
|
||||
if "width" in processor.size:
|
||||
im_width = processor.size["width"]
|
||||
if "height" in processor.size:
|
||||
im_height = processor.size["height"]
|
||||
|
||||
# If both width and height are set, use a tuple
|
||||
if im_width is not None and im_height is not None:
|
||||
cfg.image_size = (im_width, im_height)
|
||||
# If only width is set, use as integer
|
||||
elif im_width is not None:
|
||||
cfg.image_size = im_width
|
||||
# If only height is set, use as integer
|
||||
elif im_height is not None:
|
||||
cfg.image_size = im_height
|
||||
|
||||
LOG.debug(f"Loaded image size: {cfg.image_size} from processor")
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
@@ -471,14 +531,19 @@ class ModelLoader:
|
||||
|
||||
# init model config
|
||||
self.model_config = load_model_config(cfg)
|
||||
if cfg.is_multimodal:
|
||||
self.text_model_config = self.model_config.text_config
|
||||
else:
|
||||
self.text_model_config = self.model_config
|
||||
|
||||
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
|
||||
def apply_patches(self) -> None:
|
||||
# patch gemma3 conditional generation forward before loading plugins
|
||||
# as it could be overridden by plugins
|
||||
if self.cfg.model_config_type == "gemma3":
|
||||
from axolotl.monkeypatch.gemma3 import (
|
||||
patch_gemma3conditionalgeneration_forward,
|
||||
)
|
||||
|
||||
patch_gemma3conditionalgeneration_forward()
|
||||
|
||||
# load any patches from plugins
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
@@ -547,6 +612,17 @@ class ModelLoader:
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||
|
||||
# Initialize ring attn for sequence parallelism. This must be done after
|
||||
# model init but before the first forward pass, since it modifies flash
|
||||
# attn to use ring comm for SP training across multiple GPUs.
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
||||
heads_k_stride=self.cfg.heads_k_stride,
|
||||
)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||
@@ -603,7 +679,7 @@ class ModelLoader:
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def patch_llama_derived_model(self) -> None:
|
||||
def patch_llama_derived_model(self):
|
||||
"""Modify all llama derived models in one block"""
|
||||
self.patch_loss_llama()
|
||||
|
||||
@@ -653,25 +729,16 @@ class ModelLoader:
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
)
|
||||
|
||||
def set_auto_model_loader(self) -> None:
|
||||
"""set self.AutoModelLoader
|
||||
- default value: AutoModelForCausalLM (set at __init__)
|
||||
- when using a multi modality model, self.AutoModelLoader should
|
||||
be set according to model type of the model
|
||||
def set_auto_model_loader(self):
|
||||
"""
|
||||
Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
|
||||
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||
should be set according to the type of the model.
|
||||
"""
|
||||
if self.cfg.is_multimodal:
|
||||
if self.model_config.model_type == "llava":
|
||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||
LlavaForConditionalGeneration
|
||||
)
|
||||
elif self.model_config.model_type == "mllama":
|
||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||
MllamaForConditionalGeneration
|
||||
)
|
||||
else:
|
||||
self.AutoModelLoader = (
|
||||
AutoModelForVision2Seq # pylint: disable=invalid-name
|
||||
)
|
||||
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
||||
self.model_config.model_type, AutoModelForVision2Seq
|
||||
)
|
||||
|
||||
def set_device_map_config(self) -> None:
|
||||
device_map = self.cfg.device_map
|
||||
@@ -695,7 +762,7 @@ class ModelLoader:
|
||||
from accelerate import infer_auto_device_map
|
||||
|
||||
with init_empty_weights():
|
||||
model_canvas = self.AutoModelLoader.from_config(
|
||||
model_canvas = self.auto_model_loader.from_config(
|
||||
self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
)
|
||||
@@ -892,8 +959,6 @@ class ModelLoader:
|
||||
quantization_config = (
|
||||
quantization_config or self.model_kwargs["quantization_config"]
|
||||
)
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = load_sharded_model_quant(
|
||||
self.base_model,
|
||||
self.model_config,
|
||||
@@ -914,13 +979,26 @@ class ModelLoader:
|
||||
|
||||
_ = _configure_zero3_memory_efficient_loading()
|
||||
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = self.AutoModelLoader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
# Load model with random initialization if specified
|
||||
if self.cfg.random_init_weights:
|
||||
# AutoModel classes support the from_config method
|
||||
if self.auto_model_loader in [
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq,
|
||||
]:
|
||||
self.model = self.auto_model_loader.from_config(
|
||||
config=self.model_config,
|
||||
)
|
||||
else:
|
||||
self.model = self.auto_model_loader(
|
||||
config=self.model_config,
|
||||
)
|
||||
else:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
|
||||
# TODO (MengqingCao) split these patches seperately
|
||||
if self.cfg.flash_attention and not self.inference:
|
||||
@@ -955,10 +1033,8 @@ class ModelLoader:
|
||||
and self.model_type != "AutoModelForCausalLM"
|
||||
and not self.cfg.trust_remote_code
|
||||
):
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
if self.cfg.gptq:
|
||||
self.model = self.AutoModelLoader.from_pretrained(
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
@@ -972,26 +1048,8 @@ class ModelLoader:
|
||||
**self.model_kwargs,
|
||||
)
|
||||
else:
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if (
|
||||
hasattr(self.text_model_config, "max_seq_len")
|
||||
and self.text_model_config.max_seq_len
|
||||
and self.cfg.sequence_len > self.text_model_config.max_seq_len
|
||||
):
|
||||
self.text_model_config.max_seq_len = self.cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(self.text_model_config, "max_sequence_length")
|
||||
and self.text_model_config.max_sequence_length
|
||||
and self.cfg.sequence_len > self.text_model_config.max_sequence_length
|
||||
):
|
||||
self.text_model_config.max_sequence_length = self.cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
|
||||
if self.cfg.gptq:
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = self.AutoModelLoader.from_pretrained(
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
@@ -1009,9 +1067,7 @@ class ModelLoader:
|
||||
|
||||
_ = _configure_zero3_memory_efficient_loading()
|
||||
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = self.AutoModelLoader.from_pretrained(
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
@@ -1174,7 +1230,9 @@ class ModelLoader:
|
||||
)
|
||||
):
|
||||
resize_kwargs = {}
|
||||
if self.cfg.mean_resizing_embeddings is not None:
|
||||
if self.cfg.mean_resizing_embeddings is not None and not (
|
||||
self.model_config.model_type == "llava"
|
||||
):
|
||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||
else:
|
||||
@@ -1273,8 +1331,6 @@ class ModelLoader:
|
||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||
if len(requires_grad) == 0:
|
||||
LOG.warning("there are no parameters that require gradient updates")
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
if self.cfg.flash_optimum:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
@@ -1307,7 +1363,7 @@ def load_model(
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
loader = ModelLoader(
|
||||
model_loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -1315,7 +1371,7 @@ def load_model(
|
||||
reference_model=reference_model,
|
||||
**kwargs,
|
||||
)
|
||||
return loader.load_model()
|
||||
return model_loader.load_model()
|
||||
|
||||
|
||||
def load_adapter(model, cfg, adapter, inference=False):
|
||||
|
||||
@@ -8,11 +8,13 @@ from typing import Any, Iterable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
|
||||
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@numba.njit
|
||||
@@ -103,10 +105,57 @@ def allocate(
|
||||
return result, s, len(result) * c * n
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
|
||||
"""
|
||||
Sequential allocator that preserves example order
|
||||
|
||||
Parameters:
|
||||
- lengths: The lengths of all examples
|
||||
- rank: The current rank (for distributed training)
|
||||
- c: The capacity of each bin (maximum sequence length)
|
||||
- n: Number of ranks
|
||||
|
||||
Returns:
|
||||
- result: List of batches for the current rank
|
||||
- total_used: Number of actual example tokens
|
||||
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
|
||||
"""
|
||||
result = []
|
||||
total_used = 0
|
||||
|
||||
# First, do sequential packing into bins
|
||||
all_bins = []
|
||||
current_bin = [0 for i in range(0)] # numba hint
|
||||
remaining_capacity = c
|
||||
|
||||
for idx, size in enumerate(lengths):
|
||||
if size <= remaining_capacity:
|
||||
# Example fits in current bin
|
||||
current_bin.append(idx)
|
||||
remaining_capacity -= size
|
||||
total_used += size
|
||||
else:
|
||||
# Example doesn't fit, start a new bin
|
||||
if current_bin: # Add non-empty bin to all_bins
|
||||
all_bins.append(current_bin)
|
||||
current_bin = [idx]
|
||||
remaining_capacity = c - size
|
||||
total_used += size
|
||||
|
||||
# Add the last bin if not empty
|
||||
if current_bin:
|
||||
all_bins.append(current_bin)
|
||||
|
||||
# Assign bins to ranks - each rank gets every n-th bin
|
||||
for bin_idx in range(rank, len(all_bins), n):
|
||||
result.append(all_bins[bin_idx])
|
||||
|
||||
return result, total_used, len(all_bins) * c
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""
|
||||
Batch Sampler class for multipack
|
||||
"""
|
||||
"""Batch sampler class for multipack"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -117,6 +166,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
drop_last: bool = False,
|
||||
num_count_samples: int = 16,
|
||||
sequential: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
@@ -124,6 +174,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.sequential = sequential
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
@@ -138,6 +189,11 @@ class MultipackBatchSampler(BatchSampler):
|
||||
# the minimum packed dataset length across all ranks determined by a gather/broadcast
|
||||
self.len_across_ranks = None
|
||||
|
||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||
LOG.warn(
|
||||
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
||||
)
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
@@ -147,13 +203,21 @@ class MultipackBatchSampler(BatchSampler):
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
if self.sequential:
|
||||
batches, total_used, total_slots = allocate_sequentially(
|
||||
lengths=lengths,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
else:
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [
|
||||
[
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Main Axolotl input configuration Pydantic models"""
|
||||
"""Module with Pydantic models for configuration."""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
@@ -42,9 +42,11 @@ from axolotl.utils.schemas.model import (
|
||||
ModelOutputConfig,
|
||||
SpecialTokensConfig,
|
||||
)
|
||||
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,6 +66,7 @@ class AxolotlInputConfig(
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RayConfig,
|
||||
MultiModalConfig,
|
||||
RemappedParameters,
|
||||
DeprecatedParameters,
|
||||
BaseModel,
|
||||
@@ -84,6 +87,9 @@ class AxolotlInputConfig(
|
||||
trl: TRLConfig | None = Field(
|
||||
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
reward_model: bool | None = None
|
||||
process_reward_model: bool | None = None
|
||||
num_labels: int | None = None
|
||||
@@ -186,6 +192,7 @@ class AxolotlInputConfig(
|
||||
sample_packing: bool | None = None
|
||||
sample_packing_group_size: int | None = 100_000
|
||||
sample_packing_bin_size: int | None = 200
|
||||
sample_packing_sequentially: bool | None = None
|
||||
eval_sample_packing: bool | None = None
|
||||
pad_to_sequence_len: bool | None = None
|
||||
curriculum_sampling: bool | None = None
|
||||
@@ -245,6 +252,9 @@ class AxolotlInputConfig(
|
||||
|
||||
val_set_size: float | None = Field(default=0.0)
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
heads_k_stride: int | None = None
|
||||
|
||||
special_tokens: SpecialTokensConfig | None = None
|
||||
tokens: list[str] | None = None
|
||||
added_tokens_overrides: dict[int, str] | None = None
|
||||
@@ -1102,6 +1112,40 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@field_validator("sequence_parallel_degree", mode="before")
|
||||
@classmethod
|
||||
def check_sequence_parallel_degree(cls, value, info):
|
||||
if not value:
|
||||
value = 1
|
||||
|
||||
if value > 1:
|
||||
if not info.data.get("flash_attention"):
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
if data.get("optimizer") == "muon" and (
|
||||
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
|
||||
):
|
||||
raise ValueError(
|
||||
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
@@ -1180,17 +1224,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp") is not None
|
||||
is_deepspeed = data.get("deepspeed") is not None
|
||||
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||
if is_fsdp:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
||||
)
|
||||
if is_deepspeed:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1237,3 +1276,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data["beta"] != data["trl"]["beta"]:
|
||||
raise ValueError("beta and trl.beta must match or one must be removed")
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_min_torch_version(self):
|
||||
if self.env_capabilities and self.env_capabilities.torch_version:
|
||||
torch_version = self.env_capabilities.torch_version
|
||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||
LOG.warning(
|
||||
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@@ -22,6 +22,7 @@ class ChatTemplate(str, Enum):
|
||||
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
@@ -36,6 +37,10 @@ class ChatTemplate(str, Enum):
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
exaone = "exaone" # pylint: disable=invalid-name
|
||||
metharme = "metharme" # pylint: disable=invalid-name
|
||||
pixtral = "pixtral" # pylint: disable=invalid-name
|
||||
llava = "llava" # pylint: disable=invalid-name
|
||||
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
|
||||
gemma3 = "gemma3" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CustomSupportedOptimizers(str, Enum):
|
||||
|
||||
48
src/axolotl/utils/schemas/multimodal.py
Normal file
48
src/axolotl/utils/schemas/multimodal.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Pydantic models for multimodal-related configuration"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from PIL.Image import Resampling
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class MultiModalConfig(BaseModel):
|
||||
"""Multi-modal configuration subset"""
|
||||
|
||||
image_size: int | tuple[int, int] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"The size of the image to resize to. It can be an integer (resized into padded-square image) or a tuple (width, height)."
|
||||
"If not provided, we will attempt to load from preprocessor.size, otherwise, images won't be resized."
|
||||
)
|
||||
},
|
||||
)
|
||||
image_resize_algorithm: (
|
||||
Literal["bilinear", "bicubic", "lanczos"] | Resampling | None
|
||||
) = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
|
||||
},
|
||||
)
|
||||
|
||||
@field_validator("image_resize_algorithm", mode="before")
|
||||
@classmethod
|
||||
def convert_image_resize_algorithm(cls, image_resize_algorithm):
|
||||
"""
|
||||
Convert the image resize algorithm to a PIL.Image.Resampling enum.
|
||||
"""
|
||||
if isinstance(image_resize_algorithm, str):
|
||||
image_resize_algorithm = image_resize_algorithm.lower()
|
||||
if image_resize_algorithm == "bilinear":
|
||||
image_resize_algorithm = Resampling.BILINEAR
|
||||
elif image_resize_algorithm == "bicubic":
|
||||
image_resize_algorithm = Resampling.BICUBIC
|
||||
elif image_resize_algorithm == "lanczos":
|
||||
image_resize_algorithm = Resampling.LANCZOS
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid image resize algorithm: {image_resize_algorithm}"
|
||||
)
|
||||
return image_resize_algorithm
|
||||
@@ -20,27 +20,30 @@ class TRLConfig(BaseModel):
|
||||
)
|
||||
|
||||
# GRPO specific args
|
||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||
use_vllm: bool | None = Field(
|
||||
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
|
||||
use_vllm: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||
)
|
||||
vllm_device: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Device to use for VLLM"},
|
||||
vllm_server_host: str | None = Field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
json_schema_extra={"description": "Host of the vLLM server to connect to"},
|
||||
)
|
||||
vllm_gpu_memory_utilization: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
vllm_server_port: int | None = Field(
|
||||
default=8000,
|
||||
json_schema_extra={"description": "Port of the vLLM server to connect to"},
|
||||
)
|
||||
vllm_dtype: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Data type for VLLM"},
|
||||
)
|
||||
vllm_max_model_len: int | None = Field(
|
||||
vllm_server_timeout: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the model context for VLLM"
|
||||
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
|
||||
"after the timeout, a `ConnectionError` is raised."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -85,3 +88,48 @@ class TRLConfig(BaseModel):
|
||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
scale_rewards: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
|
||||
},
|
||||
)
|
||||
|
||||
temperature: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Sampling temperature for the GRPO policy."},
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Top-p sampling probability for the generation policy."
|
||||
},
|
||||
)
|
||||
top_k: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Top-k sampling for the generation policy."},
|
||||
)
|
||||
min_p: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Minimum probability for the generation policy."
|
||||
},
|
||||
)
|
||||
repetition_penalty: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
|
||||
},
|
||||
)
|
||||
num_iterations: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
|
||||
},
|
||||
)
|
||||
epsilon: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Epsilon value for clipping in the GRPO algorithm."
|
||||
},
|
||||
)
|
||||
|
||||
38
src/axolotl/utils/schemas/vllm.py
Normal file
38
src/axolotl/utils/schemas/vllm.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Pydantic models for VLLM configuration, used primarily for RL training with TRL + grpo
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VllmConfig(BaseModel):
|
||||
"""
|
||||
Configuration for VLLM server
|
||||
"""
|
||||
|
||||
device: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Device to use for VLLM"},
|
||||
)
|
||||
tensor_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Tensor parallel size for VLLM"},
|
||||
)
|
||||
gpu_memory_utilization: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
)
|
||||
dtype: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Data type for VLLM"},
|
||||
)
|
||||
max_model_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the model context for VLLM"
|
||||
},
|
||||
)
|
||||
enable_prefix_caching: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Enable prefix caching for VLLM"},
|
||||
)
|
||||
@@ -13,7 +13,7 @@ import torch
|
||||
import torch.cuda
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
@@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if cfg.model_config_type == "mamba":
|
||||
if cfg.model_config_type in ["mamba", "gemma3"]:
|
||||
LOG.info("dropping attention_mask column")
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
elif cfg.sample_packing:
|
||||
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if cfg.eval_sample_packing is not False:
|
||||
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
add_position_ids,
|
||||
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||
@@ -455,13 +456,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
else:
|
||||
sampler_batch_size = cfg.micro_batch_size
|
||||
batch_max_len = cfg.sequence_len
|
||||
if cfg.curriculum_sampling:
|
||||
sampler = SequentialSampler(train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
sampler=sampler,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
batch_size=sampler_batch_size,
|
||||
batch_max_len=batch_max_len,
|
||||
group_size=cfg.sample_packing_group_size,
|
||||
bin_size=cfg.sample_packing_bin_size,
|
||||
sequential=cfg.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
@@ -473,7 +479,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||
)
|
||||
)
|
||||
|
||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||
@@ -494,7 +504,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
)
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||
return total_num_steps
|
||||
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@@ -8,10 +8,16 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
import requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
|
||||
|
||||
|
||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
@@ -25,9 +31,11 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
except (
|
||||
requests.exceptions.ReadTimeout,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.HTTPError,
|
||||
) as exc:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(delay)
|
||||
wait = 2**attempt * delay # in seconds
|
||||
time.sleep(wait)
|
||||
else:
|
||||
raise exc
|
||||
|
||||
@@ -37,26 +45,35 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
|
||||
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
@disable_hf_offline
|
||||
def snapshot_download_w_retry(*args, **kwargs):
|
||||
return snapshot_download(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_ds_fixture_bundle():
|
||||
ds_dir = snapshot_download_w_retry(
|
||||
"axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
|
||||
)
|
||||
return Path(ds_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_smollm2_135m_model():
|
||||
# download the model
|
||||
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
|
||||
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama_68m_random_model():
|
||||
# download the model
|
||||
snapshot_download_w_retry("JackFram/llama-68m")
|
||||
snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_qwen_2_5_half_billion_model():
|
||||
# download the model
|
||||
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
|
||||
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -94,13 +111,52 @@ def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||
def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
# @pytest.fixture(scope="session", autouse=True)
|
||||
# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||
# # download the dataset
|
||||
# snapshot_download_w_retry(
|
||||
# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||
# )
|
||||
|
||||
|
||||
# @pytest.fixture(scope="session", autouse=True)
|
||||
# def download_fozzie_alpaca_dpo_dataset():
|
||||
# # download the dataset
|
||||
# snapshot_download_w_retry(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
||||
# )
|
||||
# snapshot_download_w_retry(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
# repo_type="dataset",
|
||||
# revision="ea82cff",
|
||||
# )
|
||||
|
||||
|
||||
# @pytest.fixture(scope="session")
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(scope="session")
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# return load_dataset(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
# )
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
||||
# download the dataset
|
||||
@@ -109,10 +165,192 @@ def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_argilla_dpo_pairs_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tiny_shakespeare_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry("Trelis/tiny-shakespeare", repo_type="dataset")
|
||||
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_deepseek_model_fixture():
|
||||
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_huggyllama_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"huggyllama/llama-7b",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama_1b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"NousResearch/Llama-3.2-1B",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama3_8b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama3_8b_instruct_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"NousResearch/Meta-Llama-3-8B-Instruct",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_phi_35_mini_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_phi_3_medium_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"microsoft/Phi-3-medium-128k-instruct",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_mistral_7b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"casperhansen/mistral-7b-instruct-v0.1-awq",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_gemma_2b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"unsloth/gemma-2b-it",
|
||||
revision="703fb4a",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_gemma2_9b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"mlx-community/gemma-2-9b-it-4bit",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_mlx_mistral_7b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"mlx-community/Mistral-7B-Instruct-v0.3-4bit",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_llama2_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"NousResearch/Llama-2-7b-hf",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama(
|
||||
download_huggyllama_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
tokenizer.pad_token = "</s>"
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama_w_special_tokens(
|
||||
tokenizer_huggyllama,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
tokenizer_huggyllama.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
return tokenizer_huggyllama
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_llama2_7b(
|
||||
download_llama2_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_mistral_7b_instruct(
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
||||
tokenizer_mistral_7b_instruct.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer_mistral_7b_instruct.add_tokens(
|
||||
[
|
||||
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
return tokenizer_mistral_7b_instruct
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -178,3 +416,88 @@ def cleanup_monkeypatches():
|
||||
module_globals = module_name_tuple[1]
|
||||
for module_global in module_globals:
|
||||
globals().pop(module_global, None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_winglian_tiny_shakespeare(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
|
||||
return datasets.load_from_disk(ds_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_tatsu_lab_alpaca(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_mhenrichsen_alpaca_2k_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
|
||||
)
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
|
||||
)
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
# # pylint: disable=redefined-outer-name,unused-argument
|
||||
# def test_load_fixtures(
|
||||
# download_smollm2_135m_model,
|
||||
# download_llama_68m_random_model,
|
||||
# download_qwen_2_5_half_billion_model,
|
||||
# download_tatsu_lab_alpaca_dataset,
|
||||
# download_mhenrichsen_alpaca_2k_dataset,
|
||||
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
|
||||
# download_mlabonne_finetome_100k_dataset,
|
||||
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
|
||||
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
|
||||
# download_argilla_dpo_pairs_dataset,
|
||||
# download_tiny_shakespeare_dataset,
|
||||
# download_deepseek_model_fixture,
|
||||
# download_huggyllama_model_fixture,
|
||||
# download_llama_1b_model_fixture,
|
||||
# download_llama3_8b_model_fixture,
|
||||
# download_llama3_8b_instruct_model_fixture,
|
||||
# download_phi_35_mini_model_fixture,
|
||||
# download_phi_3_medium_model_fixture,
|
||||
# download_mistral_7b_model_fixture,
|
||||
# download_gemma_2b_model_fixture,
|
||||
# download_gemma2_9b_model_fixture,
|
||||
# download_mlx_mistral_7b_model_fixture,
|
||||
# download_llama2_model_fixture,
|
||||
# ):
|
||||
# pass
|
||||
|
||||
@@ -10,10 +10,13 @@ from transformers import AddedToken, AutoTokenizer
|
||||
from axolotl.core.chat.format.chatml import format_message
|
||||
from axolotl.core.chat.messages import ChatFormattedChats, Chats
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline # noqa
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", name="llama_tokenizer")
|
||||
@enable_hf_offline
|
||||
def llama_tokenizer_fixture():
|
||||
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
|
||||
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", name="chatml_tokenizer")
|
||||
|
||||
@@ -5,7 +5,6 @@ e2e tests for kd trainer support in Axolotl
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
@@ -13,6 +12,8 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
|
||||
|
||||
@pytest.fixture(name="kd_min_cfg")
|
||||
def min_cfg(temp_dir):
|
||||
|
||||
@@ -2,15 +2,13 @@
|
||||
Simple end-to-end test for Liger integration
|
||||
"""
|
||||
|
||||
from e2e.utils import require_torch_2_4_1
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists
|
||||
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
||||
|
||||
|
||||
class LigerIntegrationTestCase:
|
||||
|
||||
0
tests/e2e/multigpu/solo/__init__.py
Normal file
0
tests/e2e/multigpu/solo/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user