Compare commits

...

8 Commits

Author SHA1 Message Date
Dan Saunders
156fede4f7 Update .pre-commit-config.yaml
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-03-21 10:36:18 -04:00
Dan Saunders
dcbbd7af79 sorry to revert, but pylint complained 2025-03-21 10:36:18 -04:00
Dan Saunders
21bac7ce1a running updated pre-commit plugins 2025-03-21 10:36:18 -04:00
Dan Saunders
aaa4571826 adding pre-commit auto-update GH action and bumping plugin versions 2025-03-21 10:36:17 -04:00
salman
187227d837 Fixing KTO+QLoRA+multi-GPU (#2420)
* WIP

* removing artifacts

* adding error

* adding adapter check

* linting

* simplifying check

* linting v2

* config fix -___-
2025-03-21 10:18:28 -04:00
NanoCode012
f8de8bb4f2 chore(doc): add instructions on adding custom integrations (#2422) [skip ci]
* chore(doc): add instructions on adding custom integrations

* chore: add warning help

* feat: add note about integration path

* fix: adjust text per suggestion
2025-03-21 10:18:01 -04:00
hugo
8e604848a4 add run on novita ai (#2421) [skip ci]
* add run on novita ai

* Revert "add run on novita ai"

This reverts commit 4d5df1ac6b.

* add run axolotl on novita ai
2025-03-21 10:17:47 -04:00
Wing Lian
aae4337f40 add 12.8.1 cuda to the base matrix (#2426)
* add 12.8.1 cuda to the base matrix

* use nightly

* bump deepspeed and set no binary

* deepspeed binary fixes hopefully

* install deepspeed by itself

* multiline fix

* make sure ninja is installed

* try with reversion of packaging/setuptools/wheel install

* use license instead of license-file

* try rolling back packaging and setuptools versions

* comment out license for validation for now

* make sure packaging version is consistent

* more parity across tests and docker images for packaging/setuptools
2025-03-21 10:17:25 -04:00
147 changed files with 615 additions and 326 deletions

View File

@@ -40,6 +40,12 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -61,7 +67,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/Dockerfile-base
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -0,0 +1,49 @@
name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff
jobs:
auto-update:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Update pre-commit hooks
id: update
run: |
pip install pre-commit
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
if: steps.update.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
branch: update/pre-commit-hooks
delete-branch: true
title: "chore: update pre-commit hooks"
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install wheel packaging
pip3 install wheel packaging==23.2
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -42,7 +42,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -59,7 +59,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install --upgrade packaging==23.2
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh

View File

@@ -74,7 +74,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -147,7 +147,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
- name: Install PyTorch
run: |

View File

@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
@@ -11,23 +11,23 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 25.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
rev: 7.1.2
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: c8c96d20cde3552a79858c7456bb1483bf83d633
rev: v3.3.6
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies:
@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5
rev: 1.8.3
hooks:
- id: bandit
args: [

View File

@@ -55,7 +55,7 @@ Features:
### Installation
```bash
pip3 install -U packaging setuptools wheel ninja
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs

View File

@@ -31,7 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip3 install -U packaging setuptools wheel
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -1,6 +1,7 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os

View File

@@ -1,4 +1,5 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os

View File

@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging setuptools wheel && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

View File

@@ -0,0 +1,39 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="nightly"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -85,6 +85,12 @@ gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
lora_on_cpu: true
# List[str]. Add plugins to extend the pipeline.
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
# https://axolotl-ai-cloud.github.io/axolotl/docs/custom_integrations.html
plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files

View File

@@ -55,3 +55,47 @@ sections = [
for section_name, folder_name in sections:
print(print_section(section_name, folder_name))
```
## Adding a new integration
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
To add a new integration, please follow these steps:
1. Create a new folder in the `src/axolotl/integrations` directory.
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
3. Add `__init__.py` and `args.py` files to the new folder.
- `__init__.py` should import the integration and hook into the appropriate functions.
- `args.py` should define the arguments for the integration.
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
::: {.callout-tip}
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
:::
::: {.callout-warning}
If you could not load your integration, please ensure you are pip installing in editable mode.
```bash
pip install -e .
```
and correctly spelled the integration name in the config file.
```yaml
plugins:
- axolotl.integrations.your_integration_name.YourIntegrationPlugin
```
:::
::: {.callout-note}
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
:::

View File

@@ -79,6 +79,7 @@ For providers supporting Docker:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Novita](https://novita.ai/gpus-console?templateId=311)
### Google Colab {#sec-colab}

View File

@@ -55,7 +55,7 @@ tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging>=24.2"]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
build-backend = "setuptools.build_meta"
[project]
@@ -8,7 +8,7 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer"
readme = "README.md"
requires-python = ">=3.10"
license-files = ["LICENSE"]
# license = "Apache-2.0"
[project.scripts]
axolotl = "axolotl.cli.main:main"

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.2
bitsandbytes==0.45.3
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1
@@ -10,14 +10,14 @@ autoawq==0.2.7.post3
liger-kernel==0.5.3
# END section
packaging==24.2
packaging==23.2
peft==0.15.0
transformers==4.49.0
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.4.1
deepspeed==0.16.1
deepspeed==0.16.4
trl==0.15.1
optimum==1.16.2

View File

@@ -1,6 +1,7 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset

View File

@@ -1,4 +1,5 @@
"""Script to output the correct installation command for cut-cross-entropy."""
import importlib.util
import sys
@@ -17,12 +18,12 @@ if v < V("2.4.0"):
cce_spec = importlib.util.find_spec("cut_cross_entropy")
uninstall_prefix = ""
UNINSTALL_PREFIX = ""
if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
uninstall_prefix = "pip uninstall -y cut-cross-entropy && "
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
print(
uninstall_prefix
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
)

View File

@@ -128,7 +128,7 @@ setup(
"flash-attn==2.7.4.post1",
],
"deepspeed": [
"deepspeed==0.16.1",
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -1,6 +1,7 @@
"""
launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union

View File

@@ -1,6 +1,7 @@
"""
base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod

View File

@@ -1,6 +1,7 @@
"""
Modal Cloud support from CLI
"""
import copy
import json
import os

View File

@@ -1,4 +1,5 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import logging

View File

@@ -5,7 +5,6 @@ import dataclasses
import hashlib
import json
import logging
import typing
from functools import wraps
from pathlib import Path
from types import NoneType
@@ -24,7 +23,7 @@ configure_logging()
LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | typing._SpecialForm | None):
def strip_optional_type(field_type: type | str | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.

View File

@@ -1,6 +1,5 @@
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
import json
import sys

View File

@@ -1,6 +1,7 @@
"""
ChatML transformation functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
"""
Llama 3.x chat formatting functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
"""
shared functions for format transforms
"""
from axolotl.core.chat.messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
"""
internal message representations of chat messages
"""
import json
from enum import Enum
from typing import Any, Callable, List, Optional, Union

View File

@@ -1,6 +1,7 @@
"""
chat dataset module
"""
import os
from typing import Callable, Optional, Union

View File

@@ -1,6 +1,7 @@
"""
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
"""
from typing import Any, Mapping, Union

View File

@@ -332,9 +332,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = {}
if self.cfg.include_tokens_per_second is not None:
training_arguments_kwargs[
"include_tokens_per_second"
] = self.cfg.include_tokens_per_second
training_arguments_kwargs["include_tokens_per_second"] = (
self.cfg.include_tokens_per_second
)
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
@@ -351,13 +351,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_arguments_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
training_arguments_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_arguments_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
@@ -373,9 +373,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[
"lr_quadratic_warmup"
] = self.cfg.lr_quadratic_warmup
training_arguments_kwargs["lr_quadratic_warmup"] = (
self.cfg.lr_quadratic_warmup
)
if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
@@ -399,28 +399,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
training_arguments_kwargs["dataloader_pin_memory"] = (
self.cfg.dataloader_pin_memory
)
if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
training_arguments_kwargs["dataloader_num_workers"] = (
self.cfg.dataloader_num_workers
)
if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
training_arguments_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs[
"dataloader_drop_last"
] = self.cfg.dataloader_drop_last
training_arguments_kwargs["dataloader_drop_last"] = (
self.cfg.dataloader_drop_last
)
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs[
"remove_unused_columns"
] = self.cfg.remove_unused_columns
training_arguments_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
@@ -452,9 +452,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
] = self.cfg.metric_for_best_model
training_arguments_kwargs["metric_for_best_model"] = (
self.cfg.metric_for_best_model
)
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
@@ -467,13 +467,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
training_arguments_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_arguments_kwargs[
"torch_compile_mode"
] = self.cfg.torch_compile_mode
training_arguments_kwargs["torch_compile_mode"] = (
self.cfg.torch_compile_mode
)
# DDP Config
if self.cfg.ddp_timeout:
@@ -482,32 +482,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[
"ddp_broadcast_buffers"
] = self.cfg.ddp_broadcast_buffers
training_arguments_kwargs["ddp_broadcast_buffers"] = (
self.cfg.ddp_broadcast_buffers
)
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs["per_device_train_batch_size"] = (
self.cfg.micro_batch_size
)
if self.cfg.eval_batch_size:
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[
"auto_find_batch_size"
] = self.cfg.auto_find_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs[
"eval_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs["auto_find_batch_size"] = (
self.cfg.auto_find_batch_size
)
training_arguments_kwargs["gradient_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["eval_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
@@ -554,9 +554,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
] = self.cfg.lr_scheduler
training_arguments_kwargs["alternate_lr_scheduler_type"] = (
self.cfg.lr_scheduler
)
else:
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
@@ -565,9 +565,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["cosine_constant_lr_ratio"] = (
self.cfg.cosine_constant_lr_ratio
)
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
@@ -580,40 +580,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.eval_sample_packing
)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs[
"sample_packing_bin_size"
] = self.cfg.sample_packing_bin_size
training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size
)
if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs[
"sample_packing_group_size"
] = self.cfg.sample_packing_group_size
training_arguments_kwargs["sample_packing_group_size"] = (
self.cfg.sample_packing_group_size
)
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
training_arguments_kwargs["sample_packing_efficiency"] = (
self.cfg.sample_packing_eff_est
)
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[
"relora_warmup_steps"
] = self.cfg.relora_warmup_steps
training_arguments_kwargs["relora_warmup_steps"] = (
self.cfg.relora_warmup_steps
)
if self.cfg.relora_anneal_steps:
training_arguments_kwargs[
"relora_anneal_steps"
] = self.cfg.relora_anneal_steps
training_arguments_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora_prune_ratio:
training_arguments_kwargs[
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs[
"lisa_step_interval"
] = self.cfg.lisa_step_interval
training_arguments_kwargs[
"lisa_layers_attribute"
] = self.cfg.lisa_layers_attribute
training_arguments_kwargs["lisa_step_interval"] = (
self.cfg.lisa_step_interval
)
training_arguments_kwargs["lisa_layers_attribute"] = (
self.cfg.lisa_layers_attribute
)
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
@@ -627,9 +627,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha
training_arguments_kwargs["neftune_noise_alpha"] = (
self.cfg.neftune_noise_alpha
)
trainer_kwargs = {}
@@ -731,23 +731,23 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["optim_target_modules"] = (
self.cfg.optim_target_modules
)
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["loraplus_lr_embedding"] = (
self.cfg.loraplus_lr_embedding
)
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config:
training_arguments_kwargs[
"accelerator_config"
] = self.cfg.accelerator_config
training_arguments_kwargs["accelerator_config"] = (
self.cfg.accelerator_config
)
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
@@ -756,13 +756,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[
"kd_zscore_base_temp"
] = self.cfg.kd_zscore_base_temp
training_arguments_kwargs["kd_zscore_base_temp"] = (
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs[
"kd_top_k_before_softmax"
] = self.cfg.kd_top_k_before_softmax
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -972,32 +972,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs[
"remove_unused_columns"
] = self.cfg.remove_unused_columns
training_args_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
else:
training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
training_args_kwargs["dataloader_pin_memory"] = (
self.cfg.dataloader_pin_memory
)
if self.cfg.dataloader_num_workers is not None:
training_args_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
training_args_kwargs["dataloader_num_workers"] = (
self.cfg.dataloader_num_workers
)
if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
training_args_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.gradient_checkpointing:
training_args_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
@@ -1071,9 +1071,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[
"use_logits_to_keep"
] = self.cfg.dpo_use_logits_to_keep
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
@@ -1108,9 +1108,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]

View File

@@ -462,9 +462,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
@@ -509,9 +509,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler

View File

@@ -1,6 +1,7 @@
"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer

View File

@@ -1,6 +1,7 @@
"""
Axolotl specific DPO args
"""
from dataclasses import dataclass
from trl import DPOConfig

View File

@@ -1,6 +1,7 @@
"""
DPO trainer for axolotl
"""
import gc
from functools import wraps
from typing import Any, Dict, Union

View File

@@ -45,9 +45,9 @@ class GRPOStrategy:
)
if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = trl.vllm_gpu_memory_utilization
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
trl.vllm_gpu_memory_utilization
)
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
@@ -86,9 +86,9 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes
)
return trainer_kwargs
@classmethod

View File

@@ -1,6 +1,7 @@
"""
Axolotl Specific Training Args
"""
from dataclasses import dataclass
from trl import GRPOConfig

View File

@@ -1,6 +1,7 @@
"""
Axolotl GRPO trainer
"""
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel

View File

@@ -1,6 +1,7 @@
"""
module for TRL PPO training
"""
import torch
from tqdm import tqdm
from trl import PPOTrainer

View File

@@ -1,6 +1,7 @@
"""
extra axolotl specific training args
"""
from dataclasses import dataclass, field
from typing import Optional

View File

@@ -1,6 +1,7 @@
"""
Grokfast plugin for Axolotl
"""
import logging
from transformers.trainer_callback import TrainerCallback

View File

@@ -1,6 +1,7 @@
"""
config args for grokfast plugin
"""
from typing import Optional
from pydantic import BaseModel

View File

@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
"""
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
kd_ce_alpha: Optional[float] = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[
bool
] = None # whether to sample top k before softmax during KD
kd_top_k_before_softmax: Optional[bool] = (
None # whether to sample top k before softmax during KD
)

View File

@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[
"fused_linear_cross_entropy"
] = cfg.liger_fused_linear_cross_entropy
kwargs["fused_linear_cross_entropy"] = (
cfg.liger_fused_linear_cross_entropy
)
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:

View File

@@ -1,6 +1,7 @@
"""
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
"""
Jamba model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
"""
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from axolotl.integrations.base import BasePlugin

View File

@@ -1,6 +1,7 @@
"""
Module for handling lm eval harness input arguments.
"""
from typing import List, Optional
from pydantic import BaseModel

View File

@@ -1,6 +1,7 @@
"""
axolotl CLI for running lm_eval tasks
"""
import subprocess # nosec
from collections import defaultdict
from datetime import datetime

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch

View File

@@ -6,6 +6,7 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name
from typing import Callable

View File

@@ -1,4 +1,5 @@
"""Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement
import ctypes

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
import torch
import triton
import triton.language as tl

View File

@@ -1,6 +1,7 @@
"""
HF Transformers MambaConfig
"""
from transformers import PretrainedConfig

View File

@@ -1,6 +1,7 @@
"""
Monkeypatch for Vision Llama for FA2 support
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple
@@ -220,10 +221,10 @@ def patch_mllama():
True
)
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
"flash_attention_2"
] = MllamaTextCrossFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = (
MllamaTextCrossFlashAttention2
)
# fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[
"flash_attention_2"
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = (
MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
)

View File

@@ -1,4 +1,5 @@
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access
import torch

View File

@@ -12,7 +12,9 @@ import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import (
LlamaAttention,
)
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
@@ -490,9 +492,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
@@ -531,9 +535,11 @@ def flashattn_forward(
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
"""
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
"""
from typing import Optional
import torch

View File

@@ -1,4 +1,5 @@
"""Flash attention monkey patch for mistral model"""
# pylint: disable=duplicate-code
import logging
@@ -21,7 +22,10 @@ from transformers.models.mistral.modeling_mistral import (
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
from transformers.models.mistral.modeling_mistral import (
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
@@ -243,9 +247,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
@@ -286,9 +292,11 @@ def flashattn_forward(
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
"""
Patches to support multipack for mixtral
"""
import torch

View File

@@ -1,4 +1,5 @@
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
import glob
import json
import logging
@@ -411,7 +412,10 @@ def merge_and_save(
if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path))
else:
in_tensors = torch.load(Path(model_src) / shard_path)
in_tensors = torch.load(
Path(model_src) / shard_path,
weights_only=True, # to prevent arbitrary code execution
)
if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"]

View File

@@ -17,7 +17,7 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """
"""PyTorch StableLM Epoch model."""
import importlib
import math
from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging

View File

@@ -1,6 +1,7 @@
"""
Shared utils for the monkeypatches
"""
import re
from typing import Optional, Tuple

View File

@@ -1,6 +1,7 @@
"""
Fused MLP layer for incrementally improved training efficiency
"""
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU

View File

@@ -1,6 +1,7 @@
"""
Prompt strategies loader for alpaca instruction datasets with system prompts
"""
from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
"""
Basic completion text
"""
from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple

View File

@@ -1,4 +1,5 @@
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
"""
module for DPO style dataset transform strategies
"""
from functools import partial
from ..base import load as load_base

View File

@@ -33,9 +33,9 @@ def default(
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
return sample
@@ -52,9 +52,9 @@ def argilla_chat(
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample
@@ -78,9 +78,9 @@ def icr(
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
@@ -100,9 +100,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
@@ -120,9 +120,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
@@ -142,9 +142,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample

View File

@@ -34,9 +34,9 @@ def default(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
return sample
@@ -53,9 +53,9 @@ def argilla_chat(
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample
@@ -79,9 +79,9 @@ def icr(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
@@ -101,9 +101,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
@@ -121,9 +121,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
@@ -143,9 +143,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample

View File

@@ -1,4 +1,5 @@
"""Module for plain input/output prompt pairs"""
from typing import Generator, Tuple
from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -1,4 +1,5 @@
"""Module for inspect jinja templates for the variables they use"""
from typing import Dict, Optional, Set, TypedDict, Union
from jinja2 import Environment, meta, nodes

View File

@@ -1,6 +1,7 @@
"""
KTO strategies for chatml
"""
# pylint: disable=duplicate-code
@@ -15,9 +16,9 @@ def argilla(
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
@@ -33,9 +34,9 @@ def argilla_chat(
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
return sample
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
@@ -74,9 +75,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample

View File

@@ -1,6 +1,7 @@
"""
KTO strategies for llama-3 chat template
"""
# pylint: disable=duplicate-code
@@ -15,9 +16,9 @@ def argilla(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
@@ -33,9 +34,9 @@ def argilla_chat(
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
return sample
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
@@ -74,9 +75,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample

View File

@@ -1,6 +1,7 @@
"""
User-defined KTO strategies
"""
# pylint: disable=duplicate-code

View File

@@ -1,6 +1,7 @@
"""
Chat dataset wrapping strategy for new internal messages representations
"""
from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset

View File

@@ -9,6 +9,7 @@ this one specifies the system prompt with "### System:".
Not suited/tested for multiple-turn conversations without further adjustments.
"""
from typing import Generator, Union
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy

View File

@@ -1,4 +1,5 @@
"""chatml prompt tokenization strategy for ORPO"""
from typing import Any, Dict, Generator, List, Optional, Tuple
from pydantic import BaseModel

View File

@@ -1,4 +1,5 @@
"""pretraining prompt strategies"""
from typing import Generator
from transformers import BatchEncoding

View File

@@ -406,9 +406,7 @@ def handle_untrained_tokens_fix(
)
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,

View File

@@ -40,6 +40,6 @@ def set_pytorch_cuda_alloc_conf():
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"expandable_segments:True,roundup_power2_divisions:16"
)

View File

@@ -1,4 +1,5 @@
"""Benchmarking and measurement utilities"""
import functools
import torch

View File

@@ -343,9 +343,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
bench_refs.extend(combined_bench_names[bench_name]["refs"])
bench_preds.extend(combined_bench_names[bench_name]["preds"])
if not pd.isna(bench_score):
results[
f"{bench_split}_bench_accuracy_{bench_name}"
] = bench_score
results[f"{bench_split}_bench_accuracy_{bench_name}"] = (
bench_score
)
bench_scores.append(bench_score)
else:
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0

View File

@@ -1,4 +1,5 @@
"""MLFlow module for trainer callbacks"""
import logging
from shutil import copyfile
from tempfile import NamedTemporaryFile

View File

@@ -1,4 +1,5 @@
"""callback to calculate perplexity as an evaluation metric."""
from typing import Dict, List, Optional
import torch

View File

@@ -1,6 +1,7 @@
"""
HF Trainer callback for creating pytorch profiling snapshots
"""
from pathlib import Path
from pickle import dump # nosec B403

View File

@@ -2,6 +2,7 @@
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional

View File

@@ -1,6 +1,7 @@
"""
shared axolotl collators for multipack, mamba, multimodal
"""
from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,

View File

@@ -1,4 +1,5 @@
"""
basic shared collator constants
"""
IGNORE_INDEX = -100

View File

@@ -1,6 +1,7 @@
"""
collators for Mamba
"""
from dataclasses import dataclass
from typing import Dict, Sequence

View File

@@ -18,7 +18,11 @@ from axolotl.utils.config.models.input.v0_4_1 import (
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset
from axolotl.utils.config.models.input.v0_4_1 import (
DPODataset,
KTODataset,
SFTDataset,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config

View File

@@ -200,12 +200,12 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None
field_model: Optional[str] = None
field_messages: Optional[str] = None
message_field_role: Optional[
str
] = None # deprecated, use message_property_mappings
message_field_content: Optional[
str
] = None # deprecated, use message_property_mappings
message_field_role: Optional[str] = (
None # deprecated, use message_property_mappings
)
message_field_content: Optional[str] = (
None # deprecated, use message_property_mappings
)
message_property_mappings: Optional[Dict[str, str]] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
@@ -505,9 +505,9 @@ class HyperparametersConfig(BaseModel):
embedding_lr: Optional[float] = None
embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, CustomSupportedOptimizers]
] = OptimizerNames.ADAMW_TORCH_FUSED
optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = (
OptimizerNames.ADAMW_TORCH_FUSED
)
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
@@ -699,9 +699,9 @@ class AxolotlInputConfig(
reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
dpo_use_weighting: Optional[bool] = (
None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
)
dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[
@@ -780,9 +780,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[
Union[Literal["unsloth", "offload"], bool]
] = Field(default=False)
gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = (
Field(default=False)
)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
@@ -894,9 +894,9 @@ class AxolotlInputConfig(
kto_undesirable_weight: Optional[float] = None
rl_beta: Optional[float] = None
max_memory: Optional[
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
] = None
max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = (
None
)
gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None
@@ -1679,6 +1679,30 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_rl_config_gradient_checkpointing(cls, data):
# TODO: SalmanMohammadi
# Distributed RL with QLoRA + gradient checkpointing
# and use_reentrant = True is broken upstream in TRL
# pylint: disable=too-many-boolean-expressions
if (
data.get("rl")
and data.get("gradient_checkpointing")
and data.get("gradient_checkpointing_kwargs")
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
and data.get("load_in_4bit")
and data.get("adapter") == "qlora"
and data.get("capabilities")
and data.get("capabilities").get("n_gpu", 1) > 1
):
raise ValueError(
"The `use_reentrant: True` implementation of gradient checkpointing "
"is not supported for distributed RL training with QLoRA. Please set "
"`use_reentrant: False` in `gradient_checkpointing_kwargs`."
)
return data
@model_validator(mode="before")
@classmethod
def check_kto_config(cls, data):
@@ -1689,15 +1713,6 @@ class AxolotlInputConfig(
if data.get("remove_unused_columns") is not False:
raise ValueError("Set `remove_unused_columns: False` when using kto")
if data.get("gradient_checkpointing") and not (
data.get("gradient_checkpointing_kwargs")
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
):
raise ValueError(
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
)
return data

View File

@@ -1,4 +1,5 @@
"""module for gpu capabilities"""
from typing import Optional
from pydantic import BaseModel, Field

View File

@@ -1,6 +1,7 @@
"""
Data processing modules
"""
from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining,
wrap_pretraining_dataset,

View File

@@ -1,6 +1,7 @@
"""
utility helpers for distributed checks
"""
import os
import pickle # nosec
from contextlib import contextmanager

View File

@@ -1,10 +1,13 @@
"""
utils to get GPU info for the current environment
"""
from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
)
from accelerate.utils.environment import get_gpu_info
from accelerate.utils.environment import (
get_gpu_info,
)
def check_cuda_p2p_ib_support():

Some files were not shown because too many files have changed in this diff Show More