Compare commits

..

10 Commits

Author SHA1 Message Date
Dan Saunders
156fede4f7 Update .pre-commit-config.yaml
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-03-21 10:36:18 -04:00
Dan Saunders
dcbbd7af79 sorry to revert, but pylint complained 2025-03-21 10:36:18 -04:00
Dan Saunders
21bac7ce1a running updated pre-commit plugins 2025-03-21 10:36:18 -04:00
Dan Saunders
aaa4571826 adding pre-commit auto-update GH action and bumping plugin versions 2025-03-21 10:36:17 -04:00
salman
187227d837 Fixing KTO+QLoRA+multi-GPU (#2420)
* WIP

* removing artifacts

* adding error

* adding adapter check

* linting

* simplifying check

* linting v2

* config fix -___-
2025-03-21 10:18:28 -04:00
NanoCode012
f8de8bb4f2 chore(doc): add instructions on adding custom integrations (#2422) [skip ci]
* chore(doc): add instructions on adding custom integrations

* chore: add warning help

* feat: add note about integration path

* fix: adjust text per suggestion
2025-03-21 10:18:01 -04:00
hugo
8e604848a4 add run on novita ai (#2421) [skip ci]
* add run on novita ai

* Revert "add run on novita ai"

This reverts commit 4d5df1ac6b.

* add run axolotl on novita ai
2025-03-21 10:17:47 -04:00
Wing Lian
aae4337f40 add 12.8.1 cuda to the base matrix (#2426)
* add 12.8.1 cuda to the base matrix

* use nightly

* bump deepspeed and set no binary

* deepspeed binary fixes hopefully

* install deepspeed by itself

* multiline fix

* make sure ninja is installed

* try with reversion of packaging/setuptools/wheel install

* use license instead of license-file

* try rolling back packaging and setuptools versions

* comment out license for validation for now

* make sure packaging version is consistent

* more parity across tests and docker images for packaging/setuptools
2025-03-21 10:17:25 -04:00
Wing Lian
38df5a36ea bump HF versions except for trl (#2427) 2025-03-20 10:22:05 -04:00
Wing Lian
4d92a68a96 use default torch fused adamw optimizer as default as adamw_hf is deprecated (#2425)
* use default torch fused adamw optimizer as default as adamw_hf is deprecated

* make sure to have latest packaging installed

* bump packagingin requirements.txt too
2025-03-19 23:58:33 -04:00
149 changed files with 608 additions and 324 deletions

View File

@@ -40,6 +40,12 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -61,7 +67,7 @@ jobs:
uses: docker/build-push-action@v4 uses: docker/build-push-action@v4
with: with:
context: . context: .
file: ./docker/Dockerfile-base file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -0,0 +1,49 @@
name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff
jobs:
auto-update:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Update pre-commit hooks
id: update
run: |
pip install pre-commit
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
if: steps.update.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
branch: update/pre-commit-hooks
delete-branch: true
title: "chore: update pre-commit hooks"
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install wheel packaging pip3 install wheel packaging==23.2
pip3 install --no-build-isolation -e . pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -42,7 +42,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -59,7 +59,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging pip3 install --upgrade packaging==23.2
pip3 install --no-build-isolation -U -e . pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh

View File

@@ -74,7 +74,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -147,7 +147,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools setuptools_scm build wheel pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |

View File

@@ -3,7 +3,7 @@ default_language_version:
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v5.0.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
@@ -11,23 +11,23 @@ repos:
- id: no-commit-to-branch - id: no-commit-to-branch
args: ['--branch', 'main'] args: ['--branch', 'main']
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 23.3.0 rev: 25.1.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.12.0 rev: 6.0.1
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 6.1.0 rev: 7.1.2
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/PyCQA/pylint - repo: https://github.com/pylint-dev/pylint
rev: v3.3.0 rev: v3.3.6
hooks: hooks:
- id: pylint - id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0 rev: v1.15.0
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:
@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3', 'pydantic>=2.5.3',
] ]
- repo: https://github.com/PyCQA/bandit - repo: https://github.com/PyCQA/bandit
rev: 1.7.5 rev: 1.8.3
hooks: hooks:
- id: bandit - id: bandit
args: [ args: [

View File

@@ -55,7 +55,7 @@ Features:
### Installation ### Installation
```bash ```bash
pip3 install -U packaging setuptools wheel ninja pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs # Download example axolotl configs, deepspeed configs

View File

@@ -31,6 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \ sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi fi
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \

View File

@@ -1,6 +1,7 @@
""" """
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -1,4 +1,5 @@
"""Modal app to run axolotl GPU tests""" """Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \ python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

View File

@@ -0,0 +1,39 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="nightly"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -85,6 +85,12 @@ gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge # Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
lora_on_cpu: true lora_on_cpu: true
# List[str]. Add plugins to extend the pipeline.
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
# https://axolotl-ai-cloud.github.io/axolotl/docs/custom_integrations.html
plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with # A list of one or more datasets to finetune the model with
datasets: datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files

View File

@@ -55,3 +55,47 @@ sections = [
for section_name, folder_name in sections: for section_name, folder_name in sections:
print(print_section(section_name, folder_name)) print(print_section(section_name, folder_name))
``` ```
## Adding a new integration
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
To add a new integration, please follow these steps:
1. Create a new folder in the `src/axolotl/integrations` directory.
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
3. Add `__init__.py` and `args.py` files to the new folder.
- `__init__.py` should import the integration and hook into the appropriate functions.
- `args.py` should define the arguments for the integration.
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
::: {.callout-tip}
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
:::
::: {.callout-warning}
If you could not load your integration, please ensure you are pip installing in editable mode.
```bash
pip install -e .
```
and correctly spelled the integration name in the config file.
```yaml
plugins:
- axolotl.integrations.your_integration_name.YourIntegrationPlugin
```
:::
::: {.callout-note}
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
:::

View File

@@ -79,6 +79,7 @@ For providers supporting Docker:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c) - [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl) - [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz) - [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Novita](https://novita.ai/gpus-console?templateId=311)
### Google Colab {#sec-colab} ### Google Colab {#sec-colab}

View File

@@ -1,5 +1,5 @@
[build-system] [build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"] requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
@@ -8,6 +8,7 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer" description = "LLM Trainer"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
# license = "Apache-2.0"
[project.scripts] [project.scripts]
axolotl = "axolotl.cli.main:main" axolotl = "axolotl.cli.main:main"

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.2 bitsandbytes==0.45.3
triton>=3.0.0 triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1 flash-attn==2.7.4.post1
@@ -12,12 +12,12 @@ liger-kernel==0.5.3
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.15.0
transformers==4.49.0 transformers==4.49.0
tokenizers>=0.21.0 tokenizers>=0.21.1
accelerate==1.3.0 accelerate==1.5.2
datasets==3.2.0 datasets==3.4.1
deepspeed==0.16.1 deepspeed==0.16.4
trl==0.15.1 trl==0.15.1
optimum==1.16.2 optimum==1.16.2

View File

@@ -1,6 +1,7 @@
""" """
helper script to parse chat datasets into a usable yaml helper script to parse chat datasets into a usable yaml
""" """
import click import click
import yaml import yaml
from datasets import load_dataset from datasets import load_dataset

View File

@@ -1,4 +1,5 @@
"""Script to output the correct installation command for cut-cross-entropy.""" """Script to output the correct installation command for cut-cross-entropy."""
import importlib.util import importlib.util
import sys import sys

View File

@@ -128,7 +128,7 @@ setup(
"flash-attn==2.7.4.post1", "flash-attn==2.7.4.post1",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.16.1", "deepspeed==0.16.4",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -1,6 +1,7 @@
""" """
launch axolotl in supported cloud platforms launch axolotl in supported cloud platforms
""" """
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union

View File

@@ -1,6 +1,7 @@
""" """
base class for cloud platforms from cli base class for cloud platforms from cli
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View File

@@ -1,6 +1,7 @@
""" """
Modal Cloud support from CLI Modal Cloud support from CLI
""" """
import copy import copy
import json import json
import os import os

View File

@@ -1,4 +1,5 @@
"""Click CLI definitions for various axolotl commands.""" """Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import logging import logging

View File

@@ -5,7 +5,6 @@ import dataclasses
import hashlib import hashlib
import json import json
import logging import logging
import typing
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
@@ -24,7 +23,7 @@ configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | typing._SpecialForm | None): def strip_optional_type(field_type: type | str | None):
""" """
Extracts the non-`None` type from an `Optional` / `Union` type. Extracts the non-`None` type from an `Optional` / `Union` type.

View File

@@ -1,6 +1,5 @@
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes""" """Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
import json import json
import sys import sys

View File

@@ -1,6 +1,7 @@
""" """
ChatML transformation functions for MessageContents ChatML transformation functions for MessageContents
""" """
from typing import Optional from typing import Optional
from ..messages import MessageContents, Messages from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
Llama 3.x chat formatting functions for MessageContents Llama 3.x chat formatting functions for MessageContents
""" """
from typing import Optional from typing import Optional
from ..messages import MessageContents, Messages from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
shared functions for format transforms shared functions for format transforms
""" """
from axolotl.core.chat.messages import MessageContents, Messages from axolotl.core.chat.messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
internal message representations of chat messages internal message representations of chat messages
""" """
import json import json
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union

View File

@@ -1,6 +1,7 @@
""" """
chat dataset module chat dataset module
""" """
import os import os
from typing import Callable, Optional, Union from typing import Callable, Optional, Union

View File

@@ -1,6 +1,7 @@
""" """
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
""" """
from typing import Any, Mapping, Union from typing import Any, Mapping, Union

View File

@@ -332,9 +332,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = {} training_arguments_kwargs = {}
if self.cfg.include_tokens_per_second is not None: if self.cfg.include_tokens_per_second is not None:
training_arguments_kwargs[ training_arguments_kwargs["include_tokens_per_second"] = (
"include_tokens_per_second" self.cfg.include_tokens_per_second
] = self.cfg.include_tokens_per_second )
if self.cfg.bf16 == "full": if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True training_arguments_kwargs["bf16_full_eval"] = True
@@ -351,13 +351,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["seed"] = self.cfg.seed training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_arguments_kwargs[ training_arguments_kwargs["gradient_checkpointing"] = (
"gradient_checkpointing" self.cfg.gradient_checkpointing
] = self.cfg.gradient_checkpointing )
if self.cfg.gradient_checkpointing_kwargs is not None: if self.cfg.gradient_checkpointing_kwargs is not None:
training_arguments_kwargs[ training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
"gradient_checkpointing_kwargs" self.cfg.gradient_checkpointing_kwargs
] = self.cfg.gradient_checkpointing_kwargs )
if self.cfg.fsdp: if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config: if self.cfg.fsdp_config:
@@ -373,9 +373,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None: if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[ training_arguments_kwargs["lr_quadratic_warmup"] = (
"lr_quadratic_warmup" self.cfg.lr_quadratic_warmup
] = self.cfg.lr_quadratic_warmup )
if self.cfg.adam_beta1: if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
@@ -399,28 +399,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_pin_memory"] = (
"dataloader_pin_memory" self.cfg.dataloader_pin_memory
] = self.cfg.dataloader_pin_memory )
if self.cfg.dataloader_num_workers is not None: if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_num_workers"] = (
"dataloader_num_workers" self.cfg.dataloader_num_workers
] = self.cfg.dataloader_num_workers )
if self.cfg.dataloader_prefetch_factor is not None: if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_prefetch_factor"] = (
"dataloader_prefetch_factor" self.cfg.dataloader_prefetch_factor
] = self.cfg.dataloader_prefetch_factor )
if self.cfg.dataloader_drop_last is not None: if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_drop_last"] = (
"dataloader_drop_last" self.cfg.dataloader_drop_last
] = self.cfg.dataloader_drop_last )
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None: if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs[ training_arguments_kwargs["remove_unused_columns"] = (
"remove_unused_columns" self.cfg.remove_unused_columns
] = self.cfg.remove_unused_columns )
if not self.cfg.test_datasets and self.cfg.val_set_size == 0: if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval # no eval set, so don't eval
@@ -452,9 +452,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_causal_lm_eval: if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model: if self.cfg.metric_for_best_model:
training_arguments_kwargs[ training_arguments_kwargs["metric_for_best_model"] = (
"metric_for_best_model" self.cfg.metric_for_best_model
] = self.cfg.metric_for_best_model )
if self.cfg.greater_is_better: if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
@@ -467,13 +467,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend: if self.cfg.torch_compile_backend:
training_arguments_kwargs[ training_arguments_kwargs["torch_compile_backend"] = (
"torch_compile_backend" self.cfg.torch_compile_backend
] = self.cfg.torch_compile_backend )
if self.cfg.torch_compile_mode: if self.cfg.torch_compile_mode:
training_arguments_kwargs[ training_arguments_kwargs["torch_compile_mode"] = (
"torch_compile_mode" self.cfg.torch_compile_mode
] = self.cfg.torch_compile_mode )
# DDP Config # DDP Config
if self.cfg.ddp_timeout: if self.cfg.ddp_timeout:
@@ -482,32 +482,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.ddp_bucket_cap_mb: if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None: if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[ training_arguments_kwargs["ddp_broadcast_buffers"] = (
"ddp_broadcast_buffers" self.cfg.ddp_broadcast_buffers
] = self.cfg.ddp_broadcast_buffers )
# these are all the "standard" kwargs that are def used # these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = ( training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1 total_num_steps if self.cfg.max_steps else -1
) )
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[ training_arguments_kwargs["per_device_train_batch_size"] = (
"per_device_train_batch_size" self.cfg.micro_batch_size
] = self.cfg.micro_batch_size )
if self.cfg.eval_batch_size: if self.cfg.eval_batch_size:
training_arguments_kwargs[ training_arguments_kwargs["per_device_eval_batch_size"] = (
"per_device_eval_batch_size" self.cfg.eval_batch_size
] = self.cfg.eval_batch_size )
if self.cfg.auto_find_batch_size is not None: if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["auto_find_batch_size"] = (
"auto_find_batch_size" self.cfg.auto_find_batch_size
] = self.cfg.auto_find_batch_size )
training_arguments_kwargs[ training_arguments_kwargs["gradient_accumulation_steps"] = (
"gradient_accumulation_steps" self.cfg.gradient_accumulation_steps
] = self.cfg.gradient_accumulation_steps )
training_arguments_kwargs[ training_arguments_kwargs["eval_accumulation_steps"] = (
"eval_accumulation_steps" self.cfg.gradient_accumulation_steps
] = self.cfg.gradient_accumulation_steps )
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir training_arguments_kwargs["output_dir"] = self.cfg.output_dir
@@ -554,9 +554,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[ training_arguments_kwargs["alternate_lr_scheduler_type"] = (
"alternate_lr_scheduler_type" self.cfg.lr_scheduler
] = self.cfg.lr_scheduler )
else: else:
training_arguments_kwargs["lr_scheduler_type"] = ( training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
@@ -565,9 +565,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[ training_arguments_kwargs["cosine_constant_lr_ratio"] = (
"cosine_constant_lr_ratio" self.cfg.cosine_constant_lr_ratio
] = self.cfg.cosine_constant_lr_ratio )
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
@@ -580,40 +580,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.eval_sample_packing self.cfg.eval_sample_packing
) )
if self.cfg.sample_packing_bin_size is not None: if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_bin_size"] = (
"sample_packing_bin_size" self.cfg.sample_packing_bin_size
] = self.cfg.sample_packing_bin_size )
if self.cfg.sample_packing_group_size is not None: if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_group_size"] = (
"sample_packing_group_size" self.cfg.sample_packing_group_size
] = self.cfg.sample_packing_group_size )
if self.cfg.sample_packing_eff_est: if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_efficiency"] = (
"sample_packing_efficiency" self.cfg.sample_packing_eff_est
] = self.cfg.sample_packing_eff_est )
if self.cfg.relora_steps: if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[ training_arguments_kwargs["relora_warmup_steps"] = (
"relora_warmup_steps" self.cfg.relora_warmup_steps
] = self.cfg.relora_warmup_steps )
if self.cfg.relora_anneal_steps: if self.cfg.relora_anneal_steps:
training_arguments_kwargs[ training_arguments_kwargs["relora_anneal_steps"] = (
"relora_anneal_steps" self.cfg.relora_anneal_steps
] = self.cfg.relora_anneal_steps )
if self.cfg.relora_prune_ratio: if self.cfg.relora_prune_ratio:
training_arguments_kwargs[ training_arguments_kwargs["relora_prune_ratio"] = (
"relora_prune_ratio" self.cfg.relora_prune_ratio
] = self.cfg.relora_prune_ratio )
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs[ training_arguments_kwargs["lisa_step_interval"] = (
"lisa_step_interval" self.cfg.lisa_step_interval
] = self.cfg.lisa_step_interval )
training_arguments_kwargs[ training_arguments_kwargs["lisa_layers_attribute"] = (
"lisa_layers_attribute" self.cfg.lisa_layers_attribute
] = self.cfg.lisa_layers_attribute )
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
@@ -627,9 +627,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
if self.cfg.neftune_noise_alpha is not None: if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[ training_arguments_kwargs["neftune_noise_alpha"] = (
"neftune_noise_alpha" self.cfg.neftune_noise_alpha
] = self.cfg.neftune_noise_alpha )
trainer_kwargs = {} trainer_kwargs = {}
@@ -731,23 +731,23 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
importlib.import_module("torchdistx") importlib.import_module("torchdistx")
if self.cfg.optim_target_modules: if self.cfg.optim_target_modules:
training_arguments_kwargs[ training_arguments_kwargs["optim_target_modules"] = (
"optim_target_modules" self.cfg.optim_target_modules
] = self.cfg.optim_target_modules )
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[ training_arguments_kwargs["loraplus_lr_embedding"] = (
"loraplus_lr_embedding" self.cfg.loraplus_lr_embedding
] = self.cfg.loraplus_lr_embedding )
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config: if self.cfg.accelerator_config:
training_arguments_kwargs[ training_arguments_kwargs["accelerator_config"] = (
"accelerator_config" self.cfg.accelerator_config
] = self.cfg.accelerator_config )
if self.cfg.kd_ce_alpha is not None: if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
@@ -756,13 +756,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_temperature is not None: if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None: if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[ training_arguments_kwargs["kd_zscore_base_temp"] = (
"kd_zscore_base_temp" self.cfg.kd_zscore_base_temp
] = self.cfg.kd_zscore_base_temp )
if self.cfg.kd_top_k_before_softmax is not None: if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs[ training_arguments_kwargs["kd_top_k_before_softmax"] = (
"kd_top_k_before_softmax" self.cfg.kd_top_k_before_softmax
] = self.cfg.kd_top_k_before_softmax )
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig
@@ -972,32 +972,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
if self.cfg.remove_unused_columns is not None: if self.cfg.remove_unused_columns is not None:
training_args_kwargs[ training_args_kwargs["remove_unused_columns"] = (
"remove_unused_columns" self.cfg.remove_unused_columns
] = self.cfg.remove_unused_columns )
else: else:
training_args_kwargs["remove_unused_columns"] = False training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[ training_args_kwargs["dataloader_pin_memory"] = (
"dataloader_pin_memory" self.cfg.dataloader_pin_memory
] = self.cfg.dataloader_pin_memory )
if self.cfg.dataloader_num_workers is not None: if self.cfg.dataloader_num_workers is not None:
training_args_kwargs[ training_args_kwargs["dataloader_num_workers"] = (
"dataloader_num_workers" self.cfg.dataloader_num_workers
] = self.cfg.dataloader_num_workers )
if self.cfg.dataloader_prefetch_factor is not None: if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs[ training_args_kwargs["dataloader_prefetch_factor"] = (
"dataloader_prefetch_factor" self.cfg.dataloader_prefetch_factor
] = self.cfg.dataloader_prefetch_factor )
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_args_kwargs[ training_args_kwargs["gradient_checkpointing"] = (
"gradient_checkpointing" self.cfg.gradient_checkpointing
] = self.cfg.gradient_checkpointing )
if self.cfg.gradient_checkpointing_kwargs is not None: if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs[ training_args_kwargs["gradient_checkpointing_kwargs"] = (
"gradient_checkpointing_kwargs" self.cfg.gradient_checkpointing_kwargs
] = self.cfg.gradient_checkpointing_kwargs )
else: else:
training_args_kwargs["gradient_checkpointing_kwargs"] = { training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False "use_reentrant": False
@@ -1071,9 +1071,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dpo_use_weighting is not None: if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None: if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[ training_args_kwargs["use_logits_to_keep"] = (
"use_logits_to_keep" self.cfg.dpo_use_logits_to_keep
] = self.cfg.dpo_use_logits_to_keep )
for blocklist_key in blocklist_args_kwargs: for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs: if blocklist_key in training_args_kwargs:
@@ -1108,9 +1108,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None: if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs[ dpo_trainer_kwargs["precompute_ref_log_probs"] = (
"precompute_ref_log_probs" self.cfg.precompute_ref_log_probs
] = self.cfg.precompute_ref_log_probs )
if self.cfg.rl == "grpo": if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class() trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model] trainer_cls_args = [self.model]

View File

@@ -462,9 +462,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
} }
if self.args.dataloader_prefetch_factor: if self.args.dataloader_prefetch_factor:
dataloader_params[ dataloader_params["prefetch_factor"] = (
"prefetch_factor" self.args.dataloader_prefetch_factor
] = self.args.dataloader_prefetch_factor )
sampler = self._get_train_sampler() sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler): if isinstance(sampler, BatchSampler):
@@ -509,9 +509,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
} }
if self.args.dataloader_prefetch_factor: if self.args.dataloader_prefetch_factor:
dataloader_params[ dataloader_params["prefetch_factor"] = (
"prefetch_factor" self.args.dataloader_prefetch_factor
] = self.args.dataloader_prefetch_factor )
if isinstance(eval_sampler, BatchSampler): if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler dataloader_params["batch_sampler"] = eval_sampler

View File

@@ -1,6 +1,7 @@
""" """
DPO Specific Strategy for training DPO Specific Strategy for training
""" """
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer

View File

@@ -1,6 +1,7 @@
""" """
Axolotl specific DPO args Axolotl specific DPO args
""" """
from dataclasses import dataclass from dataclasses import dataclass
from trl import DPOConfig from trl import DPOConfig

View File

@@ -1,6 +1,7 @@
""" """
DPO trainer for axolotl DPO trainer for axolotl
""" """
import gc import gc
from functools import wraps from functools import wraps
from typing import Any, Dict, Union from typing import Any, Dict, Union

View File

@@ -45,9 +45,9 @@ class GRPOStrategy:
) )
if trl.vllm_gpu_memory_utilization: if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[ grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
"vllm_gpu_memory_utilization" trl.vllm_gpu_memory_utilization
] = trl.vllm_gpu_memory_utilization )
if trl.vllm_max_model_len: if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
@@ -86,9 +86,9 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg): def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {} trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes: if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[ trainer_kwargs["reward_processing_classes"] = (
"reward_processing_classes" cfg.trl.reward_processing_classes
] = cfg.trl.reward_processing_classes )
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod

View File

@@ -1,6 +1,7 @@
""" """
Axolotl Specific Training Args Axolotl Specific Training Args
""" """
from dataclasses import dataclass from dataclasses import dataclass
from trl import GRPOConfig from trl import GRPOConfig

View File

@@ -1,6 +1,7 @@
""" """
Axolotl GRPO trainer Axolotl GRPO trainer
""" """
from accelerate.utils import is_peft_model from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel from transformers import PreTrainedModel

View File

@@ -1,6 +1,7 @@
""" """
module for TRL PPO training module for TRL PPO training
""" """
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from trl import PPOTrainer from trl import PPOTrainer

View File

@@ -1,6 +1,7 @@
""" """
extra axolotl specific training args extra axolotl specific training args
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional

View File

@@ -1,6 +1,7 @@
""" """
Grokfast plugin for Axolotl Grokfast plugin for Axolotl
""" """
import logging import logging
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback

View File

@@ -1,6 +1,7 @@
""" """
config args for grokfast plugin config args for grokfast plugin
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
""" """
kd_trainer: Optional[bool] = None # whether to use KD trainer kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[ kd_ce_alpha: Optional[float] = (
float None # loss coefficient for cross-entropy loss during KD
] = None # loss coefficient for cross-entropy loss during KD )
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[ kd_top_k_before_softmax: Optional[bool] = (
bool None # whether to sample top k before softmax during KD
] = None # whether to sample top k before softmax during KD )

View File

@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
if "cross_entropy" in liger_fn_sig.parameters: if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters: if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[ kwargs["fused_linear_cross_entropy"] = (
"fused_linear_cross_entropy" cfg.liger_fused_linear_cross_entropy
] = cfg.liger_fused_linear_cross_entropy )
if "rms_norm" in liger_fn_sig.parameters: if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters: if "layer_norm" in liger_fn_sig.parameters:

View File

@@ -1,6 +1,7 @@
""" """
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
Jamba model with LigerFusedLinearCrossEntropyLoss Jamba model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
Module for the Plugin for LM Eval Harness Module for the Plugin for LM Eval Harness
""" """
import subprocess # nosec import subprocess # nosec
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin

View File

@@ -1,6 +1,7 @@
""" """
Module for handling lm eval harness input arguments. Module for handling lm eval harness input arguments.
""" """
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,6 +1,7 @@
""" """
axolotl CLI for running lm_eval tasks axolotl CLI for running lm_eval tasks
""" """
import subprocess # nosec import subprocess # nosec
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code # pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch import torch

View File

@@ -6,6 +6,7 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name # pylint: disable=invalid-name
from typing import Callable from typing import Callable

View File

@@ -1,4 +1,5 @@
"""Dequantization utilities for `bitsandbytes` integration.""" """Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement # pylint: disable=invalid-name,global-statement
import ctypes import ctypes

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl

View File

@@ -1,6 +1,7 @@
""" """
HF Transformers MambaConfig HF Transformers MambaConfig
""" """
from transformers import PretrainedConfig from transformers import PretrainedConfig

View File

@@ -1,6 +1,7 @@
""" """
Monkeypatch for Vision Llama for FA2 support Monkeypatch for Vision Llama for FA2 support
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -220,10 +221,10 @@ def patch_mllama():
True True
) )
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2 MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = (
"flash_attention_2" MllamaTextCrossFlashAttention2
] = MllamaTextCrossFlashAttention2 )
# fallback to SDPA # fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[ MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = (
"flash_attention_2" MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] )

View File

@@ -1,4 +1,5 @@
"""monkey patches for the dataset fetcher to handle batches of packed indexes""" """monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access # pylint: disable=protected-access
import torch import torch

View File

@@ -12,7 +12,9 @@ import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import (
LlamaAttention,
)
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer, LlamaDecoderLayer as OriginalLlamaDecoderLayer,
) )
@@ -490,9 +492,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel # We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask # the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad, qkv_unpad,
@@ -531,9 +535,11 @@ def flashattn_forward(
value_states, value_states,
kvpacked=True, kvpacked=True,
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
if q_unpad.dtype != kv_unpad.dtype: if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype) kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
""" """
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
""" """
from typing import Optional from typing import Optional
import torch import torch

View File

@@ -1,4 +1,5 @@
"""Flash attention monkey patch for mistral model""" """Flash attention monkey patch for mistral model"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import logging import logging
@@ -21,7 +22,10 @@ from transformers.models.mistral.modeling_mistral import (
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv from transformers.models.mistral.modeling_mistral import (
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
@@ -243,9 +247,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel # We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask # the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad, qkv_unpad,
@@ -286,9 +292,11 @@ def flashattn_forward(
value_states, value_states,
kvpacked=True, kvpacked=True,
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
if q_unpad.dtype != kv_unpad.dtype: if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype) kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
""" """
Patches to support multipack for mixtral Patches to support multipack for mixtral
""" """
import torch import torch

View File

@@ -1,4 +1,5 @@
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.""" """Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
import glob import glob
import json import json
import logging import logging
@@ -411,7 +412,10 @@ def merge_and_save(
if shard_path.endswith(".safetensors"): if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path)) in_tensors = st.load_file(str(Path(model_src) / shard_path))
else: else:
in_tensors = torch.load(Path(model_src) / shard_path) in_tensors = torch.load(
Path(model_src) / shard_path,
weights_only=True, # to prevent arbitrary code execution
)
if "state_dict" in in_tensors: if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"] in_tensors = in_tensors["state_dict"]

View File

@@ -17,7 +17,7 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """ """PyTorch StableLM Epoch model."""
import importlib import importlib
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
fix for FSDP optimizer save in trainer w 4.47.0 fix for FSDP optimizer save in trainer w 4.47.0
""" """
import inspect import inspect
import logging import logging

View File

@@ -1,6 +1,7 @@
""" """
Shared utils for the monkeypatches Shared utils for the monkeypatches
""" """
import re import re
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@@ -1,6 +1,7 @@
""" """
Fused MLP layer for incrementally improved training efficiency Fused MLP layer for incrementally improved training efficiency
""" """
import torch import torch
from transformers.models.llama.modeling_llama import LlamaMLP from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU from xformers.ops import SwiGLU

View File

@@ -1,6 +1,7 @@
""" """
Prompt strategies loader for alpaca instruction datasets with system prompts Prompt strategies loader for alpaca instruction datasets with system prompts
""" """
from typing import Generator, Tuple, Union from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
""" """
Basic completion text Basic completion text
""" """
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple

View File

@@ -1,4 +1,5 @@
"""Module containing the classes for Context QA Prompt Tokenization Strategies""" """Module containing the classes for Context QA Prompt Tokenization Strategies"""
from typing import Tuple from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
""" """
module for DPO style dataset transform strategies module for DPO style dataset transform strategies
""" """
from functools import partial from functools import partial
from ..base import load as load_base from ..base import load as load_base

View File

@@ -33,9 +33,9 @@ def default(
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>" sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>" sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
return sample return sample
@@ -52,9 +52,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample return sample
@@ -78,9 +78,9 @@ def icr(
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -100,9 +100,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -120,9 +120,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -142,9 +142,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample return sample

View File

@@ -34,9 +34,9 @@ def default(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>" sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>" sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
return sample return sample
@@ -53,9 +53,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample return sample
@@ -79,9 +79,9 @@ def icr(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -101,9 +101,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -121,9 +121,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -143,9 +143,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample return sample

View File

@@ -1,4 +1,5 @@
"""Module for plain input/output prompt pairs""" """Module for plain input/output prompt pairs"""
from typing import Generator, Tuple from typing import Generator, Tuple
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -1,4 +1,5 @@
"""Module for inspect jinja templates for the variables they use""" """Module for inspect jinja templates for the variables they use"""
from typing import Dict, Optional, Set, TypedDict, Union from typing import Dict, Optional, Set, TypedDict, Union
from jinja2 import Environment, meta, nodes from jinja2 import Environment, meta, nodes

View File

@@ -1,6 +1,7 @@
""" """
KTO strategies for chatml KTO strategies for chatml
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -15,9 +16,9 @@ def argilla(
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample
@@ -33,9 +34,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>" sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
return sample return sample
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample
@@ -74,9 +75,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample

View File

@@ -1,6 +1,7 @@
""" """
KTO strategies for llama-3 chat template KTO strategies for llama-3 chat template
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -15,9 +16,9 @@ def argilla(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample
@@ -33,9 +34,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>" sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
return sample return sample
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample
@@ -74,9 +75,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample

View File

@@ -1,6 +1,7 @@
""" """
User-defined KTO strategies User-defined KTO strategies
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code

View File

@@ -1,6 +1,7 @@
""" """
Chat dataset wrapping strategy for new internal messages representations Chat dataset wrapping strategy for new internal messages representations
""" """
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset from axolotl.core.datasets.chat import TokenizedChatDataset

View File

@@ -9,6 +9,7 @@ this one specifies the system prompt with "### System:".
Not suited/tested for multiple-turn conversations without further adjustments. Not suited/tested for multiple-turn conversations without further adjustments.
""" """
from typing import Generator, Union from typing import Generator, Union
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy

View File

@@ -1,4 +1,5 @@
"""chatml prompt tokenization strategy for ORPO""" """chatml prompt tokenization strategy for ORPO"""
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,4 +1,5 @@
"""pretraining prompt strategies""" """pretraining prompt strategies"""
from typing import Generator from typing import Generator
from transformers import BatchEncoding from transformers import BatchEncoding

View File

@@ -406,9 +406,7 @@ def handle_untrained_tokens_fix(
) )
def setup_model_and_trainer( def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder, HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel, PeftModel | PreTrainedModel,
PreTrainedTokenizer, PreTrainedTokenizer,

View File

@@ -40,6 +40,6 @@ def set_pytorch_cuda_alloc_conf():
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2: if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"PYTORCH_CUDA_ALLOC_CONF" "expandable_segments:True,roundup_power2_divisions:16"
] = "expandable_segments:True,roundup_power2_divisions:16" )

View File

@@ -1,4 +1,5 @@
"""Benchmarking and measurement utilities""" """Benchmarking and measurement utilities"""
import functools import functools
import torch import torch

View File

@@ -343,9 +343,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
bench_refs.extend(combined_bench_names[bench_name]["refs"]) bench_refs.extend(combined_bench_names[bench_name]["refs"])
bench_preds.extend(combined_bench_names[bench_name]["preds"]) bench_preds.extend(combined_bench_names[bench_name]["preds"])
if not pd.isna(bench_score): if not pd.isna(bench_score):
results[ results[f"{bench_split}_bench_accuracy_{bench_name}"] = (
f"{bench_split}_bench_accuracy_{bench_name}" bench_score
] = bench_score )
bench_scores.append(bench_score) bench_scores.append(bench_score)
else: else:
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0 results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0

View File

@@ -1,4 +1,5 @@
"""MLFlow module for trainer callbacks""" """MLFlow module for trainer callbacks"""
import logging import logging
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile

View File

@@ -1,4 +1,5 @@
"""callback to calculate perplexity as an evaluation metric.""" """callback to calculate perplexity as an evaluation metric."""
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch import torch

View File

@@ -1,6 +1,7 @@
""" """
HF Trainer callback for creating pytorch profiling snapshots HF Trainer callback for creating pytorch profiling snapshots
""" """
from pathlib import Path from pathlib import Path
from pickle import dump # nosec B403 from pickle import dump # nosec B403

View File

@@ -2,6 +2,7 @@
This module provides functionality for selecting chat templates based on user choices. This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation. These templates are used for formatting messages in a conversation.
""" """
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional

View File

@@ -1,6 +1,7 @@
""" """
shared axolotl collators for multipack, mamba, multimodal shared axolotl collators for multipack, mamba, multimodal
""" """
from .batching import ( # noqa: F401 from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,

View File

@@ -1,4 +1,5 @@
""" """
basic shared collator constants basic shared collator constants
""" """
IGNORE_INDEX = -100 IGNORE_INDEX = -100

View File

@@ -1,6 +1,7 @@
""" """
collators for Mamba collators for Mamba
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Sequence from typing import Dict, Sequence

View File

@@ -18,7 +18,11 @@ from axolotl.utils.config.models.input.v0_4_1 import (
from axolotl.utils.config.models.input.v0_4_1 import ( from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase, AxolotlInputConfig as AxolotlInputConfigBase,
) )
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset from axolotl.utils.config.models.input.v0_4_1 import (
DPODataset,
KTODataset,
SFTDataset,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config from axolotl.utils.models import load_model_config

View File

@@ -200,12 +200,12 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None field_human: Optional[str] = None
field_model: Optional[str] = None field_model: Optional[str] = None
field_messages: Optional[str] = None field_messages: Optional[str] = None
message_field_role: Optional[ message_field_role: Optional[str] = (
str None # deprecated, use message_property_mappings
] = None # deprecated, use message_property_mappings )
message_field_content: Optional[ message_field_content: Optional[str] = (
str None # deprecated, use message_property_mappings
] = None # deprecated, use message_property_mappings )
message_property_mappings: Optional[Dict[str, str]] = None message_property_mappings: Optional[Dict[str, str]] = None
message_field_training: Optional[str] = None message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None message_field_training_detail: Optional[str] = None
@@ -505,9 +505,9 @@ class HyperparametersConfig(BaseModel):
embedding_lr: Optional[float] = None embedding_lr: Optional[float] = None
embedding_lr_scale: Optional[float] = None embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[ optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = (
Union[OptimizerNames, CustomSupportedOptimizers] OptimizerNames.ADAMW_TORCH_FUSED
] = OptimizerNames.ADAMW_HF )
optim_args: Optional[Union[str, Dict[str, Any]]] = Field( optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."}, json_schema_extra={"description": "Optional arguments to supply to optimizer."},
@@ -699,9 +699,9 @@ class AxolotlInputConfig(
reward_model: Optional[bool] = None reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None num_labels: Optional[int] = None
dpo_use_weighting: Optional[ dpo_use_weighting: Optional[bool] = (
bool None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. )
dpo_use_logits_to_keep: Optional[bool] = None dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[ datasets: Optional[
@@ -780,9 +780,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype] # torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[ gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = (
Union[Literal["unsloth", "offload"], bool] Field(default=False)
] = Field(default=False) )
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None unfrozen_parameters: Optional[List[str]] = None
@@ -894,9 +894,9 @@ class AxolotlInputConfig(
kto_undesirable_weight: Optional[float] = None kto_undesirable_weight: Optional[float] = None
rl_beta: Optional[float] = None rl_beta: Optional[float] = None
max_memory: Optional[ max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = (
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] None
] = None )
gpu_memory_limit: Optional[Union[int, str]] = None gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None low_cpu_mem_usage: Optional[bool] = None

View File

@@ -1,4 +1,5 @@
"""module for gpu capabilities""" """module for gpu capabilities"""
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

View File

@@ -1,6 +1,7 @@
""" """
Data processing modules Data processing modules
""" """
from axolotl.utils.data.pretraining import ( # noqa: F401 from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,

View File

@@ -2,6 +2,7 @@
import functools import functools
import logging import logging
import os
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -344,6 +345,7 @@ def load_tokenized_prepared_datasets(
) )
ds_from_iter.save_to_disk(str(prepared_ds_path)) ds_from_iter.save_to_disk(str(prepared_ds_path))
else: else:
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(str(prepared_ds_path)) dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
LOG.info( LOG.info(

View File

@@ -1,6 +1,7 @@
""" """
utility helpers for distributed checks utility helpers for distributed checks
""" """
import os import os
import pickle # nosec import pickle # nosec
from contextlib import contextmanager from contextlib import contextmanager

View File

@@ -1,10 +1,13 @@
""" """
utils to get GPU info for the current environment utils to get GPU info for the current environment
""" """
from accelerate.utils.environment import ( from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
) )
from accelerate.utils.environment import get_gpu_info from accelerate.utils.environment import (
get_gpu_info,
)
def check_cuda_p2p_ib_support(): def check_cuda_p2p_ib_support():

Some files were not shown because too many files have changed in this diff Show More