Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
798c8fba89 | ||
|
|
17fc747f99 | ||
|
|
901f2356bc |
5
.github/CONTRIBUTING.md
vendored
5
.github/CONTRIBUTING.md
vendored
@@ -31,10 +31,11 @@ PRs are **greatly welcome**!
|
|||||||
|
|
||||||
Please run below to setup env
|
Please run below to setup env
|
||||||
```bash
|
```bash
|
||||||
# Install axolotl + dev and test dependencies from lockfile
|
# Install axolotl + dev and test dependencies
|
||||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
uv venv --no-project --relocatable
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||||
pre-commit install
|
pre-commit install
|
||||||
|
|
||||||
# test
|
# test
|
||||||
|
|||||||
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -30,14 +30,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -168,14 +160,6 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
|
|||||||
12
.github/workflows/main.yml
vendored
12
.github/workflows/main.yml
vendored
@@ -18,12 +18,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -180,12 +174,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
RUN pip uninstall -y causal_conv1d
|
RUN pip uninstall -y causal_conv1d
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="deepspeed,optimizers,ray"; \
|
||||||
fi && \
|
fi && \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
|
|||||||
@@ -58,19 +58,3 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
# Map Python version (e.g., 3.12 -> cp312)
|
|
||||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
|
||||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
|
||||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
|
||||||
# Map architecture
|
|
||||||
case "$TARGETARCH" in \
|
|
||||||
amd64) ARCH_TAG="x86_64" ;; \
|
|
||||||
arm64) ARCH_TAG="aarch64" ;; \
|
|
||||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
|
||||||
esac && \
|
|
||||||
WHL_VERSION="v0.7.16" && \
|
|
||||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
|
||||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
|
||||||
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
|
||||||
rm "${WHL_FILE}"
|
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
RUN uv pip uninstall causal_conv1d
|
RUN uv pip uninstall causal_conv1d
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="deepspeed,optimizers,ray"; \
|
||||||
fi && \
|
fi && \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
|
|||||||
@@ -38,20 +38,3 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||||
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Map Python version (e.g., 3.12 -> cp312)
|
|
||||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
|
||||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
|
||||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
|
||||||
LINUX_TAG="manylinux_" && \
|
|
||||||
# Map architecture
|
|
||||||
case "$TARGETARCH" in \
|
|
||||||
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
|
|
||||||
arm64) ARCH_TAG="2_34_aarch64" ;; \
|
|
||||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
|
||||||
esac && \
|
|
||||||
WHL_VERSION="v0.7.16" && \
|
|
||||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
|
|
||||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
|
||||||
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
|
||||||
rm "${WHL_FILE}"
|
|
||||||
|
|||||||
@@ -77,8 +77,9 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
uv venv --no-project --relocatable
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Remote Hosts
|
#### Remote Hosts
|
||||||
@@ -218,8 +219,9 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
|||||||
You will now be in the container. Next, install Axolotl with dev dependencies:
|
You will now be in the container. Next, install Axolotl with dev dependencies:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
uv venv --no-project --relocatable
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||||
```
|
```
|
||||||
|
|
||||||
### Attach To Container
|
### Attach To Container
|
||||||
|
|||||||
@@ -10,13 +10,16 @@ This section describes the different Docker images that are released by AxolotlA
|
|||||||
[Docker Hub](https://hub.docker.com/u/axolotlai).
|
[Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
|
### Switch to the `-uv` images
|
||||||
:::
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
|
||||||
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
|
(`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
||||||
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
(e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
|
||||||
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
|
same format as their non-uv counterparts.
|
||||||
|
|
||||||
|
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
|
||||||
|
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
|
||||||
|
the uv image.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Base
|
## Base
|
||||||
@@ -85,7 +88,7 @@ Tags examples:
|
|||||||
- `main-py3.12-cu130-2.10.0`
|
- `main-py3.12-cu130-2.10.0`
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20260315-py3.11-cu128-2.9.1`
|
- `main-20260315-py3.11-cu128-2.9.1`
|
||||||
- `0.12.0`
|
- `0.16.1`
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
**Q: vLLM is not working with Axolotl**
|
**Q: vLLM is not working with Axolotl**
|
||||||
|
|
||||||
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
|
> A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
|
||||||
|
|
||||||
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python ≥3.11
|
- Python ≥3.11
|
||||||
- PyTorch ≥2.9.0
|
- PyTorch ≥2.9.1
|
||||||
|
|
||||||
## Installation {#sec-installation}
|
## Installation {#sec-installation}
|
||||||
|
|
||||||
@@ -36,9 +36,9 @@ source $HOME/.local/bin/env
|
|||||||
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
|
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv venv --no-project --relocatable
|
uv venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Edge/Development Build {#sec-edge-build}
|
### Edge/Development Build {#sec-edge-build}
|
||||||
@@ -49,12 +49,11 @@ For the latest features between releases:
|
|||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv sync --extra flash-attn --extra deepspeed
|
uv venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
|
|
||||||
|
|
||||||
### Docker {#sec-docker}
|
### Docker {#sec-docker}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
@@ -132,11 +131,11 @@ source $HOME/.local/bin/env
|
|||||||
|
|
||||||
# Create a fresh venv (recommended for a clean start)
|
# Create a fresh venv (recommended for a clean start)
|
||||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv venv --no-project --relocatable
|
uv venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
|
||||||
# Reinstall axolotl
|
# Reinstall axolotl
|
||||||
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Using pip (Alternative) {#sec-pip}
|
## Using pip (Alternative) {#sec-pip}
|
||||||
@@ -151,13 +150,13 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
|
|||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
For editable/development installs:
|
For editable/development installs:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Troubleshooting {#sec-troubleshooting}
|
## Troubleshooting {#sec-troubleshooting}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have a compatible version of Pytorch installed
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run one of the finetuning examples below.
|
2. Run one of the finetuning examples below.
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
uv pip install --no-build-isolation -e '.'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -13,11 +13,11 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
uv pip install --no-build-isolation -e '.'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -36,12 +36,7 @@
|
|||||||
"id": "msOCO4NRmRLa"
|
"id": "msOCO4NRmRLa"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": "%%capture\n# This step can take ~5-10 minutes to install dependencies\n!pip install --no-build-isolation \"axolotl>=0.16.1\"\n!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
||||||
"%%capture\n",
|
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.7.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
uv pip install --no-build-isolation -e '.'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ Tencent released a family of opensource models called HunYuan with varying param
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
uv pip install --no-build-isolation -e '.'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have a compatible version of Pytorch installed
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
|
|
||||||
# Install Cut Cross Entropy
|
# Install Cut Cross Entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have a compatible version of Pytorch installed
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install an extra dependency:
|
2. Install an extra dependency:
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Please install the below.
|
2. Please install the below.
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ requires-python = ">=3.10"
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core ML stack
|
# Core ML stack
|
||||||
"torch>=2.6.0",
|
"torch>=2.9.1",
|
||||||
"packaging==26.0",
|
"packaging==26.0",
|
||||||
"huggingface_hub>=1.1.7",
|
"huggingface_hub>=1.1.7",
|
||||||
"peft>=0.19.1,<0.20.0",
|
"peft>=0.19.1,<0.20.0",
|
||||||
@@ -79,7 +79,7 @@ dependencies = [
|
|||||||
# Platform-specific (Linux only)
|
# Platform-specific (Linux only)
|
||||||
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
|
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
|
||||||
"triton>=3.4.0 ; sys_platform != 'darwin'",
|
"triton>=3.4.0 ; sys_platform != 'darwin'",
|
||||||
"xformers>=0.0.23.post1 ; sys_platform != 'darwin'",
|
"xformers>=0.0.33.post2 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
||||||
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
|
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
|
||||||
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
||||||
|
|
||||||
|
|||||||
@@ -370,7 +370,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
"padding": True, # True/"longest" is the default
|
"padding": True, # True/"longest" is the default
|
||||||
}
|
}
|
||||||
multiple = 64
|
multiple = getattr(self.cfg, "pad_to_multiple_of", None) or 64
|
||||||
if self.cfg.pad_to_sequence_len:
|
if self.cfg.pad_to_sequence_len:
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||||
self.cfg.sequence_len / multiple
|
self.cfg.sequence_len / multiple
|
||||||
|
|||||||
@@ -228,9 +228,47 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
return training_args, trainer_kwargs
|
return training_args, trainer_kwargs
|
||||||
|
|
||||||
|
def build_collator(self, **kwargs):
|
||||||
|
"""Build a data collator for preference-tuning trainers.
|
||||||
|
|
||||||
|
Returns None for RL types that provide their own collator (e.g. GRPO,
|
||||||
|
KTO), letting the trainer construct its default. For DPO/IPO/ORPO/SIMPO
|
||||||
|
returns an ``AxolotlDPODataCollatorWithPadding`` when
|
||||||
|
``pad_to_multiple_of`` is set, otherwise None (so the trainer
|
||||||
|
falls back to the TRL default).
|
||||||
|
"""
|
||||||
|
if self.cfg.rl not in (
|
||||||
|
RLType.DPO,
|
||||||
|
RLType.IPO,
|
||||||
|
RLType.ORPO,
|
||||||
|
RLType.SIMPO,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
pad_to_multiple_of = getattr(self.cfg, "pad_to_multiple_of", None)
|
||||||
|
if not pad_to_multiple_of:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from axolotl.utils.collators.dpo import AxolotlDPODataCollatorWithPadding
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Using AxolotlDPODataCollatorWithPadding with pad_to_multiple_of="
|
||||||
|
f"{pad_to_multiple_of}"
|
||||||
|
)
|
||||||
|
is_enc_dec = getattr(self.model.config, "is_encoder_decoder", False)
|
||||||
|
return AxolotlDPODataCollatorWithPadding(
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
is_encoder_decoder=is_enc_dec,
|
||||||
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
|
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
|
||||||
|
|
||||||
|
if (data_collator := self.build_collator()) is not None:
|
||||||
|
trainer_kwargs["data_collator"] = data_collator
|
||||||
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
|
|||||||
kd_alpha: 0.9
|
kd_alpha: 0.9
|
||||||
kd_temperature: 1.0
|
kd_temperature: 1.0
|
||||||
|
|
||||||
torch_compile: True # torch>=2.6.0, recommended to reduce vram
|
torch_compile: True # recommended to reduce vram
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: ...
|
||||||
|
|||||||
@@ -407,7 +407,10 @@ def selective_log_softmax(logits, index) -> torch.Tensor:
|
|||||||
K = index.shape[-1]
|
K = index.shape[-1]
|
||||||
original_index_shape = index.shape
|
original_index_shape = index.shape
|
||||||
|
|
||||||
flat_logits = logits.reshape(-1, V).contiguous()
|
try:
|
||||||
|
flat_logits = logits.view(-1, V)
|
||||||
|
except RuntimeError:
|
||||||
|
flat_logits = logits.reshape(-1, V).contiguous()
|
||||||
flat_index = index.reshape(-1, K).contiguous()
|
flat_index = index.reshape(-1, K).contiguous()
|
||||||
|
|
||||||
BLOCK_V = 4096
|
BLOCK_V = 4096
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from .batching import (
|
|||||||
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
from .dpo import AxolotlDPODataCollatorWithPadding
|
||||||
from .mamba import MambaDataCollator
|
from .mamba import MambaDataCollator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -13,5 +14,6 @@ __all__ = [
|
|||||||
"BatchSamplerDataCollatorForSeq2Seq",
|
"BatchSamplerDataCollatorForSeq2Seq",
|
||||||
"V2BatchSamplerDataCollatorForSeq2Seq",
|
"V2BatchSamplerDataCollatorForSeq2Seq",
|
||||||
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
|
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
|
||||||
|
"AxolotlDPODataCollatorWithPadding",
|
||||||
"MambaDataCollator",
|
"MambaDataCollator",
|
||||||
]
|
]
|
||||||
|
|||||||
128
src/axolotl/utils/collators/dpo.py
Normal file
128
src/axolotl/utils/collators/dpo.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""DPO/ORPO/IPO/KTO data collator with pad_to_multiple_of support.
|
||||||
|
|
||||||
|
Extends TRL's DPODataCollatorWithPadding to round padded sequence lengths
|
||||||
|
up to a fixed multiple. This stabilizes Triton autotune caches for kernels
|
||||||
|
that key on sequence length (e.g. fla's linear attention kernels used by
|
||||||
|
Qwen3.5), which otherwise re-autotune on every distinct batch length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from trl.experimental.utils import DPODataCollatorWithPadding
|
||||||
|
from trl.trainer.utils import pad
|
||||||
|
|
||||||
|
|
||||||
|
def _round_up(length: int, multiple: int) -> int:
|
||||||
|
return ((length + multiple - 1) // multiple) * multiple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlDPODataCollatorWithPadding(DPODataCollatorWithPadding):
|
||||||
|
"""DPO data collator that pads to a multiple of ``pad_to_multiple_of``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pad_token_id: Tokenizer pad token id (inherited).
|
||||||
|
is_encoder_decoder: Whether the model is encoder-decoder (inherited).
|
||||||
|
pad_to_multiple_of: If set, padded lengths are rounded up to this
|
||||||
|
multiple. Helps stabilize Triton autotune caches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pad_to_multiple_of: int | None = None
|
||||||
|
|
||||||
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
|
pad_to_mult = self.pad_to_multiple_of
|
||||||
|
|
||||||
|
padded_batch: dict[str, Any] = {}
|
||||||
|
for k in features[0].keys():
|
||||||
|
if k.endswith(
|
||||||
|
("_input_ids", "_attention_mask", "_labels", "_pixel_values")
|
||||||
|
):
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
if k.endswith("_pixel_values"):
|
||||||
|
to_pad = [
|
||||||
|
torch.tensor(ex[k], dtype=torch.float32) for ex in features
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
to_pad = [torch.LongTensor(ex[k]) for ex in features]
|
||||||
|
|
||||||
|
if k.startswith("prompt") and k.endswith("input_ids"):
|
||||||
|
if self.pad_token_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Padding is enabled, but the tokenizer is not configured with a padding token."
|
||||||
|
)
|
||||||
|
padding_value = self.pad_token_id
|
||||||
|
elif k.endswith("_attention_mask"):
|
||||||
|
padding_value = 0
|
||||||
|
elif k.endswith("_pixel_values"):
|
||||||
|
padding_value = 0
|
||||||
|
elif (
|
||||||
|
k.startswith(("chosen", "rejected", "completion"))
|
||||||
|
or "decoder" in k
|
||||||
|
):
|
||||||
|
padding_value = -100
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected key in batch '{k}'")
|
||||||
|
|
||||||
|
padded = pad_sequence(
|
||||||
|
to_pad, batch_first=True, padding_value=padding_value
|
||||||
|
)
|
||||||
|
if pad_to_mult:
|
||||||
|
cur = padded.shape[1]
|
||||||
|
target = _round_up(cur, pad_to_mult)
|
||||||
|
if target > cur:
|
||||||
|
extra = target - cur
|
||||||
|
pad_shape = list(padded.shape)
|
||||||
|
pad_shape[1] = extra
|
||||||
|
filler = torch.full(
|
||||||
|
pad_shape,
|
||||||
|
padding_value,
|
||||||
|
dtype=padded.dtype,
|
||||||
|
device=padded.device,
|
||||||
|
)
|
||||||
|
padded = torch.cat([padded, filler], dim=1)
|
||||||
|
padded_batch[k] = padded
|
||||||
|
else:
|
||||||
|
if k.endswith("_input_ids"):
|
||||||
|
if self.pad_token_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Padding is enabled, but the tokenizer is not configured with a padding token."
|
||||||
|
)
|
||||||
|
padding_value = self.pad_token_id
|
||||||
|
elif k.endswith("_labels"):
|
||||||
|
padding_value = -100
|
||||||
|
elif k.endswith("_attention_mask"):
|
||||||
|
padding_value = 0
|
||||||
|
elif k.endswith("_pixel_values"):
|
||||||
|
padding_value = 0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected key in batch '{k}'")
|
||||||
|
|
||||||
|
padding_side = (
|
||||||
|
"left"
|
||||||
|
if k in ("prompt_input_ids", "prompt_attention_mask")
|
||||||
|
else "right"
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype = (
|
||||||
|
torch.float32 if k.endswith("_pixel_values") else torch.int64
|
||||||
|
)
|
||||||
|
to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features]
|
||||||
|
|
||||||
|
# trl.pad() natively supports pad_to_multiple_of
|
||||||
|
padded_batch[k] = pad(
|
||||||
|
to_pad,
|
||||||
|
padding_value=padding_value,
|
||||||
|
padding_side=padding_side,
|
||||||
|
pad_to_multiple_of=pad_to_mult,
|
||||||
|
)
|
||||||
|
elif k.endswith("_logps"):
|
||||||
|
padded_batch[k] = torch.tensor([ex[k] for ex in features])
|
||||||
|
else:
|
||||||
|
padded_batch[k] = [ex[k] for ex in features]
|
||||||
|
|
||||||
|
return padded_batch
|
||||||
@@ -673,6 +673,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled"
|
"description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
pad_to_multiple_of: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": ("Pad each batch to a multiple of this value.")
|
||||||
|
},
|
||||||
|
)
|
||||||
curriculum_sampling: bool | None = Field(
|
curriculum_sampling: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -1010,7 +1016,7 @@ class AxolotlInputConfig(
|
|||||||
torch_compile: Literal["auto"] | bool | None = Field(
|
torch_compile: Literal["auto"] | bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0"
|
"description": "Whether to use torch.compile and which backend to use."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
torch_compile_backend: str | None = Field(
|
torch_compile_backend: str | None = Field(
|
||||||
|
|||||||
Reference in New Issue
Block a user