Compare commits
1 Commits
pre-commit
...
kto_fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92c217677c |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -40,12 +40,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: nightly
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -67,7 +61,7 @@ jobs:
|
|||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
|
file: ./docker/Dockerfile-base
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|||||||
49
.github/workflows/precommit-autoupdate.yml
vendored
49
.github/workflows/precommit-autoupdate.yml
vendored
@@ -1,49 +0,0 @@
|
|||||||
name: Pre-commit auto-update
|
|
||||||
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
- cron: '0 0 * * 0' # Run weekly
|
|
||||||
workflow_dispatch: # Manual kickoff
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
auto-update:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
pull-requests: write
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Update pre-commit hooks
|
|
||||||
id: update
|
|
||||||
run: |
|
|
||||||
pip install pre-commit
|
|
||||||
pre-commit autoupdate
|
|
||||||
if [[ -n $(git status --porcelain) ]]; then
|
|
||||||
echo "changes=true" >> $GITHUB_OUTPUT
|
|
||||||
git diff .pre-commit-config.yaml > pre-commit-update.diff
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Create Pull Request
|
|
||||||
if: steps.update.outputs.changes == 'true'
|
|
||||||
uses: peter-evans/create-pull-request@v6
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
branch: update/pre-commit-hooks
|
|
||||||
delete-branch: true
|
|
||||||
title: "chore: update pre-commit hooks"
|
|
||||||
commit-message: "chore: update pre-commit hooks"
|
|
||||||
body: |
|
|
||||||
Automated PR to update pre-commit hooks to their latest versions.
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>Changes:</summary>
|
|
||||||
|
|
||||||
```diff
|
|
||||||
${{ steps.update.outputs.diff }}
|
|
||||||
```
|
|
||||||
</details>
|
|
||||||
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install wheel packaging==23.2
|
pip3 install wheel packaging
|
||||||
pip3 install --no-build-isolation -e .
|
pip3 install --no-build-isolation -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
|||||||
4
.github/workflows/tests-nightly.yml
vendored
4
.github/workflows/tests-nightly.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
pip3 install --upgrade packaging setuptools wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -59,7 +59,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging==23.2
|
pip3 install --upgrade packaging
|
||||||
pip3 install --no-build-isolation -U -e .
|
pip3 install --no-build-isolation -U -e .
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/unsloth_install.py | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
pip3 install --upgrade packaging setuptools wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -147,7 +147,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
|
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ default_language_version:
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v4.4.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -11,23 +11,23 @@ repos:
|
|||||||
- id: no-commit-to-branch
|
- id: no-commit-to-branch
|
||||||
args: ['--branch', 'main']
|
args: ['--branch', 'main']
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 25.1.0
|
rev: 23.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 6.0.1
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.1.2
|
rev: 6.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/PyCQA/pylint
|
||||||
rev: v3.3.6
|
rev: v3.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.15.0
|
rev: v1.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
@@ -36,7 +36,7 @@ repos:
|
|||||||
'pydantic>=2.5.3',
|
'pydantic>=2.5.3',
|
||||||
]
|
]
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.8.3
|
rev: 1.7.5
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
args: [
|
args: [
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ Features:
|
|||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
|
|
||||||
# Download example axolotl configs, deepspeed configs
|
# Download example axolotl configs, deepspeed configs
|
||||||
|
|||||||
@@ -32,9 +32,8 @@ website:
|
|||||||
contents:
|
contents:
|
||||||
- docs/getting-started.qmd
|
- docs/getting-started.qmd
|
||||||
- docs/installation.qmd
|
- docs/installation.qmd
|
||||||
- docs/inference.qmd
|
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/config.qmd
|
- docs/inference.qmd
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
@@ -75,6 +74,10 @@ website:
|
|||||||
- docs/debugging.qmd
|
- docs/debugging.qmd
|
||||||
- docs/nccl.qmd
|
- docs/nccl.qmd
|
||||||
|
|
||||||
|
- section: "Reference"
|
||||||
|
contents:
|
||||||
|
- docs/config.qmd
|
||||||
|
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
theme: darkly
|
theme: darkly
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
modal application to run axolotl gpu tests in Modal
|
modal application to run axolotl gpu tests in Modal
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Modal app to run axolotl GPU tests"""
|
"""Modal app to run axolotl GPU tests"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
ARG CUDA_VERSION="12.8.1"
|
|
||||||
ARG CUDNN_VERSION="8"
|
|
||||||
ARG UBUNTU_VERSION="22.04"
|
|
||||||
ARG MAX_JOBS=4
|
|
||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
|
||||||
|
|
||||||
ARG PYTHON_VERSION="3.11"
|
|
||||||
ARG PYTORCH_VERSION="nightly"
|
|
||||||
ARG CUDA="128"
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
|
||||||
|
|
||||||
RUN apt-get update \
|
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
|
||||||
&& wget \
|
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
|
||||||
&& mkdir /root/.conda \
|
|
||||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
|
||||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
|
||||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
|
||||||
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
|
||||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
|
||||||
|
|
||||||
RUN git lfs install --skip-repo && \
|
|
||||||
pip3 install awscli && \
|
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: Config Reference
|
title: Config options
|
||||||
description: A complete list of all configuration options.
|
description: A complete list of all configuration options.
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -30,8 +30,6 @@ tokenizer_legacy:
|
|||||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||||
# This is reported to improve training speed on some models
|
# This is reported to improve training speed on some models
|
||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
|
||||||
shrink_embeddings:
|
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
@@ -85,12 +83,6 @@ gpu_memory_limit: 20GiB
|
|||||||
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
||||||
lora_on_cpu: true
|
lora_on_cpu: true
|
||||||
|
|
||||||
# List[str]. Add plugins to extend the pipeline.
|
|
||||||
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
|
|
||||||
# https://axolotl-ai-cloud.github.io/axolotl/docs/custom_integrations.html
|
|
||||||
plugins:
|
|
||||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
# A list of one or more datasets to finetune the model with
|
# A list of one or more datasets to finetune the model with
|
||||||
datasets:
|
datasets:
|
||||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||||
@@ -213,46 +205,10 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
# use RL training: 'dpo', 'ipo', 'kto'
|
||||||
rl:
|
rl:
|
||||||
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
# whether to perform weighting if doing DPO training. Boolean.
|
||||||
|
dpo_use_weighting:
|
||||||
# dpo
|
|
||||||
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
|
||||||
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
|
||||||
|
|
||||||
# orpo
|
|
||||||
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
|
||||||
|
|
||||||
# kto
|
|
||||||
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
|
||||||
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
|
||||||
|
|
||||||
# simpo
|
|
||||||
cpo_alpha: 1.0 # Weight of the BC regularizer
|
|
||||||
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
|
||||||
|
|
||||||
# grpo
|
|
||||||
trl:
|
|
||||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
|
||||||
vllm_device: # Optional[str]. Device to use for VLLM.
|
|
||||||
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
|
||||||
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
|
||||||
vllm_dtype: # Optional[str]. Data type for VLLM.
|
|
||||||
|
|
||||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
|
||||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
|
||||||
|
|
||||||
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
|
||||||
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
|
||||||
|
|
||||||
num_generations: # Optional[int]. Number of generations to sample.
|
|
||||||
log_completions: # Optional[bool]. Whether to log completions.
|
|
||||||
|
|
||||||
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
|
||||||
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
|
||||||
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
|
||||||
|
|
||||||
|
|
||||||
# reward modelling: `True` or `False`
|
# reward modelling: `True` or `False`
|
||||||
reward_model:
|
reward_model:
|
||||||
@@ -276,7 +232,7 @@ default_system_message: You are a helpful assistant. Please give a long and deta
|
|||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
# Push prepared dataset to hub
|
# Push prepared dataset to hub
|
||||||
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
push_dataset_to_hub: # repo path
|
||||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||||
# if not set.
|
# if not set.
|
||||||
dataset_processes: # defaults to os.cpu_count() if not set
|
dataset_processes: # defaults to os.cpu_count() if not set
|
||||||
|
|||||||
@@ -55,47 +55,3 @@ sections = [
|
|||||||
for section_name, folder_name in sections:
|
for section_name, folder_name in sections:
|
||||||
print(print_section(section_name, folder_name))
|
print(print_section(section_name, folder_name))
|
||||||
```
|
```
|
||||||
|
|
||||||
## Adding a new integration
|
|
||||||
|
|
||||||
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
|
|
||||||
|
|
||||||
To add a new integration, please follow these steps:
|
|
||||||
|
|
||||||
1. Create a new folder in the `src/axolotl/integrations` directory.
|
|
||||||
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
|
|
||||||
3. Add `__init__.py` and `args.py` files to the new folder.
|
|
||||||
- `__init__.py` should import the integration and hook into the appropriate functions.
|
|
||||||
- `args.py` should define the arguments for the integration.
|
|
||||||
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
|
|
||||||
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
::: {.callout-warning}
|
|
||||||
|
|
||||||
If you could not load your integration, please ensure you are pip installing in editable mode.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
and correctly spelled the integration name in the config file.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.your_integration_name.YourIntegrationPlugin
|
|
||||||
```
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
::: {.callout-note}
|
|
||||||
|
|
||||||
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
|
|
||||||
|
|
||||||
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|||||||
10
docs/faq.qmd
10
docs/faq.qmd
@@ -27,16 +27,6 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||||
|
|
||||||
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
|
|
||||||
|
|
||||||
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
|
|
||||||
|
|
||||||
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
|
|
||||||
|
|
||||||
**Q: How to call Axolotl via custom python scripts?**
|
|
||||||
|
|
||||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
@@ -36,9 +36,7 @@ The YAML configuration file controls everything about your training. Here's what
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: NousResearch/Llama-3.2-1B
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
load_in_8bit: true
|
|
||||||
adapter: lora
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
@@ -46,15 +44,11 @@ datasets:
|
|||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
|
|
||||||
|
|
||||||
- To perform Full finetuning, remove these two lines.
|
|
||||||
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
|
||||||
:::
|
|
||||||
|
|
||||||
See our [Config options](config.qmd) for more details.
|
See our [Config options](config.qmd) for more details.
|
||||||
|
|
||||||
### Training {#sec-training}
|
### Training {#sec-training}
|
||||||
@@ -62,7 +56,7 @@ See our [Config options](config.qmd) for more details.
|
|||||||
When you run `axolotl train`, Axolotl:
|
When you run `axolotl train`, Axolotl:
|
||||||
|
|
||||||
1. Downloads the base model
|
1. Downloads the base model
|
||||||
2. (If specified) applies QLoRA/LoRA adapter layers
|
2. (If specified) applies LoRA adapter layers
|
||||||
3. Loads and processes the dataset
|
3. Loads and processes the dataset
|
||||||
4. Runs the training loop
|
4. Runs the training loop
|
||||||
5. Saves the trained model and / or LoRA weights
|
5. Saves the trained model and / or LoRA weights
|
||||||
@@ -75,8 +69,6 @@ Let's modify the example for your own data:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|
||||||
# Training settings
|
# Training settings
|
||||||
@@ -112,6 +104,8 @@ format):
|
|||||||
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
|
||||||
|
|
||||||
3. Run the training:
|
3. Run the training:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "Inference and Merging"
|
title: "Inference"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
@@ -9,14 +9,10 @@ execute:
|
|||||||
enabled: false
|
enabled: false
|
||||||
---
|
---
|
||||||
|
|
||||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
|
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
|
||||||
|
|
||||||
## Quick Start {#sec-quickstart}
|
## Quick Start {#sec-quickstart}
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Use the same config used for training on inference/merging.
|
|
||||||
:::
|
|
||||||
|
|
||||||
### Basic Inference {#sec-basic}
|
### Basic Inference {#sec-basic}
|
||||||
|
|
||||||
::: {.panel-tabset}
|
::: {.panel-tabset}
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
### PyPI Installation (Recommended) {#sec-pypi}
|
### PyPI Installation (Recommended) {#sec-pypi}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -38,7 +37,7 @@ For the latest features between releases:
|
|||||||
```{.bash}
|
```{.bash}
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
pip3 install packaging ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -79,7 +78,6 @@ For providers supporting Docker:
|
|||||||
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||||
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
||||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
- [Novita](https://novita.ai/gpus-console?templateId=311)
|
|
||||||
|
|
||||||
### Google Colab {#sec-colab}
|
### Google Colab {#sec-colab}
|
||||||
|
|
||||||
@@ -109,7 +107,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
|||||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||||
3. Install Axolotl:
|
3. Install Axolotl:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
pip3 install packaging
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
4. (Optional) Login to Hugging Face:
|
4. (Optional) Login to Hugging Face:
|
||||||
|
|||||||
@@ -66,10 +66,6 @@ logic to be compatible with more of them.
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
|
|
||||||
:::
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||||
|
|||||||
@@ -41,10 +41,6 @@ Bradley-Terry chat templates expect single-turn conversations in the following f
|
|||||||
|
|
||||||
### Process Reward Models (PRM)
|
### Process Reward Models (PRM)
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
|
|
||||||
:::
|
|
||||||
|
|
||||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||||
```yaml
|
```yaml
|
||||||
base_model: Qwen/Qwen2.5-3B
|
base_model: Qwen/Qwen2.5-3B
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
|
|
||||||
### IPO
|
### IPO
|
||||||
|
|
||||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: ipo
|
rl: ipo
|
||||||
@@ -344,9 +344,8 @@ ORPO supports the following types with the following dataset format:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: kto
|
rl: kto
|
||||||
rl_beta: 0.1 # default
|
rl_beta: 0.5
|
||||||
kto_desirable_weight: 1.0 # default
|
kto_desirable_weight: 0.2
|
||||||
kto_undesirable_weight: 1.0 # default
|
|
||||||
|
|
||||||
remove_unused_columns: false
|
remove_unused_columns: false
|
||||||
|
|
||||||
@@ -498,10 +497,6 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
|
|
||||||
### GRPO
|
### GRPO
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
|
||||||
:::
|
|
||||||
|
|
||||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||||
|
|
||||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||||
@@ -545,19 +540,6 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
|
|||||||
|
|
||||||
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||||
|
|
||||||
### SimPO
|
|
||||||
|
|
||||||
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
rl: simpo
|
|
||||||
rl_beta: 0.1 # default in CPOTrainer
|
|
||||||
cpo_alpha: 1.0 # default in CPOTrainer
|
|
||||||
simpo_gamma: 0.5 # default in CPOTrainer
|
|
||||||
```
|
|
||||||
|
|
||||||
This method uses the same dataset format as [DPO](#dpo).
|
|
||||||
|
|
||||||
### Using local dataset files
|
### Using local dataset files
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ tf32: true
|
|||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: false
|
use_reentrant: true
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
|
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
@@ -8,7 +8,6 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
|
|||||||
description = "LLM Trainer"
|
description = "LLM Trainer"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
# license = "Apache-2.0"
|
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
axolotl = "axolotl.cli.main:main"
|
axolotl = "axolotl.cli.main:main"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.45.3
|
bitsandbytes==0.45.2
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
flash-attn==2.7.4.post1
|
flash-attn==2.7.4.post1
|
||||||
@@ -12,12 +12,12 @@ liger-kernel==0.5.3
|
|||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.15.0
|
peft==0.14.0
|
||||||
transformers==4.49.0
|
transformers==4.49.0
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.0
|
||||||
accelerate==1.5.2
|
accelerate==1.3.0
|
||||||
datasets==3.4.1
|
datasets==3.2.0
|
||||||
deepspeed==0.16.4
|
deepspeed==0.16.1
|
||||||
trl==0.15.1
|
trl==0.15.1
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
helper script to parse chat datasets into a usable yaml
|
helper script to parse chat datasets into a usable yaml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import yaml
|
import yaml
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Script to output the correct installation command for cut-cross-entropy."""
|
"""Script to output the correct installation command for cut-cross-entropy."""
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -128,7 +128,7 @@ setup(
|
|||||||
"flash-attn==2.7.4.post1",
|
"flash-attn==2.7.4.post1",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.16.4",
|
"deepspeed==0.16.1",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
launch axolotl in supported cloud platforms
|
launch axolotl in supported cloud platforms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
base class for cloud platforms from cli
|
base class for cloud platforms from cli
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Modal Cloud support from CLI
|
Modal Cloud support from CLI
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Click CLI definitions for various axolotl commands."""
|
"""Click CLI definitions for various axolotl commands."""
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import dataclasses
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import typing
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
@@ -23,7 +24,7 @@ configure_logging()
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def strip_optional_type(field_type: type | str | None):
|
def strip_optional_type(field_type: type | typing._SpecialForm | None):
|
||||||
"""
|
"""
|
||||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
|
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
ChatML transformation functions for MessageContents
|
ChatML transformation functions for MessageContents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..messages import MessageContents, Messages
|
from ..messages import MessageContents, Messages
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Llama 3.x chat formatting functions for MessageContents
|
Llama 3.x chat formatting functions for MessageContents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..messages import MessageContents, Messages
|
from ..messages import MessageContents, Messages
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
shared functions for format transforms
|
shared functions for format transforms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.core.chat.messages import MessageContents, Messages
|
from axolotl.core.chat.messages import MessageContents, Messages
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
internal message representations of chat messages
|
internal message representations of chat messages
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
chat dataset module
|
chat dataset module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Mapping, Union
|
from typing import Any, Mapping, Union
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
|
|||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
AxolotlPRMTrainer,
|
AxolotlPRMTrainer,
|
||||||
@@ -51,6 +50,7 @@ from axolotl.core.trainers.base import (
|
|||||||
from axolotl.core.trainers.dpo import DPOStrategy
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
|
from axolotl.core.trainers.kto import AxolotlKTOTrainer
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlCPOConfig,
|
AxolotlCPOConfig,
|
||||||
AxolotlKTOConfig,
|
AxolotlKTOConfig,
|
||||||
@@ -332,9 +332,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs = {}
|
training_arguments_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.include_tokens_per_second is not None:
|
if self.cfg.include_tokens_per_second is not None:
|
||||||
training_arguments_kwargs["include_tokens_per_second"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.include_tokens_per_second
|
"include_tokens_per_second"
|
||||||
)
|
] = self.cfg.include_tokens_per_second
|
||||||
|
|
||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
training_arguments_kwargs["bf16_full_eval"] = True
|
training_arguments_kwargs["bf16_full_eval"] = True
|
||||||
@@ -351,13 +351,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
training_arguments_kwargs["gradient_checkpointing"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.gradient_checkpointing
|
"gradient_checkpointing"
|
||||||
)
|
] = self.cfg.gradient_checkpointing
|
||||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||||
training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.gradient_checkpointing_kwargs
|
"gradient_checkpointing_kwargs"
|
||||||
)
|
] = self.cfg.gradient_checkpointing_kwargs
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
@@ -373,9 +373,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||||
|
|
||||||
if self.cfg.lr_quadratic_warmup is not None:
|
if self.cfg.lr_quadratic_warmup is not None:
|
||||||
training_arguments_kwargs["lr_quadratic_warmup"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.lr_quadratic_warmup
|
"lr_quadratic_warmup"
|
||||||
)
|
] = self.cfg.lr_quadratic_warmup
|
||||||
|
|
||||||
if self.cfg.adam_beta1:
|
if self.cfg.adam_beta1:
|
||||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||||
@@ -399,28 +399,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
training_arguments_kwargs["dataloader_pin_memory"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.dataloader_pin_memory
|
"dataloader_pin_memory"
|
||||||
)
|
] = self.cfg.dataloader_pin_memory
|
||||||
if self.cfg.dataloader_num_workers is not None:
|
if self.cfg.dataloader_num_workers is not None:
|
||||||
training_arguments_kwargs["dataloader_num_workers"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.dataloader_num_workers
|
"dataloader_num_workers"
|
||||||
)
|
] = self.cfg.dataloader_num_workers
|
||||||
if self.cfg.dataloader_prefetch_factor is not None:
|
if self.cfg.dataloader_prefetch_factor is not None:
|
||||||
training_arguments_kwargs["dataloader_prefetch_factor"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.dataloader_prefetch_factor
|
"dataloader_prefetch_factor"
|
||||||
)
|
] = self.cfg.dataloader_prefetch_factor
|
||||||
if self.cfg.dataloader_drop_last is not None:
|
if self.cfg.dataloader_drop_last is not None:
|
||||||
training_arguments_kwargs["dataloader_drop_last"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.dataloader_drop_last
|
"dataloader_drop_last"
|
||||||
)
|
] = self.cfg.dataloader_drop_last
|
||||||
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||||
training_arguments_kwargs["dataloader_drop_last"] = True
|
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||||
|
|
||||||
if self.cfg.remove_unused_columns is not None:
|
if self.cfg.remove_unused_columns is not None:
|
||||||
training_arguments_kwargs["remove_unused_columns"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.remove_unused_columns
|
"remove_unused_columns"
|
||||||
)
|
] = self.cfg.remove_unused_columns
|
||||||
|
|
||||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||||
# no eval set, so don't eval
|
# no eval set, so don't eval
|
||||||
@@ -452,9 +452,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.do_causal_lm_eval:
|
if self.cfg.do_causal_lm_eval:
|
||||||
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
||||||
if self.cfg.metric_for_best_model:
|
if self.cfg.metric_for_best_model:
|
||||||
training_arguments_kwargs["metric_for_best_model"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.metric_for_best_model
|
"metric_for_best_model"
|
||||||
)
|
] = self.cfg.metric_for_best_model
|
||||||
if self.cfg.greater_is_better:
|
if self.cfg.greater_is_better:
|
||||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||||
|
|
||||||
@@ -467,13 +467,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||||
if self.cfg.torch_compile_backend:
|
if self.cfg.torch_compile_backend:
|
||||||
training_arguments_kwargs["torch_compile_backend"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.torch_compile_backend
|
"torch_compile_backend"
|
||||||
)
|
] = self.cfg.torch_compile_backend
|
||||||
if self.cfg.torch_compile_mode:
|
if self.cfg.torch_compile_mode:
|
||||||
training_arguments_kwargs["torch_compile_mode"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.torch_compile_mode
|
"torch_compile_mode"
|
||||||
)
|
] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
# DDP Config
|
# DDP Config
|
||||||
if self.cfg.ddp_timeout:
|
if self.cfg.ddp_timeout:
|
||||||
@@ -482,32 +482,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.ddp_bucket_cap_mb:
|
if self.cfg.ddp_bucket_cap_mb:
|
||||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||||
if self.cfg.ddp_broadcast_buffers is not None:
|
if self.cfg.ddp_broadcast_buffers is not None:
|
||||||
training_arguments_kwargs["ddp_broadcast_buffers"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.ddp_broadcast_buffers
|
"ddp_broadcast_buffers"
|
||||||
)
|
] = self.cfg.ddp_broadcast_buffers
|
||||||
|
|
||||||
# these are all the "standard" kwargs that are def used
|
# these are all the "standard" kwargs that are def used
|
||||||
training_arguments_kwargs["max_steps"] = (
|
training_arguments_kwargs["max_steps"] = (
|
||||||
total_num_steps if self.cfg.max_steps else -1
|
total_num_steps if self.cfg.max_steps else -1
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||||
training_arguments_kwargs["per_device_train_batch_size"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.micro_batch_size
|
"per_device_train_batch_size"
|
||||||
)
|
] = self.cfg.micro_batch_size
|
||||||
if self.cfg.eval_batch_size:
|
if self.cfg.eval_batch_size:
|
||||||
training_arguments_kwargs["per_device_eval_batch_size"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.eval_batch_size
|
"per_device_eval_batch_size"
|
||||||
)
|
] = self.cfg.eval_batch_size
|
||||||
if self.cfg.auto_find_batch_size is not None:
|
if self.cfg.auto_find_batch_size is not None:
|
||||||
training_arguments_kwargs["auto_find_batch_size"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.auto_find_batch_size
|
"auto_find_batch_size"
|
||||||
)
|
] = self.cfg.auto_find_batch_size
|
||||||
training_arguments_kwargs["gradient_accumulation_steps"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.gradient_accumulation_steps
|
"gradient_accumulation_steps"
|
||||||
)
|
] = self.cfg.gradient_accumulation_steps
|
||||||
training_arguments_kwargs["eval_accumulation_steps"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.gradient_accumulation_steps
|
"eval_accumulation_steps"
|
||||||
)
|
] = self.cfg.gradient_accumulation_steps
|
||||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||||
@@ -554,9 +554,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
training_arguments_kwargs["alternate_lr_scheduler_type"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.lr_scheduler
|
"alternate_lr_scheduler_type"
|
||||||
)
|
] = self.cfg.lr_scheduler
|
||||||
else:
|
else:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||||
@@ -565,9 +565,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
||||||
training_arguments_kwargs["cosine_constant_lr_ratio"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.cosine_constant_lr_ratio
|
"cosine_constant_lr_ratio"
|
||||||
)
|
] = self.cfg.cosine_constant_lr_ratio
|
||||||
training_arguments_kwargs["weight_decay"] = (
|
training_arguments_kwargs["weight_decay"] = (
|
||||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||||
)
|
)
|
||||||
@@ -580,40 +580,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.eval_sample_packing
|
self.cfg.eval_sample_packing
|
||||||
)
|
)
|
||||||
if self.cfg.sample_packing_bin_size is not None:
|
if self.cfg.sample_packing_bin_size is not None:
|
||||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.sample_packing_bin_size
|
"sample_packing_bin_size"
|
||||||
)
|
] = self.cfg.sample_packing_bin_size
|
||||||
if self.cfg.sample_packing_group_size is not None:
|
if self.cfg.sample_packing_group_size is not None:
|
||||||
training_arguments_kwargs["sample_packing_group_size"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.sample_packing_group_size
|
"sample_packing_group_size"
|
||||||
)
|
] = self.cfg.sample_packing_group_size
|
||||||
if self.cfg.sample_packing_eff_est:
|
if self.cfg.sample_packing_eff_est:
|
||||||
training_arguments_kwargs["sample_packing_efficiency"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.sample_packing_eff_est
|
"sample_packing_efficiency"
|
||||||
)
|
] = self.cfg.sample_packing_eff_est
|
||||||
|
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs["relora_warmup_steps"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.relora_warmup_steps
|
"relora_warmup_steps"
|
||||||
)
|
] = self.cfg.relora_warmup_steps
|
||||||
if self.cfg.relora_anneal_steps:
|
if self.cfg.relora_anneal_steps:
|
||||||
training_arguments_kwargs["relora_anneal_steps"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.relora_anneal_steps
|
"relora_anneal_steps"
|
||||||
)
|
] = self.cfg.relora_anneal_steps
|
||||||
if self.cfg.relora_prune_ratio:
|
if self.cfg.relora_prune_ratio:
|
||||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.relora_prune_ratio
|
"relora_prune_ratio"
|
||||||
)
|
] = self.cfg.relora_prune_ratio
|
||||||
|
|
||||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||||
training_arguments_kwargs["lisa_step_interval"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.lisa_step_interval
|
"lisa_step_interval"
|
||||||
)
|
] = self.cfg.lisa_step_interval
|
||||||
training_arguments_kwargs["lisa_layers_attribute"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.lisa_layers_attribute
|
"lisa_layers_attribute"
|
||||||
)
|
] = self.cfg.lisa_layers_attribute
|
||||||
|
|
||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
@@ -627,9 +627,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.neftune_noise_alpha is not None:
|
if self.cfg.neftune_noise_alpha is not None:
|
||||||
training_arguments_kwargs["neftune_noise_alpha"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.neftune_noise_alpha
|
"neftune_noise_alpha"
|
||||||
)
|
] = self.cfg.neftune_noise_alpha
|
||||||
|
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
@@ -731,23 +731,23 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
importlib.import_module("torchdistx")
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
if self.cfg.optim_target_modules:
|
if self.cfg.optim_target_modules:
|
||||||
training_arguments_kwargs["optim_target_modules"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.optim_target_modules
|
"optim_target_modules"
|
||||||
)
|
] = self.cfg.optim_target_modules
|
||||||
|
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
|
||||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
training_arguments_kwargs["loraplus_lr_embedding"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.loraplus_lr_embedding
|
"loraplus_lr_embedding"
|
||||||
)
|
] = self.cfg.loraplus_lr_embedding
|
||||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
|
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
training_arguments_kwargs["accelerator_config"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.accelerator_config
|
"accelerator_config"
|
||||||
)
|
] = self.cfg.accelerator_config
|
||||||
|
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
@@ -756,13 +756,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.kd_temperature is not None:
|
if self.cfg.kd_temperature is not None:
|
||||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||||
if self.cfg.kd_zscore_base_temp is not None:
|
if self.cfg.kd_zscore_base_temp is not None:
|
||||||
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.kd_zscore_base_temp
|
"kd_zscore_base_temp"
|
||||||
)
|
] = self.cfg.kd_zscore_base_temp
|
||||||
if self.cfg.kd_top_k_before_softmax is not None:
|
if self.cfg.kd_top_k_before_softmax is not None:
|
||||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.kd_top_k_before_softmax
|
"kd_top_k_before_softmax"
|
||||||
)
|
] = self.cfg.kd_top_k_before_softmax
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
@@ -972,32 +972,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
)
|
)
|
||||||
if self.cfg.remove_unused_columns is not None:
|
if self.cfg.remove_unused_columns is not None:
|
||||||
training_args_kwargs["remove_unused_columns"] = (
|
training_args_kwargs[
|
||||||
self.cfg.remove_unused_columns
|
"remove_unused_columns"
|
||||||
)
|
] = self.cfg.remove_unused_columns
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["remove_unused_columns"] = False
|
training_args_kwargs["remove_unused_columns"] = False
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
training_args_kwargs["dataloader_pin_memory"] = (
|
training_args_kwargs[
|
||||||
self.cfg.dataloader_pin_memory
|
"dataloader_pin_memory"
|
||||||
)
|
] = self.cfg.dataloader_pin_memory
|
||||||
if self.cfg.dataloader_num_workers is not None:
|
if self.cfg.dataloader_num_workers is not None:
|
||||||
training_args_kwargs["dataloader_num_workers"] = (
|
training_args_kwargs[
|
||||||
self.cfg.dataloader_num_workers
|
"dataloader_num_workers"
|
||||||
)
|
] = self.cfg.dataloader_num_workers
|
||||||
if self.cfg.dataloader_prefetch_factor is not None:
|
if self.cfg.dataloader_prefetch_factor is not None:
|
||||||
training_args_kwargs["dataloader_prefetch_factor"] = (
|
training_args_kwargs[
|
||||||
self.cfg.dataloader_prefetch_factor
|
"dataloader_prefetch_factor"
|
||||||
)
|
] = self.cfg.dataloader_prefetch_factor
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
training_args_kwargs["gradient_checkpointing"] = (
|
training_args_kwargs[
|
||||||
self.cfg.gradient_checkpointing
|
"gradient_checkpointing"
|
||||||
)
|
] = self.cfg.gradient_checkpointing
|
||||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||||
training_args_kwargs["gradient_checkpointing_kwargs"] = (
|
training_args_kwargs[
|
||||||
self.cfg.gradient_checkpointing_kwargs
|
"gradient_checkpointing_kwargs"
|
||||||
)
|
] = self.cfg.gradient_checkpointing_kwargs
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
||||||
"use_reentrant": False
|
"use_reentrant": False
|
||||||
@@ -1071,9 +1071,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||||
training_args_kwargs["use_logits_to_keep"] = (
|
training_args_kwargs[
|
||||||
self.cfg.dpo_use_logits_to_keep
|
"use_logits_to_keep"
|
||||||
)
|
] = self.cfg.dpo_use_logits_to_keep
|
||||||
|
|
||||||
for blocklist_key in blocklist_args_kwargs:
|
for blocklist_key in blocklist_args_kwargs:
|
||||||
if blocklist_key in training_args_kwargs:
|
if blocklist_key in training_args_kwargs:
|
||||||
@@ -1108,9 +1108,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
|
dpo_trainer_kwargs[
|
||||||
self.cfg.precompute_ref_log_probs
|
"precompute_ref_log_probs"
|
||||||
)
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl == "grpo":
|
if self.cfg.rl == "grpo":
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|||||||
@@ -20,9 +20,10 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sequential
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
from trl import CPOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
|
from axolotl.core.trainers.kto import AxolotlKTOTrainer
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
@@ -462,9 +463,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
if self.args.dataloader_prefetch_factor:
|
if self.args.dataloader_prefetch_factor:
|
||||||
dataloader_params["prefetch_factor"] = (
|
dataloader_params[
|
||||||
self.args.dataloader_prefetch_factor
|
"prefetch_factor"
|
||||||
)
|
] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
sampler = self._get_train_sampler()
|
sampler = self._get_train_sampler()
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
@@ -509,9 +510,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
if self.args.dataloader_prefetch_factor:
|
if self.args.dataloader_prefetch_factor:
|
||||||
dataloader_params["prefetch_factor"] = (
|
dataloader_params[
|
||||||
self.args.dataloader_prefetch_factor
|
"prefetch_factor"
|
||||||
)
|
] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
if isinstance(eval_sampler, BatchSampler):
|
if isinstance(eval_sampler, BatchSampler):
|
||||||
dataloader_params["batch_sampler"] = eval_sampler
|
dataloader_params["batch_sampler"] = eval_sampler
|
||||||
@@ -874,14 +875,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
DPO Specific Strategy for training
|
DPO Specific Strategy for training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Axolotl specific DPO args
|
Axolotl specific DPO args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from trl import DPOConfig
|
from trl import DPOConfig
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
DPO trainer for axolotl
|
DPO trainer for axolotl
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Union
|
||||||
|
|||||||
@@ -45,9 +45,9 @@ class GRPOStrategy:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if trl.vllm_gpu_memory_utilization:
|
if trl.vllm_gpu_memory_utilization:
|
||||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
grpo_args_kwargs[
|
||||||
trl.vllm_gpu_memory_utilization
|
"vllm_gpu_memory_utilization"
|
||||||
)
|
] = trl.vllm_gpu_memory_utilization
|
||||||
|
|
||||||
if trl.vllm_max_model_len:
|
if trl.vllm_max_model_len:
|
||||||
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||||
@@ -86,9 +86,9 @@ class GRPOStrategy:
|
|||||||
def set_trainer_kwargs(cls, cfg):
|
def set_trainer_kwargs(cls, cfg):
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||||
trainer_kwargs["reward_processing_classes"] = (
|
trainer_kwargs[
|
||||||
cfg.trl.reward_processing_classes
|
"reward_processing_classes"
|
||||||
)
|
] = cfg.trl.reward_processing_classes
|
||||||
return trainer_kwargs
|
return trainer_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Axolotl Specific Training Args
|
Axolotl Specific Training Args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from trl import GRPOConfig
|
from trl import GRPOConfig
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Axolotl GRPO trainer
|
Axolotl GRPO trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from accelerate.utils import is_peft_model
|
from accelerate.utils import is_peft_model
|
||||||
from accelerate.utils.other import is_compiled_module
|
from accelerate.utils.other import is_compiled_module
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|||||||
7
src/axolotl/core/trainers/kto/__init__.py
Normal file
7
src/axolotl/core/trainers/kto/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
KTO package initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from axolotl.core.trainers.kto.trainer import AxolotlKTOTrainer
|
||||||
|
|
||||||
|
__all__ = ["AxolotlKTOTrainer"]
|
||||||
512
src/axolotl/core/trainers/kto/trainer.py
Normal file
512
src/axolotl/core/trainers/kto/trainer.py
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
"""
|
||||||
|
KTO trainer implementation for Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from datasets import Dataset
|
||||||
|
from torch.utils.data import DataLoader, SequentialSampler
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
BaseImageProcessor,
|
||||||
|
DataCollator,
|
||||||
|
FeatureExtractionMixin,
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
ProcessorMixin,
|
||||||
|
Trainer,
|
||||||
|
TrainerCallback,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
from transformers.trainer_utils import EvalLoopOutput
|
||||||
|
from trl import KTOTrainer
|
||||||
|
from trl.trainer.kto_config import KTOConfig
|
||||||
|
from trl.trainer.utils import KTODataCollatorWithPadding, pad_to_length
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import SchedulerMixin
|
||||||
|
|
||||||
|
# Check if PEFT is available
|
||||||
|
try:
|
||||||
|
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training, peft_module_casting_to_bf16
|
||||||
|
is_peft_available = True
|
||||||
|
except ImportError:
|
||||||
|
is_peft_available = False
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.core.trainers.kto")
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(SchedulerMixin, Trainer):
|
||||||
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||||
|
args: KTOConfig = None,
|
||||||
|
train_dataset: Optional[Dataset] = None,
|
||||||
|
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||||
|
processing_class: Optional[
|
||||||
|
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||||
|
] = None,
|
||||||
|
data_collator: Optional[DataCollator] = None,
|
||||||
|
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||||
|
callbacks: Optional[list[TrainerCallback]] = None,
|
||||||
|
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||||
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||||
|
peft_config: Optional[dict] = None,
|
||||||
|
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||||
|
dataset_tags=None,
|
||||||
|
model_adapter_name: Optional[str] = None,
|
||||||
|
ref_adapter_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
|
self._tag_names = ["trl", "kto"]
|
||||||
|
if hasattr(self, "tag_names"):
|
||||||
|
self._tag_names.extend(self.tag_names)
|
||||||
|
|
||||||
|
if type(args) is TrainingArguments:
|
||||||
|
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
||||||
|
|
||||||
|
if args.model_init_kwargs is None:
|
||||||
|
model_init_kwargs = {}
|
||||||
|
elif not isinstance(model, str):
|
||||||
|
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
||||||
|
else:
|
||||||
|
model_init_kwargs = args.model_init_kwargs
|
||||||
|
torch_dtype = model_init_kwargs.get("torch_dtype")
|
||||||
|
if torch_dtype is not None:
|
||||||
|
# Convert to `torch.dtype` if an str is passed
|
||||||
|
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
||||||
|
torch_dtype = getattr(torch, torch_dtype)
|
||||||
|
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
||||||
|
)
|
||||||
|
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
|
||||||
|
if args.ref_model_init_kwargs is None:
|
||||||
|
ref_model_init_kwargs = {}
|
||||||
|
elif not isinstance(ref_model, str):
|
||||||
|
raise ValueError(
|
||||||
|
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ref_model_init_kwargs = args.ref_model_init_kwargs
|
||||||
|
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
||||||
|
if torch_dtype is not None:
|
||||||
|
# Convert to `torch.dtype` if an str is passed
|
||||||
|
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
||||||
|
torch_dtype = getattr(torch, torch_dtype)
|
||||||
|
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
||||||
|
)
|
||||||
|
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
||||||
|
|
||||||
|
if isinstance(model, str):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
||||||
|
|
||||||
|
if isinstance(ref_model, str):
|
||||||
|
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
||||||
|
|
||||||
|
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
||||||
|
# has been called in order to properly call autocast if needed.
|
||||||
|
self._peft_has_been_casted_to_bf16 = False
|
||||||
|
|
||||||
|
if not is_peft_available() and peft_config is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
||||||
|
)
|
||||||
|
elif is_peft_available() and peft_config is not None:
|
||||||
|
# if model is a peft model and we have a peft_config, we merge and unload it first
|
||||||
|
if isinstance(model, PeftModel):
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
|
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
||||||
|
_support_gc_kwargs = hasattr(
|
||||||
|
args, "gradient_checkpointing_kwargs"
|
||||||
|
) and "gradient_checkpointing_kwargs" in list(
|
||||||
|
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||||
|
|
||||||
|
if _support_gc_kwargs:
|
||||||
|
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||||
|
|
||||||
|
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||||
|
elif getattr(args, "gradient_checkpointing", False):
|
||||||
|
# For backward compatibility with older versions of transformers
|
||||||
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
|
||||||
|
def make_inputs_require_grad(module, input, output):
|
||||||
|
output.requires_grad_(True)
|
||||||
|
|
||||||
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||||
|
|
||||||
|
# get peft model with the given config
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
||||||
|
peft_module_casting_to_bf16(model)
|
||||||
|
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
||||||
|
self._peft_has_been_casted_to_bf16 = True
|
||||||
|
|
||||||
|
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
||||||
|
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
||||||
|
# fail or completely fail.
|
||||||
|
elif getattr(args, "gradient_checkpointing", False):
|
||||||
|
# For backward compatibility with older versions of transformers
|
||||||
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
|
||||||
|
def make_inputs_require_grad(module, input, output):
|
||||||
|
output.requires_grad_(True)
|
||||||
|
|
||||||
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||||
|
|
||||||
|
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
||||||
|
raise ValueError(
|
||||||
|
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
||||||
|
" Please install `wandb` or `comet-ml` to resolve."
|
||||||
|
)
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||||
|
elif args.is_encoder_decoder is None:
|
||||||
|
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
||||||
|
else:
|
||||||
|
self.is_encoder_decoder = args.is_encoder_decoder
|
||||||
|
|
||||||
|
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
||||||
|
self.model_adapter_name = model_adapter_name
|
||||||
|
self.ref_adapter_name = ref_adapter_name
|
||||||
|
|
||||||
|
if ref_model:
|
||||||
|
self.ref_model = ref_model
|
||||||
|
elif self.is_peft_model or args.precompute_ref_log_probs:
|
||||||
|
# The `model` with adapters turned off will be used as the reference model
|
||||||
|
self.ref_model = None
|
||||||
|
else:
|
||||||
|
self.ref_model = create_reference_model(model)
|
||||||
|
|
||||||
|
if processing_class is None:
|
||||||
|
raise ValueError(
|
||||||
|
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
||||||
|
)
|
||||||
|
if args.max_length is None:
|
||||||
|
warnings.warn(
|
||||||
|
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
||||||
|
" it will be set to `512` by default, but you should do it yourself in the future.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
max_length = 512
|
||||||
|
if args.max_length is not None:
|
||||||
|
max_length = args.max_length
|
||||||
|
|
||||||
|
if args.max_prompt_length is None:
|
||||||
|
warnings.warn(
|
||||||
|
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
||||||
|
" it will be set to `128` by default, but you should do it yourself in the future.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
max_prompt_length = 128
|
||||||
|
if args.max_prompt_length is not None:
|
||||||
|
max_prompt_length = args.max_prompt_length
|
||||||
|
|
||||||
|
max_completion_length = None
|
||||||
|
if args.max_completion_length is None and self.is_encoder_decoder:
|
||||||
|
warnings.warn(
|
||||||
|
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
||||||
|
" it will be set to `128` by default, but you should do it yourself in the future.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
max_completion_length = 128
|
||||||
|
if args.max_completion_length is not None and self.is_encoder_decoder:
|
||||||
|
max_completion_length = args.max_completion_length
|
||||||
|
|
||||||
|
if data_collator is None:
|
||||||
|
data_collator = DPODataCollatorWithPadding(
|
||||||
|
pad_token_id=processing_class.pad_token_id,
|
||||||
|
label_pad_token_id=args.label_pad_token_id,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.remove_unused_columns:
|
||||||
|
args.remove_unused_columns = False
|
||||||
|
# warn users
|
||||||
|
warnings.warn(
|
||||||
|
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
||||||
|
" we have set it for you, but you should do it yourself in the future.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_dpo_data_collator = True
|
||||||
|
else:
|
||||||
|
self.use_dpo_data_collator = False
|
||||||
|
|
||||||
|
# Disable dropout in the model and reference model
|
||||||
|
if args.disable_dropout:
|
||||||
|
disable_dropout_in_model(model)
|
||||||
|
if self.ref_model is not None:
|
||||||
|
disable_dropout_in_model(self.ref_model)
|
||||||
|
|
||||||
|
self.loss_type = args.loss_type
|
||||||
|
self.max_length = max_length
|
||||||
|
self.generate_during_eval = args.generate_during_eval
|
||||||
|
self.label_pad_token_id = args.label_pad_token_id
|
||||||
|
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
||||||
|
self.max_prompt_length = max_prompt_length
|
||||||
|
self.truncation_mode = args.truncation_mode
|
||||||
|
self.max_completion_length = max_completion_length
|
||||||
|
self.processing_class = processing_class
|
||||||
|
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
||||||
|
|
||||||
|
# Not all losses require a KL calculation
|
||||||
|
self.calculate_KL = True
|
||||||
|
if self.loss_type in ["apo_zero_unpaired"]:
|
||||||
|
self.calculate_KL = False
|
||||||
|
|
||||||
|
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
||||||
|
# keep track of first called to avoid computation of future calls
|
||||||
|
self._precomputed_train_ref_log_probs = False
|
||||||
|
self._precomputed_eval_ref_log_probs = False
|
||||||
|
|
||||||
|
# metric
|
||||||
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
# KTO parameter
|
||||||
|
self.beta = args.beta
|
||||||
|
self.desirable_weight = args.desirable_weight
|
||||||
|
self.undesirable_weight = args.undesirable_weight
|
||||||
|
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
||||||
|
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
||||||
|
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
||||||
|
warnings.warn(
|
||||||
|
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
||||||
|
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
||||||
|
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
||||||
|
"loss.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||||
|
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
|
||||||
|
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
||||||
|
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
||||||
|
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
||||||
|
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
||||||
|
# issued.
|
||||||
|
model.warnings_issued["estimate_tokens"] = True
|
||||||
|
|
||||||
|
# Compute that only on the main process for faster data processing.
|
||||||
|
# see: https://github.com/huggingface/trl/pull/1255
|
||||||
|
with PartialState().local_main_process_first():
|
||||||
|
# Extract the prompt if needed
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
||||||
|
)
|
||||||
|
# Unpair the dataset if needed
|
||||||
|
train_dataset = maybe_unpair_preference_dataset(
|
||||||
|
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
||||||
|
)
|
||||||
|
# Apply the chat template if needed
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
maybe_apply_chat_template,
|
||||||
|
fn_kwargs={"tokenizer": processing_class},
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Applying chat template to train dataset",
|
||||||
|
)
|
||||||
|
if eval_dataset is not None:
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
||||||
|
)
|
||||||
|
eval_dataset = maybe_unpair_preference_dataset(
|
||||||
|
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
||||||
|
)
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
maybe_apply_chat_template,
|
||||||
|
fn_kwargs={"tokenizer": processing_class},
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Applying chat template to eval dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize and prepare the training datasets
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
_tokenize,
|
||||||
|
batched=True,
|
||||||
|
fn_kwargs={"tokenizer": self.processing_class},
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Tokenizing train dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
fn_kwargs = {
|
||||||
|
"prefix": "",
|
||||||
|
"is_encoder_decoder": self.is_encoder_decoder,
|
||||||
|
"tokenizer": self.processing_class,
|
||||||
|
"max_length": self.max_length,
|
||||||
|
"truncation_mode": self.truncation_mode,
|
||||||
|
"label_pad_token_id": self.label_pad_token_id,
|
||||||
|
"max_prompt_length": self.max_prompt_length,
|
||||||
|
"max_completion_length": self.max_completion_length,
|
||||||
|
}
|
||||||
|
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
_process_tokens,
|
||||||
|
fn_kwargs=fn_kwargs,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Processing tokenized train dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize and prepare the eval datasets
|
||||||
|
if eval_dataset is not None:
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
_tokenize,
|
||||||
|
fn_kwargs={"tokenizer": self.processing_class},
|
||||||
|
batched=True,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Tokenizing eval dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
_process_tokens,
|
||||||
|
fn_kwargs=fn_kwargs,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Processing tokenized eval dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get KL datasets if needed
|
||||||
|
if self.calculate_KL:
|
||||||
|
if args.per_device_train_batch_size <= 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
||||||
|
)
|
||||||
|
|
||||||
|
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
||||||
|
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
|
||||||
|
train_kl_dataset = train_dataset.map(
|
||||||
|
_get_kl_dataset,
|
||||||
|
batched=True,
|
||||||
|
batch_size=args.per_device_train_batch_size,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Extracting KL train dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
fn_kwargs["prefix"] = "KL_"
|
||||||
|
train_kl_dataset = train_kl_dataset.map(
|
||||||
|
_process_tokens,
|
||||||
|
fn_kwargs=fn_kwargs,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
||||||
|
desc="Processing tokenized train KL dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# merge the datasets
|
||||||
|
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
||||||
|
|
||||||
|
if eval_dataset is not None:
|
||||||
|
# Get KL dataset
|
||||||
|
eval_kl_dataset = eval_dataset.map(
|
||||||
|
_get_kl_dataset,
|
||||||
|
batched=True,
|
||||||
|
batch_size=args.per_device_train_batch_size,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
desc="Extracting eval KL dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_kl_dataset = eval_kl_dataset.map(
|
||||||
|
_process_tokens,
|
||||||
|
fn_kwargs=fn_kwargs,
|
||||||
|
num_proc=args.dataset_num_proc,
|
||||||
|
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
||||||
|
desc="Processing tokenized eval KL dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
# merge the datasets
|
||||||
|
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
||||||
|
|
||||||
|
# calculate dataset desirability balance
|
||||||
|
num_desirable = max(sum(train_dataset["label"]), 1)
|
||||||
|
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
||||||
|
|
||||||
|
if num_desirable != num_undesirable:
|
||||||
|
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
|
||||||
|
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
||||||
|
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
||||||
|
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
||||||
|
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
||||||
|
|
||||||
|
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
||||||
|
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
||||||
|
|
||||||
|
if not (des_weight_in_range or und_weight_in_range):
|
||||||
|
warnings.warn(
|
||||||
|
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
||||||
|
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
||||||
|
f"on your data, we recommend EITHER "
|
||||||
|
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
||||||
|
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
||||||
|
"See the documentation on how to optimally set these weights.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
data_collator=data_collator,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
processing_class=processing_class,
|
||||||
|
model_init=model_init,
|
||||||
|
compute_metrics=compute_metrics,
|
||||||
|
callbacks=callbacks,
|
||||||
|
optimizers=optimizers,
|
||||||
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
||||||
|
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
||||||
|
# self.model_accepts_loss_kwargs to False to enable scaling.
|
||||||
|
self.model_accepts_loss_kwargs = False
|
||||||
|
|
||||||
|
# Add tags for models that have been loaded with the correct transformers version
|
||||||
|
if hasattr(self.model, "add_model_tags"):
|
||||||
|
self.model.add_model_tags(self._tag_names)
|
||||||
|
|
||||||
|
if not hasattr(self, "accelerator"):
|
||||||
|
raise AttributeError(
|
||||||
|
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.ref_model is None:
|
||||||
|
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
||||||
|
raise ValueError(
|
||||||
|
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||||
|
else:
|
||||||
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
module for TRL PPO training
|
module for TRL PPO training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
extra axolotl specific training args
|
extra axolotl specific training args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Grokfast plugin for Axolotl
|
Grokfast plugin for Axolotl
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from transformers.trainer_callback import TrainerCallback
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
config args for grokfast plugin
|
config args for grokfast plugin
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|||||||
@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
||||||
kd_ce_alpha: Optional[float] = (
|
kd_ce_alpha: Optional[
|
||||||
None # loss coefficient for cross-entropy loss during KD
|
float
|
||||||
)
|
] = None # loss coefficient for cross-entropy loss during KD
|
||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||||
kd_top_k_before_softmax: Optional[bool] = (
|
kd_top_k_before_softmax: Optional[
|
||||||
None # whether to sample top k before softmax during KD
|
bool
|
||||||
)
|
] = None # whether to sample top k before softmax during KD
|
||||||
|
|||||||
@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
|
|||||||
if "cross_entropy" in liger_fn_sig.parameters:
|
if "cross_entropy" in liger_fn_sig.parameters:
|
||||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||||
kwargs["fused_linear_cross_entropy"] = (
|
kwargs[
|
||||||
cfg.liger_fused_linear_cross_entropy
|
"fused_linear_cross_entropy"
|
||||||
)
|
] = cfg.liger_fused_linear_cross_entropy
|
||||||
if "rms_norm" in liger_fn_sig.parameters:
|
if "rms_norm" in liger_fn_sig.parameters:
|
||||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||||
if "layer_norm" in liger_fn_sig.parameters:
|
if "layer_norm" in liger_fn_sig.parameters:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
|
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Jamba model with LigerFusedLinearCrossEntropyLoss
|
Jamba model with LigerFusedLinearCrossEntropyLoss
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Module for the Plugin for LM Eval Harness
|
Module for the Plugin for LM Eval Harness
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Module for handling lm eval harness input arguments.
|
Module for handling lm eval harness input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
axolotl CLI for running lm_eval tasks
|
axolotl CLI for running lm_eval tasks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
|||||||
|
|
||||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
|
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
|
|||||||
|
|
||||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Dequantization utilities for `bitsandbytes` integration."""
|
"""Dequantization utilities for `bitsandbytes` integration."""
|
||||||
|
|
||||||
# pylint: disable=invalid-name,global-statement
|
# pylint: disable=invalid-name,global-statement
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
|||||||
|
|
||||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
HF Transformers MambaConfig
|
HF Transformers MambaConfig
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Monkeypatch for Vision Llama for FA2 support
|
Monkeypatch for Vision Llama for FA2 support
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@@ -221,10 +220,10 @@ def patch_mllama():
|
|||||||
True
|
True
|
||||||
)
|
)
|
||||||
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
|
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
|
||||||
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = (
|
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
|
||||||
MllamaTextCrossFlashAttention2
|
"flash_attention_2"
|
||||||
)
|
] = MllamaTextCrossFlashAttention2
|
||||||
# fallback to SDPA
|
# fallback to SDPA
|
||||||
MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = (
|
MLLAMA_VISION_ATTENTION_CLASSES[
|
||||||
MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
|
"flash_attention_2"
|
||||||
)
|
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -12,9 +12,7 @@ import transformers
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
LlamaAttention,
|
|
||||||
)
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||||
)
|
)
|
||||||
@@ -492,11 +490,9 @@ def flashattn_forward(
|
|||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
key_padding_mask=attention_mask,
|
key_padding_mask=attention_mask,
|
||||||
query_padding_mask=(
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
attention_mask[:, -query_states.size(1) :]
|
if attention_mask is not None
|
||||||
if attention_mask is not None
|
else None,
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv_unpad,
|
qkv_unpad,
|
||||||
@@ -535,11 +531,9 @@ def flashattn_forward(
|
|||||||
value_states,
|
value_states,
|
||||||
kvpacked=True,
|
kvpacked=True,
|
||||||
key_padding_mask=attention_mask,
|
key_padding_mask=attention_mask,
|
||||||
query_padding_mask=(
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
attention_mask[:, -query_states.size(1) :]
|
if attention_mask is not None
|
||||||
if attention_mask is not None
|
else None,
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if q_unpad.dtype != kv_unpad.dtype:
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Flash attention monkey patch for mistral model"""
|
"""Flash attention monkey patch for mistral model"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -22,10 +21,7 @@ from transformers.models.mistral.modeling_mistral import (
|
|||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||||
apply_rotary_pos_emb,
|
|
||||||
repeat_kv,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
@@ -247,11 +243,9 @@ def flashattn_forward(
|
|||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
key_padding_mask=attention_mask,
|
key_padding_mask=attention_mask,
|
||||||
query_padding_mask=(
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
attention_mask[:, -query_states.size(1) :]
|
if attention_mask is not None
|
||||||
if attention_mask is not None
|
else None,
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv_unpad,
|
qkv_unpad,
|
||||||
@@ -292,11 +286,9 @@ def flashattn_forward(
|
|||||||
value_states,
|
value_states,
|
||||||
kvpacked=True,
|
kvpacked=True,
|
||||||
key_padding_mask=attention_mask,
|
key_padding_mask=attention_mask,
|
||||||
query_padding_mask=(
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
attention_mask[:, -query_states.size(1) :]
|
if attention_mask is not None
|
||||||
if attention_mask is not None
|
else None,
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if q_unpad.dtype != kv_unpad.dtype:
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Patches to support multipack for mixtral
|
Patches to support multipack for mixtral
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
|
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -412,10 +411,7 @@ def merge_and_save(
|
|||||||
if shard_path.endswith(".safetensors"):
|
if shard_path.endswith(".safetensors"):
|
||||||
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
||||||
else:
|
else:
|
||||||
in_tensors = torch.load(
|
in_tensors = torch.load(Path(model_src) / shard_path)
|
||||||
Path(model_src) / shard_path,
|
|
||||||
weights_only=True, # to prevent arbitrary code execution
|
|
||||||
)
|
|
||||||
if "state_dict" in in_tensors:
|
if "state_dict" in in_tensors:
|
||||||
in_tensors = in_tensors["state_dict"]
|
in_tensors = in_tensors["state_dict"]
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
"""PyTorch StableLM Epoch model."""
|
""" PyTorch StableLM Epoch model. """
|
||||||
import importlib
|
import importlib
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
fix for FSDP optimizer save in trainer w 4.47.0
|
fix for FSDP optimizer save in trainer w 4.47.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Shared utils for the monkeypatches
|
Shared utils for the monkeypatches
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Fused MLP layer for incrementally improved training efficiency
|
Fused MLP layer for incrementally improved training efficiency
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||||
from xformers.ops import SwiGLU
|
from xformers.ops import SwiGLU
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Prompt strategies loader for alpaca instruction datasets with system prompts
|
Prompt strategies loader for alpaca instruction datasets with system prompts
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Generator, Tuple, Union
|
from typing import Generator, Tuple, Union
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Basic completion text
|
Basic completion text
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, Generator, Optional, Tuple
|
from typing import Any, Dict, Generator, Optional, Tuple
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
|
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
module for DPO style dataset transform strategies
|
module for DPO style dataset transform strategies
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from ..base import load as load_base
|
from ..base import load as load_base
|
||||||
|
|||||||
@@ -33,9 +33,9 @@ def default(
|
|||||||
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -52,9 +52,9 @@ def argilla_chat(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -78,9 +78,9 @@ def icr(
|
|||||||
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -100,9 +100,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -120,9 +120,9 @@ def prompt_pairs(
|
|||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -142,9 +142,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|||||||
@@ -34,9 +34,9 @@ def default(
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -53,9 +53,9 @@ def argilla_chat(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -79,9 +79,9 @@ def icr(
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -101,9 +101,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -121,9 +121,9 @@ def prompt_pairs(
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
@@ -143,9 +143,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Module for plain input/output prompt pairs"""
|
"""Module for plain input/output prompt pairs"""
|
||||||
|
|
||||||
from typing import Generator, Tuple
|
from typing import Generator, Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Module for inspect jinja templates for the variables they use"""
|
"""Module for inspect jinja templates for the variables they use"""
|
||||||
|
|
||||||
from typing import Dict, Optional, Set, TypedDict, Union
|
from typing import Dict, Optional, Set, TypedDict, Union
|
||||||
|
|
||||||
from jinja2 import Environment, meta, nodes
|
from jinja2 import Environment, meta, nodes
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
KTO strategies for chatml
|
KTO strategies for chatml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
@@ -16,9 +15,9 @@ def argilla(
|
|||||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -34,9 +33,9 @@ def argilla_chat(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
|
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -56,9 +55,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -75,9 +74,9 @@ def prompt_pairs(
|
|||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -97,9 +96,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
"prompt"
|
||||||
)
|
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
KTO strategies for llama-3 chat template
|
KTO strategies for llama-3 chat template
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
@@ -16,9 +15,9 @@ def argilla(
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -34,9 +33,9 @@ def argilla_chat(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def transform_fn(sample):
|
def transform_fn(sample):
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
|
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -56,9 +55,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -75,9 +74,9 @@ def prompt_pairs(
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@@ -97,9 +96,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
|||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = (
|
sample[
|
||||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
"prompt"
|
||||||
)
|
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
User-defined KTO strategies
|
User-defined KTO strategies
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Chat dataset wrapping strategy for new internal messages representations
|
Chat dataset wrapping strategy for new internal messages representations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
from axolotl.core.datasets.chat import TokenizedChatDataset
|
from axolotl.core.datasets.chat import TokenizedChatDataset
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ this one specifies the system prompt with "### System:".
|
|||||||
|
|
||||||
Not suited/tested for multiple-turn conversations without further adjustments.
|
Not suited/tested for multiple-turn conversations without further adjustments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Generator, Union
|
from typing import Generator, Union
|
||||||
|
|
||||||
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy
|
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""chatml prompt tokenization strategy for ORPO"""
|
"""chatml prompt tokenization strategy for ORPO"""
|
||||||
|
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""pretraining prompt strategies"""
|
"""pretraining prompt strategies"""
|
||||||
|
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
from transformers import BatchEncoding
|
from transformers import BatchEncoding
|
||||||
|
|||||||
@@ -406,7 +406,9 @@ def handle_untrained_tokens_fix(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
def setup_model_and_trainer(
|
||||||
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
|
) -> tuple[
|
||||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
|
|||||||
@@ -40,6 +40,6 @@ def set_pytorch_cuda_alloc_conf():
|
|||||||
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
if torch_major == 2 and torch_minor >= 2:
|
if torch_major == 2 and torch_minor >= 2:
|
||||||
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
os.environ[
|
||||||
"expandable_segments:True,roundup_power2_divisions:16"
|
"PYTORCH_CUDA_ALLOC_CONF"
|
||||||
)
|
] = "expandable_segments:True,roundup_power2_divisions:16"
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Benchmarking and measurement utilities"""
|
"""Benchmarking and measurement utilities"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -343,9 +343,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
bench_refs.extend(combined_bench_names[bench_name]["refs"])
|
bench_refs.extend(combined_bench_names[bench_name]["refs"])
|
||||||
bench_preds.extend(combined_bench_names[bench_name]["preds"])
|
bench_preds.extend(combined_bench_names[bench_name]["preds"])
|
||||||
if not pd.isna(bench_score):
|
if not pd.isna(bench_score):
|
||||||
results[f"{bench_split}_bench_accuracy_{bench_name}"] = (
|
results[
|
||||||
bench_score
|
f"{bench_split}_bench_accuracy_{bench_name}"
|
||||||
)
|
] = bench_score
|
||||||
bench_scores.append(bench_score)
|
bench_scores.append(bench_score)
|
||||||
else:
|
else:
|
||||||
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
|
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""MLFlow module for trainer callbacks"""
|
"""MLFlow module for trainer callbacks"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""callback to calculate perplexity as an evaluation metric."""
|
"""callback to calculate perplexity as an evaluation metric."""
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
HF Trainer callback for creating pytorch profiling snapshots
|
HF Trainer callback for creating pytorch profiling snapshots
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pickle import dump # nosec B403
|
from pickle import dump # nosec B403
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
This module provides functionality for selecting chat templates based on user choices.
|
This module provides functionality for selecting chat templates based on user choices.
|
||||||
These templates are used for formatting messages in a conversation.
|
These templates are used for formatting messages in a conversation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user