Compare commits

...

19 Commits

Author SHA1 Message Date
Dan Saunders
0bffef25d0 installing axolotl prior to quartodoc build 2025-03-21 16:51:02 +00:00
Dan Saunders
94c00c1d04 pre-commit 2025-03-21 11:23:39 -04:00
Dan Saunders
ddd84d7c65 update pylint 2025-03-21 11:18:59 -04:00
Dan Saunders
42bdf0bd74 update pre-commit version 2025-03-21 11:18:59 -04:00
Dan Saunders
b03d96a228 include quartodoc build step 2025-03-21 11:18:59 -04:00
Dan Saunders
2653f170fc fix accidental change 2025-03-21 11:18:59 -04:00
Dan Saunders
3bfcce9f0a shrinking header sizes 2025-03-21 11:18:59 -04:00
Dan Saunders
8feb746953 fix 2025-03-21 11:18:59 -04:00
Dan Saunders
a563815fe7 pydantic models refactor + add to autodoc + fixes 2025-03-21 11:18:58 -04:00
Dan Saunders
81f2203151 update to reflect recent changes 2025-03-21 11:12:09 -04:00
Dan Saunders
5b7e688fc5 fix broken link 2025-03-21 11:12:09 -04:00
Dan Saunders
5134aa66cd moving reference up near the top of the sidebar 2025-03-21 11:12:09 -04:00
Dan Saunders
ba9a867adb more autodoc progress 2025-03-21 11:12:09 -04:00
Dan Saunders
c618f42c39 Fix 2025-03-21 11:12:09 -04:00
Dan Saunders
fc1f985296 Update docs/.gitignore to exclude auto-generated API documentation files 2025-03-21 11:12:09 -04:00
Dan Saunders
a5e37f183c deletions 2025-03-21 11:12:09 -04:00
Dan Saunders
e6a7bbe9ff quartodoc progress 2025-03-21 11:12:09 -04:00
Dan Saunders
e4fd7aad0b quartodoc integration 2025-03-21 11:12:09 -04:00
Dan Saunders
c907ac173e adding pre-commit auto-update GH action and bumping plugin versions (#2428)
* adding pre-commit auto-update GH action and bumping plugin versions

* running updated pre-commit plugins

* sorry to revert, but pylint complained

* Update .pre-commit-config.yaml

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-03-21 11:02:43 -04:00
159 changed files with 1758 additions and 1238 deletions

View File

@@ -20,9 +20,12 @@ jobs:
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: '3.11' python-version: '3.11'
- name: install dependencies - name: Install dependencies
run: | run: |
python3 -m pip install jupyter python3 -m pip install jupyter quartodoc
python3 -m pip install -e .
- name: Build autodoc
run: quartodoc build
- name: Publish to GitHub Pages (and render) - name: Publish to GitHub Pages (and render)
uses: quarto-dev/quarto-actions/publish@v2 uses: quarto-dev/quarto-actions/publish@v2
with: with:

View File

@@ -0,0 +1,49 @@
name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff
jobs:
auto-update:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Update pre-commit hooks
id: update
run: |
pip install pre-commit
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
if: steps.update.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
branch: update/pre-commit-hooks
delete-branch: true
title: "chore: update pre-commit hooks"
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

4
.gitignore vendored
View File

@@ -181,6 +181,10 @@ prepared-datasets/
submit.sh submit.sh
*.out* *.out*
# Quartodoc generated files
objects.json
site_libs/
typings/ typings/
out/ out/

View File

@@ -3,7 +3,7 @@ default_language_version:
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v5.0.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
@@ -11,23 +11,23 @@ repos:
- id: no-commit-to-branch - id: no-commit-to-branch
args: ['--branch', 'main'] args: ['--branch', 'main']
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 23.3.0 rev: 25.1.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.12.0 rev: 6.0.1
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 6.1.0 rev: 7.1.2
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/pylint-dev/pylint - repo: https://github.com/pylint-dev/pylint
rev: c8c96d20cde3552a79858c7456bb1483bf83d633 rev: v3.3.6
hooks: hooks:
- id: pylint - id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0 rev: v1.15.0
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:
@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3', 'pydantic>=2.5.3',
] ]
- repo: https://github.com/PyCQA/bandit - repo: https://github.com/PyCQA/bandit
rev: 1.7.5 rev: 1.8.3
hooks: hooks:
- id: bandit - id: bandit
args: [ args: [

View File

@@ -97,6 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html) - [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html) - [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html) - [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions - [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
## 🤝 Getting Help ## 🤝 Getting Help

View File

@@ -1,6 +1,178 @@
project: project:
type: website type: website
quartodoc:
dir: docs/api
package: axolotl
title: API Reference
parser: google
sections:
- title: Core
desc: Core functionality for training
contents:
- train
- evaluate
- datasets
- convert
- prompt_tokenizers
- logging_config
- core.trainer_builder
- core.training_args
- core.chat.messages
- core.chat.format.chatml
- core.chat.format.llama3x
- core.chat.format.shared
- core.datasets.chat
- core.datasets.transforms.chat_builder
- title: CLI
desc: Command-line interface
contents:
- cli.main
- cli.train
- cli.evaluate
- cli.args
- cli.checks
- cli.config
- cli.inference
- cli.merge_lora
- cli.merge_sharded_fsdp_weights
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.cloud.base
- cli.cloud.modal_
- title: Trainers
desc: Training implementations
contents:
- core.trainers.base
- core.trainers.trl
- core.trainers.dpo.trainer
- core.trainers.grpo.trainer
- title: Prompt Strategies
desc: Prompt formatting strategies
contents:
- prompt_strategies.base
- prompt_strategies.chat_template
- prompt_strategies.alpaca_chat
- prompt_strategies.alpaca_instruct
- prompt_strategies.alpaca_w_system
- prompt_strategies.user_defined
- prompt_strategies.llama2_chat
- prompt_strategies.completion
- prompt_strategies.input_output
- prompt_strategies.stepwise_supervised
- prompt_strategies.metharme
- prompt_strategies.orcamini
- prompt_strategies.pygmalion
- prompt_strategies.messages.chat
- prompt_strategies.dpo.chat_template
- prompt_strategies.dpo.llama3
- prompt_strategies.dpo.chatml
- prompt_strategies.dpo.zephyr
- prompt_strategies.dpo.user_defined
- prompt_strategies.dpo.passthrough
- prompt_strategies.kto.llama3
- prompt_strategies.kto.chatml
- prompt_strategies.kto.user_defined
- prompt_strategies.orpo.chat_template
- prompt_strategies.bradley_terry.llama3
- title: Kernels
desc: Low-level performance optimizations
contents:
- kernels.lora
- kernels.geglu
- kernels.swiglu
- kernels.quantize
- kernels.utils
- title: MonkeyPatches
desc: Runtime patches for model optimizations
contents:
- monkeypatch.llama_attn_hijack_flash
- monkeypatch.llama_attn_hijack_xformers
- monkeypatch.mistral_attn_hijack_flash
- monkeypatch.multipack
- monkeypatch.relora
- monkeypatch.llama_expand_mask
- monkeypatch.lora_kernels
- monkeypatch.utils
- monkeypatch.btlm_attn_hijack_flash
- monkeypatch.llama_patch_multipack
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_
- monkeypatch.attention.mllama
- monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral
- title: Utils
desc: Utility functions
contents:
- utils.models
- utils.tokenization
- utils.chat_templates
- utils.lora
- utils.lora_embeddings
- utils.model_shard_quant
- utils.bench
- utils.freeze
- utils.trainer
- utils.schedulers
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.unsloth
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:
- utils.schemas.config
- utils.schemas.model
- utils.schemas.training
- utils.schemas.datasets
- utils.schemas.peft
- utils.schemas.trl
- utils.schemas.integrations
- utils.schemas.enums
- utils.schemas.utils
- title: Integrations
desc: Third-party integrations and extensions
contents:
- integrations.base
- integrations.cut_cross_entropy.args
- integrations.grokfast.optimizer
- integrations.kd.trainer
- integrations.liger.args
- integrations.lm_eval.args
- integrations.spectrum.args
- title: Common
desc: Common utilities and shared functionality
contents:
- common.architectures
- common.const
- common.datasets
- title: Models
desc: Custom model implementations
contents:
- models.mamba.modeling_mamba
- title: Data Processing
desc: Data processing utilities
contents:
- utils.collators.core
- utils.collators.batching
- utils.collators.mamba
- utils.collators.mm_chat
- utils.samplers.multipack
- title: Callbacks
desc: Training callbacks
contents:
- utils.callbacks.perplexity
- utils.callbacks.profiler
- utils.callbacks.lisa
- utils.callbacks.mlflow_
- utils.callbacks.comet_
website: website:
title: "Axolotl" title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun" description: "We make fine-tuning accessible, scalable, and fun"
@@ -35,6 +207,8 @@ website:
- docs/inference.qmd - docs/inference.qmd
- docs/cli.qmd - docs/cli.qmd
- docs/config.qmd - docs/config.qmd
- text: "API Reference"
href: docs/api
- section: "Dataset Formats" - section: "Dataset Formats"
contents: docs/dataset-formats/* contents: docs/dataset-formats/*
@@ -80,3 +254,22 @@ format:
theme: darkly theme: darkly
css: styles.css css: styles.css
toc: true toc: true
# Enable better handling of line breaks in markdown
preserve-tabs: true
html-math-method: mathjax
# Improved markdown processing options
md-extensions:
- markdown_it
- def_list
- attr_list
- fenced_divs
- tables
- html_admonition
- lineblocks
- fancy_lists
# Control whitespace handling
whitespace: preserve
# Process newlines in paragraphs
wrap: preserve
# Better line break handling
preserve-linebreaks: true

View File

@@ -1,6 +1,7 @@
""" """
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -1,4 +1,5 @@
"""Modal app to run axolotl GPU tests""" """Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

2
docs/.gitignore vendored
View File

@@ -1,2 +1,4 @@
/.quarto/ /.quarto/
_site/ _site/
/api/*.qmd
/api/*.html

View File

@@ -1,5 +1,5 @@
--- ---
title: "CLI Reference" title: "Command Line Interface (CLI)"
format: format:
html: html:
toc: true toc: true

View File

@@ -6,7 +6,7 @@ description: How datasets are processed
## Overview ## Overview
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
the [dataset format](docs/dataset-formats) and prompt strategies to: the [dataset format](dataset-formats) and prompt strategies to:
- parse the dataset based on the *dataset format* - parse the dataset based on the *dataset format*
- transform the dataset to how you would interact with the model based on the *prompt strategy* - transform the dataset to how you would interact with the model based on the *prompt strategy*

View File

@@ -2,3 +2,5 @@ pre-commit
black black
mypy mypy
types-requests types-requests
quartodoc
jupyter

View File

@@ -1,6 +1,7 @@
""" """
helper script to parse chat datasets into a usable yaml helper script to parse chat datasets into a usable yaml
""" """
import click import click
import yaml import yaml
from datasets import load_dataset from datasets import load_dataset

View File

@@ -1,4 +1,5 @@
"""Script to output the correct installation command for cut-cross-entropy.""" """Script to output the correct installation command for cut-cross-entropy."""
import importlib.util import importlib.util
import sys import sys
@@ -17,12 +18,12 @@ if v < V("2.4.0"):
cce_spec = importlib.util.find_spec("cut_cross_entropy") cce_spec = importlib.util.find_spec("cut_cross_entropy")
uninstall_prefix = "" UNINSTALL_PREFIX = ""
if cce_spec: if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"): if not importlib.util.find_spec("cut_cross_entropy.transformers"):
uninstall_prefix = "pip uninstall -y cut-cross-entropy && " UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
print( print(
uninstall_prefix UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"' + 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
) )

View File

@@ -1,6 +1,7 @@
""" """
launch axolotl in supported cloud platforms launch axolotl in supported cloud platforms
""" """
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union

View File

@@ -1,6 +1,7 @@
""" """
base class for cloud platforms from cli base class for cloud platforms from cli
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View File

@@ -1,6 +1,7 @@
""" """
Modal Cloud support from CLI Modal Cloud support from CLI
""" """
import copy import copy
import json import json
import os import os

View File

@@ -1,4 +1,5 @@
"""Click CLI definitions for various axolotl commands.""" """Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import logging import logging
@@ -24,7 +25,7 @@ from axolotl.cli.utils import (
) )
from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.schemas.config import AxolotlInputConfig
@click.group() @click.group()

View File

@@ -5,7 +5,6 @@ import dataclasses
import hashlib import hashlib
import json import json
import logging import logging
import typing
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
@@ -24,7 +23,7 @@ configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | typing._SpecialForm | None): def strip_optional_type(field_type: type | str | None):
""" """
Extracts the non-`None` type from an `Optional` / `Union` type. Extracts the non-`None` type from an `Optional` / `Union` type.

View File

@@ -1,6 +1,5 @@
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes""" """Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
import json import json
import sys import sys

View File

@@ -1,6 +1,7 @@
""" """
ChatML transformation functions for MessageContents ChatML transformation functions for MessageContents
""" """
from typing import Optional from typing import Optional
from ..messages import MessageContents, Messages from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
Llama 3.x chat formatting functions for MessageContents Llama 3.x chat formatting functions for MessageContents
""" """
from typing import Optional from typing import Optional
from ..messages import MessageContents, Messages from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
shared functions for format transforms shared functions for format transforms
""" """
from axolotl.core.chat.messages import MessageContents, Messages from axolotl.core.chat.messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
internal message representations of chat messages internal message representations of chat messages
""" """
import json import json
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union

View File

@@ -1,6 +1,7 @@
""" """
chat dataset module chat dataset module
""" """
import os import os
from typing import Callable, Optional, Union from typing import Callable, Optional, Union

View File

@@ -1,6 +1,7 @@
""" """
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
""" """
from typing import Any, Mapping, Union from typing import Any, Mapping, Union

View File

@@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
""" """Builder for the training args and trainer"""
Builder for the training args and trainer
"""
import abc import abc
import importlib import importlib
@@ -85,8 +83,8 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
) )
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
from axolotl.utils.models import ensure_dtype from axolotl.utils.models import ensure_dtype
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
try: try:
import torch._dynamo # pylint: disable=ungrouped-imports import torch._dynamo # pylint: disable=ungrouped-imports
@@ -332,9 +330,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = {} training_arguments_kwargs = {}
if self.cfg.include_tokens_per_second is not None: if self.cfg.include_tokens_per_second is not None:
training_arguments_kwargs[ training_arguments_kwargs["include_tokens_per_second"] = (
"include_tokens_per_second" self.cfg.include_tokens_per_second
] = self.cfg.include_tokens_per_second )
if self.cfg.bf16 == "full": if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True training_arguments_kwargs["bf16_full_eval"] = True
@@ -351,13 +349,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["seed"] = self.cfg.seed training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_arguments_kwargs[ training_arguments_kwargs["gradient_checkpointing"] = (
"gradient_checkpointing" self.cfg.gradient_checkpointing
] = self.cfg.gradient_checkpointing )
if self.cfg.gradient_checkpointing_kwargs is not None: if self.cfg.gradient_checkpointing_kwargs is not None:
training_arguments_kwargs[ training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
"gradient_checkpointing_kwargs" self.cfg.gradient_checkpointing_kwargs
] = self.cfg.gradient_checkpointing_kwargs )
if self.cfg.fsdp: if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config: if self.cfg.fsdp_config:
@@ -373,9 +371,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None: if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[ training_arguments_kwargs["lr_quadratic_warmup"] = (
"lr_quadratic_warmup" self.cfg.lr_quadratic_warmup
] = self.cfg.lr_quadratic_warmup )
if self.cfg.adam_beta1: if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
@@ -399,28 +397,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_pin_memory"] = (
"dataloader_pin_memory" self.cfg.dataloader_pin_memory
] = self.cfg.dataloader_pin_memory )
if self.cfg.dataloader_num_workers is not None: if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_num_workers"] = (
"dataloader_num_workers" self.cfg.dataloader_num_workers
] = self.cfg.dataloader_num_workers )
if self.cfg.dataloader_prefetch_factor is not None: if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_prefetch_factor"] = (
"dataloader_prefetch_factor" self.cfg.dataloader_prefetch_factor
] = self.cfg.dataloader_prefetch_factor )
if self.cfg.dataloader_drop_last is not None: if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_drop_last"] = (
"dataloader_drop_last" self.cfg.dataloader_drop_last
] = self.cfg.dataloader_drop_last )
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None: if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs[ training_arguments_kwargs["remove_unused_columns"] = (
"remove_unused_columns" self.cfg.remove_unused_columns
] = self.cfg.remove_unused_columns )
if not self.cfg.test_datasets and self.cfg.val_set_size == 0: if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval # no eval set, so don't eval
@@ -452,9 +450,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_causal_lm_eval: if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model: if self.cfg.metric_for_best_model:
training_arguments_kwargs[ training_arguments_kwargs["metric_for_best_model"] = (
"metric_for_best_model" self.cfg.metric_for_best_model
] = self.cfg.metric_for_best_model )
if self.cfg.greater_is_better: if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
@@ -467,13 +465,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend: if self.cfg.torch_compile_backend:
training_arguments_kwargs[ training_arguments_kwargs["torch_compile_backend"] = (
"torch_compile_backend" self.cfg.torch_compile_backend
] = self.cfg.torch_compile_backend )
if self.cfg.torch_compile_mode: if self.cfg.torch_compile_mode:
training_arguments_kwargs[ training_arguments_kwargs["torch_compile_mode"] = (
"torch_compile_mode" self.cfg.torch_compile_mode
] = self.cfg.torch_compile_mode )
# DDP Config # DDP Config
if self.cfg.ddp_timeout: if self.cfg.ddp_timeout:
@@ -482,32 +480,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.ddp_bucket_cap_mb: if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None: if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[ training_arguments_kwargs["ddp_broadcast_buffers"] = (
"ddp_broadcast_buffers" self.cfg.ddp_broadcast_buffers
] = self.cfg.ddp_broadcast_buffers )
# these are all the "standard" kwargs that are def used # these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = ( training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1 total_num_steps if self.cfg.max_steps else -1
) )
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[ training_arguments_kwargs["per_device_train_batch_size"] = (
"per_device_train_batch_size" self.cfg.micro_batch_size
] = self.cfg.micro_batch_size )
if self.cfg.eval_batch_size: if self.cfg.eval_batch_size:
training_arguments_kwargs[ training_arguments_kwargs["per_device_eval_batch_size"] = (
"per_device_eval_batch_size" self.cfg.eval_batch_size
] = self.cfg.eval_batch_size )
if self.cfg.auto_find_batch_size is not None: if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["auto_find_batch_size"] = (
"auto_find_batch_size" self.cfg.auto_find_batch_size
] = self.cfg.auto_find_batch_size )
training_arguments_kwargs[ training_arguments_kwargs["gradient_accumulation_steps"] = (
"gradient_accumulation_steps" self.cfg.gradient_accumulation_steps
] = self.cfg.gradient_accumulation_steps )
training_arguments_kwargs[ training_arguments_kwargs["eval_accumulation_steps"] = (
"eval_accumulation_steps" self.cfg.gradient_accumulation_steps
] = self.cfg.gradient_accumulation_steps )
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir training_arguments_kwargs["output_dir"] = self.cfg.output_dir
@@ -554,9 +552,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[ training_arguments_kwargs["alternate_lr_scheduler_type"] = (
"alternate_lr_scheduler_type" self.cfg.lr_scheduler
] = self.cfg.lr_scheduler )
else: else:
training_arguments_kwargs["lr_scheduler_type"] = ( training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
@@ -565,9 +563,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[ training_arguments_kwargs["cosine_constant_lr_ratio"] = (
"cosine_constant_lr_ratio" self.cfg.cosine_constant_lr_ratio
] = self.cfg.cosine_constant_lr_ratio )
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
@@ -580,40 +578,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.eval_sample_packing self.cfg.eval_sample_packing
) )
if self.cfg.sample_packing_bin_size is not None: if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_bin_size"] = (
"sample_packing_bin_size" self.cfg.sample_packing_bin_size
] = self.cfg.sample_packing_bin_size )
if self.cfg.sample_packing_group_size is not None: if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_group_size"] = (
"sample_packing_group_size" self.cfg.sample_packing_group_size
] = self.cfg.sample_packing_group_size )
if self.cfg.sample_packing_eff_est: if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_efficiency"] = (
"sample_packing_efficiency" self.cfg.sample_packing_eff_est
] = self.cfg.sample_packing_eff_est )
if self.cfg.relora_steps: if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[ training_arguments_kwargs["relora_warmup_steps"] = (
"relora_warmup_steps" self.cfg.relora_warmup_steps
] = self.cfg.relora_warmup_steps )
if self.cfg.relora_anneal_steps: if self.cfg.relora_anneal_steps:
training_arguments_kwargs[ training_arguments_kwargs["relora_anneal_steps"] = (
"relora_anneal_steps" self.cfg.relora_anneal_steps
] = self.cfg.relora_anneal_steps )
if self.cfg.relora_prune_ratio: if self.cfg.relora_prune_ratio:
training_arguments_kwargs[ training_arguments_kwargs["relora_prune_ratio"] = (
"relora_prune_ratio" self.cfg.relora_prune_ratio
] = self.cfg.relora_prune_ratio )
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs[ training_arguments_kwargs["lisa_step_interval"] = (
"lisa_step_interval" self.cfg.lisa_step_interval
] = self.cfg.lisa_step_interval )
training_arguments_kwargs[ training_arguments_kwargs["lisa_layers_attribute"] = (
"lisa_layers_attribute" self.cfg.lisa_layers_attribute
] = self.cfg.lisa_layers_attribute )
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
@@ -627,9 +625,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
if self.cfg.neftune_noise_alpha is not None: if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[ training_arguments_kwargs["neftune_noise_alpha"] = (
"neftune_noise_alpha" self.cfg.neftune_noise_alpha
] = self.cfg.neftune_noise_alpha )
trainer_kwargs = {} trainer_kwargs = {}
@@ -731,23 +729,23 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
importlib.import_module("torchdistx") importlib.import_module("torchdistx")
if self.cfg.optim_target_modules: if self.cfg.optim_target_modules:
training_arguments_kwargs[ training_arguments_kwargs["optim_target_modules"] = (
"optim_target_modules" self.cfg.optim_target_modules
] = self.cfg.optim_target_modules )
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[ training_arguments_kwargs["loraplus_lr_embedding"] = (
"loraplus_lr_embedding" self.cfg.loraplus_lr_embedding
] = self.cfg.loraplus_lr_embedding )
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config: if self.cfg.accelerator_config:
training_arguments_kwargs[ training_arguments_kwargs["accelerator_config"] = (
"accelerator_config" self.cfg.accelerator_config
] = self.cfg.accelerator_config )
if self.cfg.kd_ce_alpha is not None: if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
@@ -756,13 +754,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_temperature is not None: if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None: if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[ training_arguments_kwargs["kd_zscore_base_temp"] = (
"kd_zscore_base_temp" self.cfg.kd_zscore_base_temp
] = self.cfg.kd_zscore_base_temp )
if self.cfg.kd_top_k_before_softmax is not None: if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs[ training_arguments_kwargs["kd_top_k_before_softmax"] = (
"kd_top_k_before_softmax" self.cfg.kd_top_k_before_softmax
] = self.cfg.kd_top_k_before_softmax )
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig
@@ -972,32 +970,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
if self.cfg.remove_unused_columns is not None: if self.cfg.remove_unused_columns is not None:
training_args_kwargs[ training_args_kwargs["remove_unused_columns"] = (
"remove_unused_columns" self.cfg.remove_unused_columns
] = self.cfg.remove_unused_columns )
else: else:
training_args_kwargs["remove_unused_columns"] = False training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[ training_args_kwargs["dataloader_pin_memory"] = (
"dataloader_pin_memory" self.cfg.dataloader_pin_memory
] = self.cfg.dataloader_pin_memory )
if self.cfg.dataloader_num_workers is not None: if self.cfg.dataloader_num_workers is not None:
training_args_kwargs[ training_args_kwargs["dataloader_num_workers"] = (
"dataloader_num_workers" self.cfg.dataloader_num_workers
] = self.cfg.dataloader_num_workers )
if self.cfg.dataloader_prefetch_factor is not None: if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs[ training_args_kwargs["dataloader_prefetch_factor"] = (
"dataloader_prefetch_factor" self.cfg.dataloader_prefetch_factor
] = self.cfg.dataloader_prefetch_factor )
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_args_kwargs[ training_args_kwargs["gradient_checkpointing"] = (
"gradient_checkpointing" self.cfg.gradient_checkpointing
] = self.cfg.gradient_checkpointing )
if self.cfg.gradient_checkpointing_kwargs is not None: if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs[ training_args_kwargs["gradient_checkpointing_kwargs"] = (
"gradient_checkpointing_kwargs" self.cfg.gradient_checkpointing_kwargs
] = self.cfg.gradient_checkpointing_kwargs )
else: else:
training_args_kwargs["gradient_checkpointing_kwargs"] = { training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False "use_reentrant": False
@@ -1071,9 +1069,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dpo_use_weighting is not None: if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None: if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[ training_args_kwargs["use_logits_to_keep"] = (
"use_logits_to_keep" self.cfg.dpo_use_logits_to_keep
] = self.cfg.dpo_use_logits_to_keep )
for blocklist_key in blocklist_args_kwargs: for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs: if blocklist_key in training_args_kwargs:
@@ -1108,9 +1106,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None: if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs[ dpo_trainer_kwargs["precompute_ref_log_probs"] = (
"precompute_ref_log_probs" self.cfg.precompute_ref_log_probs
] = self.cfg.precompute_ref_log_probs )
if self.cfg.rl == "grpo": if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class() trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model] trainer_cls_args = [self.model]

View File

@@ -462,9 +462,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
} }
if self.args.dataloader_prefetch_factor: if self.args.dataloader_prefetch_factor:
dataloader_params[ dataloader_params["prefetch_factor"] = (
"prefetch_factor" self.args.dataloader_prefetch_factor
] = self.args.dataloader_prefetch_factor )
sampler = self._get_train_sampler() sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler): if isinstance(sampler, BatchSampler):
@@ -509,9 +509,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
} }
if self.args.dataloader_prefetch_factor: if self.args.dataloader_prefetch_factor:
dataloader_params[ dataloader_params["prefetch_factor"] = (
"prefetch_factor" self.args.dataloader_prefetch_factor
] = self.args.dataloader_prefetch_factor )
if isinstance(eval_sampler, BatchSampler): if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler dataloader_params["batch_sampler"] = eval_sampler

View File

@@ -1,6 +1,7 @@
""" """
DPO Specific Strategy for training DPO Specific Strategy for training
""" """
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer

View File

@@ -1,6 +1,7 @@
""" """
Axolotl specific DPO args Axolotl specific DPO args
""" """
from dataclasses import dataclass from dataclasses import dataclass
from trl import DPOConfig from trl import DPOConfig

View File

@@ -1,6 +1,7 @@
""" """
DPO trainer for axolotl DPO trainer for axolotl
""" """
import gc import gc
from functools import wraps from functools import wraps
from typing import Any, Dict, Union from typing import Any, Dict, Union

View File

@@ -9,7 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig from axolotl.utils.schemas.trl import TRLConfig
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -45,9 +45,9 @@ class GRPOStrategy:
) )
if trl.vllm_gpu_memory_utilization: if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[ grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
"vllm_gpu_memory_utilization" trl.vllm_gpu_memory_utilization
] = trl.vllm_gpu_memory_utilization )
if trl.vllm_max_model_len: if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
@@ -86,9 +86,9 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg): def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {} trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes: if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[ trainer_kwargs["reward_processing_classes"] = (
"reward_processing_classes" cfg.trl.reward_processing_classes
] = cfg.trl.reward_processing_classes )
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod

View File

@@ -1,6 +1,7 @@
""" """
Axolotl Specific Training Args Axolotl Specific Training Args
""" """
from dataclasses import dataclass from dataclasses import dataclass
from trl import GRPOConfig from trl import GRPOConfig

View File

@@ -1,6 +1,7 @@
""" """
Axolotl GRPO trainer Axolotl GRPO trainer
""" """
from accelerate.utils import is_peft_model from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel from transformers import PreTrainedModel

View File

@@ -1,6 +1,7 @@
""" """
module for TRL PPO training module for TRL PPO training
""" """
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from trl import PPOTrainer from trl import PPOTrainer

View File

@@ -1,6 +1,7 @@
""" """
extra axolotl specific training args extra axolotl specific training args
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional

View File

@@ -8,6 +8,8 @@ from typing import Dict, Optional
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
@@ -25,18 +27,18 @@ LOG = get_logger("axolotl.evaluate")
def evaluate_dataset( def evaluate_dataset(
trainer, dataset, dataset_type: str, flash_optimum: bool = False trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]: ) -> Optional[Dict[str, float]]:
"""Helper function to evaluate a single dataset safely. """Helper function to evaluate a single dataset.
Args: Args:
trainer: The trainer instance trainer: The trainer instance.
dataset: Dataset to evaluate dataset: Dataset to evaluate.
dataset_type: Type of dataset ('train' or 'eval') dataset_type: Type of dataset ('train' or 'eval').
flash_optimum: Whether to use flash optimum flash_optimum: Whether to use flash optimum.
Returns: Returns:
Dictionary of metrics or None if dataset is None Dictionary of metrics or None if dataset is None.
""" """
if dataset is None: if dataset is None:
return None return None
@@ -63,17 +65,14 @@ def evaluate_dataset(
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
""" """
Evaluate a model on training and validation datasets Evaluate a model on training and validation datasets.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
dataset_meta: Dataset metadata containing training and evaluation datasets. dataset_meta: Dataset metadata containing training and evaluation datasets.
Returns: Returns:
Tuple containing: Dictionary mapping metric names to their values.
- The model (either PeftModel or PreTrainedModel)
- The tokenizer
- Dictionary of evaluation metrics
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage # Enable expandable segments for cuda allocation to improve VRAM usage

View File

@@ -11,19 +11,17 @@
# the License. # the License.
""" """
module to handle merging the plugins' input arguments with the base configurations. Module to handle merging the plugins' input arguments with the base configurations.
this was moved here to prevent circular imports This was moved here to prevent circular imports.
""" """
from typing import Any, Dict, List from typing import Any, Dict, List
from axolotl.utils.config.models.input.v0_4_1 import ( from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
) )
from axolotl.utils.config.models.input.v0_4_1 import ( from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
AxolotlInputConfig as AxolotlInputConfigBase,
)
def merge_input_args(): def merge_input_args():

View File

@@ -1,6 +1,7 @@
""" """
Grokfast plugin for Axolotl Grokfast plugin for Axolotl
""" """
import logging import logging
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback

View File

@@ -1,6 +1,7 @@
""" """
config args for grokfast plugin config args for grokfast plugin
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
""" """
kd_trainer: Optional[bool] = None # whether to use KD trainer kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[ kd_ce_alpha: Optional[float] = (
float None # loss coefficient for cross-entropy loss during KD
] = None # loss coefficient for cross-entropy loss during KD )
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[ kd_top_k_before_softmax: Optional[bool] = (
bool None # whether to sample top k before softmax during KD
] = None # whether to sample top k before softmax during KD )

View File

@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
if "cross_entropy" in liger_fn_sig.parameters: if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters: if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[ kwargs["fused_linear_cross_entropy"] = (
"fused_linear_cross_entropy" cfg.liger_fused_linear_cross_entropy
] = cfg.liger_fused_linear_cross_entropy )
if "rms_norm" in liger_fn_sig.parameters: if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters: if "layer_norm" in liger_fn_sig.parameters:

View File

@@ -1,6 +1,7 @@
""" """
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
Jamba model with LigerFusedLinearCrossEntropyLoss Jamba model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
Module for the Plugin for LM Eval Harness Module for the Plugin for LM Eval Harness
""" """
import subprocess # nosec import subprocess # nosec
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin

View File

@@ -1,6 +1,7 @@
""" """
Module for handling lm eval harness input arguments. Module for handling lm eval harness input arguments.
""" """
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,6 +1,7 @@
""" """
axolotl CLI for running lm_eval tasks axolotl CLI for running lm_eval tasks
""" """
import subprocess # nosec import subprocess # nosec
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code # pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch import torch

View File

@@ -6,6 +6,7 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name # pylint: disable=invalid-name
from typing import Callable from typing import Callable

View File

@@ -1,4 +1,5 @@
"""Dequantization utilities for `bitsandbytes` integration.""" """Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement # pylint: disable=invalid-name,global-statement
import ctypes import ctypes

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl

View File

@@ -1,6 +1,7 @@
""" """
HF Transformers MambaConfig HF Transformers MambaConfig
""" """
from transformers import PretrainedConfig from transformers import PretrainedConfig

View File

@@ -1,6 +1,7 @@
""" """
Monkeypatch for Vision Llama for FA2 support Monkeypatch for Vision Llama for FA2 support
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -220,10 +221,10 @@ def patch_mllama():
True True
) )
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2 MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = (
"flash_attention_2" MllamaTextCrossFlashAttention2
] = MllamaTextCrossFlashAttention2 )
# fallback to SDPA # fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[ MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = (
"flash_attention_2" MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] )

View File

@@ -1,4 +1,5 @@
"""monkey patches for the dataset fetcher to handle batches of packed indexes""" """monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access # pylint: disable=protected-access
import torch import torch

View File

@@ -12,7 +12,9 @@ import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import (
LlamaAttention,
)
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer, LlamaDecoderLayer as OriginalLlamaDecoderLayer,
) )
@@ -490,9 +492,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel # We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask # the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None if attention_mask is not None
else None, else None
),
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad, qkv_unpad,
@@ -531,9 +535,11 @@ def flashattn_forward(
value_states, value_states,
kvpacked=True, kvpacked=True,
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None if attention_mask is not None
else None, else None
),
) )
if q_unpad.dtype != kv_unpad.dtype: if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype) kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
""" """
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
""" """
from typing import Optional from typing import Optional
import torch import torch

View File

@@ -1,4 +1,5 @@
"""Flash attention monkey patch for mistral model""" """Flash attention monkey patch for mistral model"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import logging import logging
@@ -21,7 +22,10 @@ from transformers.models.mistral.modeling_mistral import (
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv from transformers.models.mistral.modeling_mistral import (
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
@@ -243,9 +247,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel # We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask # the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None if attention_mask is not None
else None, else None
),
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad, qkv_unpad,
@@ -286,9 +292,11 @@ def flashattn_forward(
value_states, value_states,
kvpacked=True, kvpacked=True,
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None if attention_mask is not None
else None, else None
),
) )
if q_unpad.dtype != kv_unpad.dtype: if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype) kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
""" """
Patches to support multipack for mixtral Patches to support multipack for mixtral
""" """
import torch import torch

View File

@@ -1,4 +1,5 @@
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.""" """Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
import glob import glob
import json import json
import logging import logging
@@ -411,7 +412,10 @@ def merge_and_save(
if shard_path.endswith(".safetensors"): if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path)) in_tensors = st.load_file(str(Path(model_src) / shard_path))
else: else:
in_tensors = torch.load(Path(model_src) / shard_path) in_tensors = torch.load(
Path(model_src) / shard_path,
weights_only=True, # to prevent arbitrary code execution
)
if "state_dict" in in_tensors: if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"] in_tensors = in_tensors["state_dict"]

View File

@@ -17,7 +17,7 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """ """PyTorch StableLM Epoch model."""
import importlib import importlib
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
fix for FSDP optimizer save in trainer w 4.47.0 fix for FSDP optimizer save in trainer w 4.47.0
""" """
import inspect import inspect
import logging import logging

View File

@@ -1,6 +1,7 @@
""" """
Shared utils for the monkeypatches Shared utils for the monkeypatches
""" """
import re import re
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@@ -1,6 +1,7 @@
""" """
Fused MLP layer for incrementally improved training efficiency Fused MLP layer for incrementally improved training efficiency
""" """
import torch import torch
from transformers.models.llama.modeling_llama import LlamaMLP from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU from xformers.ops import SwiGLU

View File

@@ -1,6 +1,7 @@
""" """
Prompt strategies loader for alpaca instruction datasets with system prompts Prompt strategies loader for alpaca instruction datasets with system prompts
""" """
from typing import Generator, Tuple, Union from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig from axolotl.utils.schemas.datasets import DatasetConfig
# Configure the logger # Configure the logger
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")

View File

@@ -1,6 +1,7 @@
""" """
Basic completion text Basic completion text
""" """
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple

View File

@@ -1,4 +1,5 @@
"""Module containing the classes for Context QA Prompt Tokenization Strategies""" """Module containing the classes for Context QA Prompt Tokenization Strategies"""
from typing import Tuple from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
""" """
module for DPO style dataset transform strategies module for DPO style dataset transform strategies
""" """
from functools import partial from functools import partial
from ..base import load as load_base from ..base import load as load_base

View File

@@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates.
""" """
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
def default( def default(

View File

@@ -33,9 +33,9 @@ def default(
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>" sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>" sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
return sample return sample
@@ -52,9 +52,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample return sample
@@ -78,9 +78,9 @@ def icr(
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -100,9 +100,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -120,9 +120,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -142,9 +142,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample return sample

View File

@@ -34,9 +34,9 @@ def default(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>" sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>" sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
return sample return sample
@@ -53,9 +53,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample return sample
@@ -79,9 +79,9 @@ def icr(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -101,9 +101,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -121,9 +121,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -143,9 +143,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample return sample

View File

@@ -1,4 +1,5 @@
"""Module for plain input/output prompt pairs""" """Module for plain input/output prompt pairs"""
from typing import Generator, Tuple from typing import Generator, Tuple
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -1,4 +1,5 @@
"""Module for inspect jinja templates for the variables they use""" """Module for inspect jinja templates for the variables they use"""
from typing import Dict, Optional, Set, TypedDict, Union from typing import Dict, Optional, Set, TypedDict, Union
from jinja2 import Environment, meta, nodes from jinja2 import Environment, meta, nodes

View File

@@ -1,6 +1,7 @@
""" """
KTO strategies for chatml KTO strategies for chatml
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -15,9 +16,9 @@ def argilla(
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample
@@ -33,9 +34,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>" sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
return sample return sample
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample
@@ -74,9 +75,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["completion"] = f"{sample['completion']}<|im_end|>" sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample return sample

View File

@@ -1,6 +1,7 @@
""" """
KTO strategies for llama-3 chat template KTO strategies for llama-3 chat template
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -15,9 +16,9 @@ def argilla(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample
@@ -33,9 +34,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>" sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
return sample return sample
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample
@@ -74,9 +75,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["completion"] = f"{sample['completion']}<|eot_id|>" sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample return sample

View File

@@ -1,6 +1,7 @@
""" """
User-defined KTO strategies User-defined KTO strategies
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code

View File

@@ -1,6 +1,7 @@
""" """
Chat dataset wrapping strategy for new internal messages representations Chat dataset wrapping strategy for new internal messages representations
""" """
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset from axolotl.core.datasets.chat import TokenizedChatDataset

View File

@@ -9,6 +9,7 @@ this one specifies the system prompt with "### System:".
Not suited/tested for multiple-turn conversations without further adjustments. Not suited/tested for multiple-turn conversations without further adjustments.
""" """
from typing import Generator, Union from typing import Generator, Union
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy

View File

@@ -1,4 +1,5 @@
"""chatml prompt tokenization strategy for ORPO""" """chatml prompt tokenization strategy for ORPO"""
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,4 +1,5 @@
"""pretraining prompt strategies""" """pretraining prompt strategies"""
from typing import Generator from typing import Generator
from transformers import BatchEncoding from transformers import BatchEncoding

View File

@@ -406,9 +406,7 @@ def handle_untrained_tokens_fix(
) )
def setup_model_and_trainer( def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder, HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel, PeftModel | PreTrainedModel,
PreTrainedTokenizer, PreTrainedTokenizer,

View File

@@ -40,6 +40,6 @@ def set_pytorch_cuda_alloc_conf():
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2: if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"PYTORCH_CUDA_ALLOC_CONF" "expandable_segments:True,roundup_power2_divisions:16"
] = "expandable_segments:True,roundup_power2_divisions:16" )

View File

@@ -1,4 +1,5 @@
"""Benchmarking and measurement utilities""" """Benchmarking and measurement utilities"""
import functools import functools
import torch import torch

View File

@@ -33,7 +33,6 @@ from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
barrier, barrier,
broadcast_dict, broadcast_dict,
@@ -43,6 +42,7 @@ from axolotl.utils.distributed import (
is_main_process, is_main_process,
zero_first, zero_first,
) )
from axolotl.utils.schemas.config import AxolotlInputConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments from axolotl.core.trainer_builder import AxolotlTrainingArguments
@@ -343,9 +343,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
bench_refs.extend(combined_bench_names[bench_name]["refs"]) bench_refs.extend(combined_bench_names[bench_name]["refs"])
bench_preds.extend(combined_bench_names[bench_name]["preds"]) bench_preds.extend(combined_bench_names[bench_name]["preds"])
if not pd.isna(bench_score): if not pd.isna(bench_score):
results[ results[f"{bench_split}_bench_accuracy_{bench_name}"] = (
f"{bench_split}_bench_accuracy_{bench_name}" bench_score
] = bench_score )
bench_scores.append(bench_score) bench_scores.append(bench_score)
else: else:
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0 results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0

View File

@@ -1,4 +1,5 @@
"""MLFlow module for trainer callbacks""" """MLFlow module for trainer callbacks"""
import logging import logging
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile

View File

@@ -1,4 +1,5 @@
"""callback to calculate perplexity as an evaluation metric.""" """callback to calculate perplexity as an evaluation metric."""
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch import torch

View File

@@ -1,6 +1,7 @@
""" """
HF Trainer callback for creating pytorch profiling snapshots HF Trainer callback for creating pytorch profiling snapshots
""" """
from pathlib import Path from pathlib import Path
from pickle import dump # nosec B403 from pickle import dump # nosec B403

View File

@@ -2,6 +2,7 @@
This module provides functionality for selecting chat templates based on user choices. This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation. These templates are used for formatting messages in a conversation.
""" """
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional

View File

@@ -1,6 +1,7 @@
""" """
shared axolotl collators for multipack, mamba, multimodal shared axolotl collators for multipack, mamba, multimodal
""" """
from .batching import ( # noqa: F401 from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,

View File

@@ -1,4 +1,5 @@
""" """
basic shared collator constants basic shared collator constants
""" """
IGNORE_INDEX = -100 IGNORE_INDEX = -100

View File

@@ -1,6 +1,7 @@
""" """
collators for Mamba collators for Mamba
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Sequence from typing import Dict, Sequence

View File

@@ -12,15 +12,13 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.integrations.config import merge_input_args from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config from axolotl.utils.models import load_model_config
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")

View File

@@ -1,6 +1,7 @@
""" """
Data processing modules Data processing modules
""" """
from axolotl.utils.data.pretraining import ( # noqa: F401 from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,

View File

@@ -1,6 +1,7 @@
""" """
utility helpers for distributed checks utility helpers for distributed checks
""" """
import os import os
import pickle # nosec import pickle # nosec
from contextlib import contextmanager from contextlib import contextmanager

View File

@@ -1,10 +1,13 @@
""" """
utils to get GPU info for the current environment utils to get GPU info for the current environment
""" """
from accelerate.utils.environment import ( from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
) )
from accelerate.utils.environment import get_gpu_info from accelerate.utils.environment import (
get_gpu_info,
)
def check_cuda_p2p_ib_support(): def check_cuda_p2p_ib_support():

View File

@@ -1,6 +1,7 @@
""" """
module to freeze/unfreeze parameters by name module to freeze/unfreeze parameters by name
""" """
import logging import logging
import re import re
from typing import Callable, List, Tuple, Union from typing import Callable, List, Tuple, Union

View File

@@ -1,4 +1,5 @@
"""custom checkpointing utils""" """custom checkpointing utils"""
from axolotl.utils.gradient_checkpointing.unsloth import ( from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer, Unsloth_Offloaded_Gradient_Checkpointer,
) )

View File

@@ -1,6 +1,7 @@
""" """
module to handle loading model on cpu/meta device for FSDP module to handle loading model on cpu/meta device for FSDP
""" """
import os import os
import time import time
from typing import List, Optional, Type, Union from typing import List, Optional, Type, Union
@@ -45,14 +46,14 @@ def _replace_linear(
if isinstance(module, torch.nn.Linear) and name not in skip_modules: if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit): if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access model._modules[name] = ( # pylint: disable=protected-access
name linear_replacement(
] = linear_replacement(
module.in_features, module.in_features,
module.out_features, module.out_features,
module.bias is not None, module.bias is not None,
**kwargs, **kwargs,
) )
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported linear replacement: {type(linear_replacement)}" f"Unsupported linear replacement: {type(linear_replacement)}"

View File

@@ -741,9 +741,9 @@ class ModelLoader:
) )
else: else:
if self.cfg.gptq_disable_exllama is not None: if self.cfg.gptq_disable_exllama is not None:
self.model_config.quantization_config[ self.model_config.quantization_config["disable_exllama"] = (
"disable_exllama" self.cfg.gptq_disable_exllama
] = self.cfg.gptq_disable_exllama )
self.model_kwargs["quantization_config"] = GPTQConfig( self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config **self.model_config.quantization_config
) )

Some files were not shown because too many files have changed in this diff Show More