Compare commits
14 Commits
revert-290
...
quantize-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a51852af1 | ||
|
|
170322a1f0 | ||
|
|
5f5ae76213 | ||
|
|
a798975b7c | ||
|
|
d23f972602 | ||
|
|
8e41317250 | ||
|
|
9f2bb188a4 | ||
|
|
9dde9e1b71 | ||
|
|
f2474ef941 | ||
|
|
8a4bcacdb2 | ||
|
|
d2c3d5a954 | ||
|
|
36cbe13d18 | ||
|
|
2c408b5c5e | ||
|
|
942005f526 |
16
.coderabbit.yaml
Normal file
16
.coderabbit.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||||
|
language: "en-US"
|
||||||
|
early_access: false
|
||||||
|
reviews:
|
||||||
|
profile: "chill"
|
||||||
|
request_changes_workflow: false
|
||||||
|
high_level_summary: true
|
||||||
|
review_status: true
|
||||||
|
collapse_walkthrough: true
|
||||||
|
poem: false
|
||||||
|
sequence_diagrams: false
|
||||||
|
auto_review:
|
||||||
|
enabled: true
|
||||||
|
drafts: false
|
||||||
|
chat:
|
||||||
|
auto_reply: true
|
||||||
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -87,7 +87,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -98,6 +97,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
53
.github/workflows/tests-nightly.yml
vendored
53
.github/workflows/tests-nightly.yml
vendored
@@ -92,7 +92,7 @@ jobs:
|
|||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 60
|
timeout-minutes: 120
|
||||||
needs: [pre-commit, pytest]
|
needs: [pre-commit, pytest]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
@@ -106,6 +106,13 @@ jobs:
|
|||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
|
nightly_build: "true"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -116,7 +123,7 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal==0.71.8 jinja2
|
pip install modal==1.0.2 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
@@ -130,3 +137,45 @@ jobs:
|
|||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
docker-e2e-multigpu-tests:
|
||||||
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
|
runs-on: [self-hosted, modal]
|
||||||
|
timeout-minutes: 120
|
||||||
|
needs: [pre-commit, pytest, docker-e2e-tests]
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
num_gpus: 2
|
||||||
|
axolotl_extras:
|
||||||
|
nightly_build: "true"
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Install Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
- name: Install Modal
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install modal==1.0.2 jinja2
|
||||||
|
- name: Update env vars
|
||||||
|
run: |
|
||||||
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
|
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||||
|
- name: Run tests job on Modal
|
||||||
|
run: |
|
||||||
|
modal run cicd.multigpu
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ coverage:
|
|||||||
only_pulls: true
|
only_pulls: true
|
||||||
flags: null
|
flags: null
|
||||||
paths: null
|
paths: null
|
||||||
|
informational: true
|
||||||
patch:
|
patch:
|
||||||
default:
|
default:
|
||||||
# basic
|
# basic
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ huggingface_hub>=0.33.0
|
|||||||
peft==0.16.0
|
peft==0.16.0
|
||||||
transformers==4.53.2
|
transformers==4.53.2
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.8.1
|
accelerate==1.9.0
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.19.1
|
trl==0.19.1
|
||||||
@@ -26,7 +26,7 @@ hf_transfer
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.23.3
|
gradio==5.23.3
|
||||||
|
|
||||||
modal==0.70.5
|
modal==1.0.2
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def do_quantize(
|
|||||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||||
)
|
)
|
||||||
|
|
||||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
model_path = cli_args.get("base_model") or cfg.output_dir
|
||||||
if weight_dtype := cli_args.get("weight_dtype"):
|
if weight_dtype := cli_args.get("weight_dtype"):
|
||||||
weight_dtype = TorchIntDType[weight_dtype]
|
weight_dtype = TorchIntDType[weight_dtype]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
chat dataset module
|
chat dataset module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -41,14 +40,10 @@ class TokenizedChatDataset(Dataset):
|
|||||||
)
|
)
|
||||||
return ex.tokenized(model_transform)
|
return ex.tokenized(model_transform)
|
||||||
|
|
||||||
process_or_cpu_count: int = (
|
|
||||||
process_count or os.cpu_count() # type: ignore[assignment]
|
|
||||||
)
|
|
||||||
num_proc = min(32, process_or_cpu_count)
|
|
||||||
features = data.features.keys()
|
features = data.features.keys()
|
||||||
tokenized_data = data.map(
|
tokenized_data = data.map(
|
||||||
map_fn,
|
map_fn,
|
||||||
num_proc=num_proc,
|
num_proc=process_count,
|
||||||
keep_in_memory=keep_in_memory,
|
keep_in_memory=keep_in_memory,
|
||||||
remove_columns=features,
|
remove_columns=features,
|
||||||
desc="Tokenizing Chats",
|
desc="Tokenizing Chats",
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ class GRPOStrategy:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_blocklist_args_kwargs(cls) -> list[str]:
|
def get_blocklist_args_kwargs(cls) -> list[str]:
|
||||||
return ["dataset_num_proc", "max_length"]
|
return ["dataset_num_proc", "max_length", "include_tokens_per_second"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
"""Module containing Dataset functionality"""
|
"""Module containing Dataset functionality"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
@@ -46,7 +44,6 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
|
||||||
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
@@ -59,13 +56,13 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
):
|
):
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
self.prompt_tokenizer.filter_rows,
|
self.prompt_tokenizer.filter_rows,
|
||||||
num_proc=num_proc,
|
num_proc=self.process_count,
|
||||||
desc="Strategy Filtering Rows",
|
desc="Strategy Filtering Rows",
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
num_proc=num_proc,
|
num_proc=self.process_count,
|
||||||
remove_columns=features,
|
remove_columns=features,
|
||||||
keep_in_memory=self.keep_in_memory,
|
keep_in_memory=self.keep_in_memory,
|
||||||
desc="Tokenizing Prompts",
|
desc="Tokenizing Prompts",
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
|
|||||||
from Apple's ML team.
|
from Apple's ML team.
|
||||||
"""
|
"""
|
||||||
import importlib
|
import importlib
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils import get_pytorch_version
|
from axolotl.utils import get_pytorch_version
|
||||||
|
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
@@ -32,7 +34,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -84,6 +86,7 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
"""Apply cut cross entropy before model loading if enabled."""
|
"""Apply cut cross entropy before model loading if enabled."""
|
||||||
if cfg.cut_cross_entropy:
|
if cfg.cut_cross_entropy:
|
||||||
self._check_requirements()
|
self._check_requirements()
|
||||||
|
self.patch_llama_like(cfg.model_config_type)
|
||||||
|
|
||||||
from cut_cross_entropy.transformers.patch import cce_patch
|
from cut_cross_entropy.transformers.patch import cce_patch
|
||||||
|
|
||||||
@@ -93,3 +96,48 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
|
|
||||||
# The patch checks model_type internally
|
# The patch checks model_type internally
|
||||||
cce_patch(cfg.model_config_type)
|
cce_patch(cfg.model_config_type)
|
||||||
|
|
||||||
|
def patch_llama_like(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generic patch for model architectures with causal lm similar to llama
|
||||||
|
"""
|
||||||
|
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
||||||
|
|
||||||
|
def patch_generic(
|
||||||
|
maybe_model, patch_options, model_type: str
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
import cut_cross_entropy.transformers.llama
|
||||||
|
from cut_cross_entropy.transformers.llama import cce_forward
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Dynamically import the module and CausalLM class
|
||||||
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||||
|
module = __import__(
|
||||||
|
module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]
|
||||||
|
)
|
||||||
|
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||||
|
|
||||||
|
cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
|
||||||
|
patch_options
|
||||||
|
)
|
||||||
|
|
||||||
|
model_cls.forward = cce_forward
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not import ForCausalLM class for model_type: {model_type}. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if model_type not in PATCH_FNS:
|
||||||
|
LOG.warning_once(
|
||||||
|
"Setting up generic cce patch for model type: %s", model_type
|
||||||
|
)
|
||||||
|
LOG.warning_once(
|
||||||
|
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
|
||||||
|
)
|
||||||
|
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)
|
||||||
|
|||||||
@@ -41,3 +41,13 @@ class CutCrossEntropyArgs(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_chunked_cross_entropy_not_set(cls, data):
|
||||||
|
if data.get("chunked_cross_entropy"):
|
||||||
|
raise ValueError(
|
||||||
|
"Cut Cross Entropy does not support chunked cross entropy. "
|
||||||
|
"Please set `chunked_cross_entropy` to `False` or disable Cut Cross Entropy."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ except ImportError:
|
|||||||
TransformersKwargs,
|
TransformersKwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
||||||
|
|
||||||
|
|
||||||
def kldiv_forward_llama_like(
|
def kldiv_forward_llama_like(
|
||||||
self,
|
self,
|
||||||
@@ -97,7 +99,7 @@ def kldiv_forward_llama_like(
|
|||||||
def apply_kernel(model_type):
|
def apply_kernel(model_type):
|
||||||
# Dynamically import the module and attention class
|
# Dynamically import the module and attention class
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
|
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
||||||
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||||
model_cls.forward = kldiv_forward_llama_like
|
model_cls.forward = kldiv_forward_llama_like
|
||||||
|
|||||||
@@ -18,170 +18,10 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
|||||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||||
It is designed to be performant, correct, and light-weight.
|
It is designed to be performant, correct, and light-weight.
|
||||||
"""
|
"""
|
||||||
import inspect
|
from .args import LigerArgs
|
||||||
import sys
|
from .plugin import LigerPlugin
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
__all__ = [
|
||||||
from axolotl.utils.logging import get_logger
|
"LigerArgs",
|
||||||
|
"LigerPlugin",
|
||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
]
|
||||||
from .utils import patch_with_compile_disable
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LigerPlugin(BasePlugin):
|
|
||||||
"""
|
|
||||||
Plugin for LIGER integraton with Axolotl.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_input_args(self):
|
|
||||||
return "axolotl.integrations.liger.LigerArgs"
|
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
|
||||||
if cfg.torch_compile:
|
|
||||||
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
|
||||||
import liger_kernel.ops.fused_linear_cross_entropy
|
|
||||||
|
|
||||||
patch_with_compile_disable(
|
|
||||||
liger_kernel.ops.fused_linear_cross_entropy,
|
|
||||||
"fused_linear_cross_entropy_forward",
|
|
||||||
)
|
|
||||||
patch_with_compile_disable(
|
|
||||||
liger_kernel.ops.fused_linear_cross_entropy,
|
|
||||||
"fused_linear_cross_entropy_backward",
|
|
||||||
)
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
||||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
||||||
|
|
||||||
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
||||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
|
||||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
|
||||||
kwargs = {}
|
|
||||||
if "rope" in liger_fn_sig.parameters:
|
|
||||||
kwargs["rope"] = cfg.liger_rope
|
|
||||||
if "cross_entropy" in liger_fn_sig.parameters:
|
|
||||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
|
||||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
|
||||||
kwargs["fused_linear_cross_entropy"] = (
|
|
||||||
cfg.liger_fused_linear_cross_entropy
|
|
||||||
)
|
|
||||||
if "rms_norm" in liger_fn_sig.parameters:
|
|
||||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
|
||||||
if "layer_norm" in liger_fn_sig.parameters:
|
|
||||||
kwargs["layer_norm"] = cfg.liger_layer_norm
|
|
||||||
if "geglu" in liger_fn_sig.parameters:
|
|
||||||
kwargs["geglu"] = cfg.liger_glu_activation
|
|
||||||
elif "swiglu" in liger_fn_sig.parameters:
|
|
||||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
|
||||||
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
|
|
||||||
apply_liger_fn(**kwargs)
|
|
||||||
elif cfg.model_config_type == "jamba":
|
|
||||||
from transformers.models.jamba import modeling_jamba
|
|
||||||
|
|
||||||
from .models.jamba import lce_forward as jamba_lce_forward
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
|
||||||
if cfg.liger_glu_activation:
|
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
|
||||||
if cfg.liger_layer_norm:
|
|
||||||
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
from transformers.loss.loss_utils import nn
|
|
||||||
|
|
||||||
nn.functional.cross_entropy = liger_cross_entropy
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
|
||||||
elif cfg.model_config_type == "deepseek_v2":
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
with init_empty_weights():
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
|
|
||||||
)
|
|
||||||
modeling_mod = sys.modules[model.__class__.__module__]
|
|
||||||
|
|
||||||
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
|
||||||
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
|
||||||
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
|
|
||||||
if cfg.liger_glu_activation:
|
|
||||||
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
|
||||||
if cfg.liger_glu_activation:
|
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
|
||||||
if cfg.liger_layer_norm:
|
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
|
||||||
# nn.CrossEntropyLoss in the forward method.
|
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
|
||||||
elif cfg.model_config_type == "llama4":
|
|
||||||
from axolotl.integrations.liger.models.llama4 import (
|
|
||||||
apply_liger_kernel_to_llama4,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_llama4(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "qwen3":
|
|
||||||
from axolotl.integrations.liger.models.qwen3 import (
|
|
||||||
apply_liger_kernel_to_qwen3,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_qwen3(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "qwen3_moe":
|
|
||||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
|
||||||
apply_liger_kernel_to_qwen3_moe,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_qwen3_moe(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "granitemoe":
|
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
|
||||||
|
|
||||||
apply_liger_kernel_to_granite(
|
|
||||||
rope=cfg.liger_rope,
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
swiglu=cfg.liger_glu_activation,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
LOG.warning(
|
|
||||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
|
||||||
)
|
|
||||||
|
|||||||
189
src/axolotl/integrations/liger/models/base.py
Normal file
189
src/axolotl/integrations/liger/models/base.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
Generic FLCE patch for untested models similar to Llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||||
|
from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection
|
||||||
|
from liger_kernel.utils import PEFT_AVAILABLE
|
||||||
|
from peft.utils import ModulesToSaveWrapper
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
||||||
|
|
||||||
|
|
||||||
|
def lce_forward(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
skip_logits: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
*args,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
||||||
|
|
||||||
|
shift_labels = kwargs.pop("shift_labels", None)
|
||||||
|
logits = None
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
# if in training mode, don't materialize logits
|
||||||
|
if skip_logits and labels is None and shift_labels is None:
|
||||||
|
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
||||||
|
|
||||||
|
if skip_logits is None:
|
||||||
|
# By default, if in training mode, don't materialize logits
|
||||||
|
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
||||||
|
|
||||||
|
if skip_logits:
|
||||||
|
loss = lce_maybe_trainable_lm_head(
|
||||||
|
self,
|
||||||
|
hidden_states=kept_hidden_states,
|
||||||
|
hidden_size=self.config.hidden_size,
|
||||||
|
labels=labels,
|
||||||
|
shift_labels=shift_labels,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(kept_hidden_states)
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def lce_maybe_trainable_lm_head(
|
||||||
|
self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs
|
||||||
|
):
|
||||||
|
lm_head = self.lm_head
|
||||||
|
|
||||||
|
# Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
|
||||||
|
# i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
|
||||||
|
# from the unwrapped module.
|
||||||
|
# See https://huggingface.co/docs/peft/package_reference/lora for reference.
|
||||||
|
if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
|
||||||
|
lm_head = lm_head.modules_to_save.default
|
||||||
|
|
||||||
|
# If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
|
||||||
|
# reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
|
||||||
|
# so the module entire parameters are summoned and kept in memory during the kernel execution.
|
||||||
|
if isinstance(lm_head, FullyShardedDataParallel):
|
||||||
|
return _FSDPForwardRedirection()(
|
||||||
|
lm_head,
|
||||||
|
_liger_for_causal_lm_loss,
|
||||||
|
lm_head.module,
|
||||||
|
hidden_states,
|
||||||
|
hidden_size,
|
||||||
|
labels,
|
||||||
|
shift_labels,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# FSDP is not used so we can read the lm_head weights and call the kernel directly
|
||||||
|
return _liger_for_causal_lm_loss(
|
||||||
|
lm_head=self.lm_head,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
labels=labels,
|
||||||
|
shift_labels=shift_labels,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _liger_for_causal_lm_loss(
|
||||||
|
lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs
|
||||||
|
):
|
||||||
|
return LigerForCausalLMLoss(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
lm_head_weight=lm_head.weight,
|
||||||
|
labels=labels,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
shift_labels=shift_labels,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_lce_forward(
|
||||||
|
model_type,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# Dynamically import the module and MLP class
|
||||||
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||||
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
||||||
|
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||||
|
|
||||||
|
model_cls.forward = lce_forward
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not import ForCausalLM class for model_type: {model_type}. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
) from e
|
||||||
182
src/axolotl/integrations/liger/plugin.py
Normal file
182
src/axolotl/integrations/liger/plugin.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
"""
|
||||||
|
Liger-Kernel Plugin for Axolotl
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
from .models.base import patch_lce_forward
|
||||||
|
from .utils import patch_with_compile_disable
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LigerPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for LIGER integraton with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.liger.LigerArgs"
|
||||||
|
|
||||||
|
def pre_model_load(self, cfg):
|
||||||
|
if cfg.torch_compile:
|
||||||
|
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
||||||
|
import liger_kernel.ops.fused_linear_cross_entropy
|
||||||
|
|
||||||
|
patch_with_compile_disable(
|
||||||
|
liger_kernel.ops.fused_linear_cross_entropy,
|
||||||
|
"fused_linear_cross_entropy_forward",
|
||||||
|
)
|
||||||
|
patch_with_compile_disable(
|
||||||
|
liger_kernel.ops.fused_linear_cross_entropy,
|
||||||
|
"fused_linear_cross_entropy_backward",
|
||||||
|
)
|
||||||
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||||
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
|
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||||
|
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||||
|
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||||
|
kwargs = {}
|
||||||
|
if "rope" in liger_fn_sig.parameters:
|
||||||
|
kwargs["rope"] = cfg.liger_rope
|
||||||
|
if "cross_entropy" in liger_fn_sig.parameters:
|
||||||
|
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||||
|
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||||
|
kwargs["fused_linear_cross_entropy"] = (
|
||||||
|
cfg.liger_fused_linear_cross_entropy
|
||||||
|
)
|
||||||
|
if "rms_norm" in liger_fn_sig.parameters:
|
||||||
|
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||||
|
if "layer_norm" in liger_fn_sig.parameters:
|
||||||
|
kwargs["layer_norm"] = cfg.liger_layer_norm
|
||||||
|
if "geglu" in liger_fn_sig.parameters:
|
||||||
|
kwargs["geglu"] = cfg.liger_glu_activation
|
||||||
|
elif "swiglu" in liger_fn_sig.parameters:
|
||||||
|
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||||
|
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
|
||||||
|
apply_liger_fn(**kwargs)
|
||||||
|
elif cfg.model_config_type == "jamba":
|
||||||
|
from transformers.models.jamba import modeling_jamba
|
||||||
|
|
||||||
|
from .models.jamba import lce_forward as jamba_lce_forward
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||||
|
elif cfg.model_config_type == "deepseek_v2":
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
|
||||||
|
)
|
||||||
|
modeling_mod = sys.modules[model.__class__.__module__]
|
||||||
|
|
||||||
|
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||||
|
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||||
|
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
LOG.warning("liger_layer_norm is not supported for DeepseekV2.")
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||||
|
# nn.CrossEntropyLoss in the forward method.
|
||||||
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
elif cfg.model_config_type == "llama4":
|
||||||
|
from axolotl.integrations.liger.models.llama4 import (
|
||||||
|
apply_liger_kernel_to_llama4,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_liger_kernel_to_llama4(
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
glu_activation=cfg.liger_glu_activation,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
layer_norm=cfg.liger_layer_norm,
|
||||||
|
)
|
||||||
|
elif cfg.model_config_type == "qwen3":
|
||||||
|
from axolotl.integrations.liger.models.qwen3 import (
|
||||||
|
apply_liger_kernel_to_qwen3,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_liger_kernel_to_qwen3(
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
glu_activation=cfg.liger_glu_activation,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
layer_norm=cfg.liger_layer_norm,
|
||||||
|
)
|
||||||
|
elif cfg.model_config_type == "qwen3_moe":
|
||||||
|
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||||
|
apply_liger_kernel_to_qwen3_moe,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_liger_kernel_to_qwen3_moe(
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
glu_activation=cfg.liger_glu_activation,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
layer_norm=cfg.liger_layer_norm,
|
||||||
|
)
|
||||||
|
elif cfg.model_config_type == "granitemoe":
|
||||||
|
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
||||||
|
|
||||||
|
apply_liger_kernel_to_granite(
|
||||||
|
rope=cfg.liger_rope,
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
swiglu=cfg.liger_glu_activation,
|
||||||
|
)
|
||||||
|
elif cfg.liger_fused_linear_cross_entropy:
|
||||||
|
try:
|
||||||
|
patch_lce_forward(cfg.model_config_type)
|
||||||
|
LOG.warning_once(
|
||||||
|
f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"
|
||||||
|
)
|
||||||
|
LOG.warning_once(
|
||||||
|
f"Liger + {cfg.model_config_type} generic FLCE support is experimental and may not work as expected."
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
LOG.warning(
|
||||||
|
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||||
|
)
|
||||||
0
src/axolotl/loaders/adapters/__init__.py
Normal file
0
src/axolotl/loaders/adapters/__init__.py
Normal file
@@ -272,7 +272,11 @@ class PatchManager:
|
|||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
||||||
|
|
||||||
patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
|
patch_tiled_mlp(
|
||||||
|
model_type,
|
||||||
|
use_original_mlp=self.cfg.tiled_mlp_use_original_mlp,
|
||||||
|
cfg_num_shards=self.cfg.tiled_mlp_num_shards,
|
||||||
|
)
|
||||||
|
|
||||||
def _patch_attention(self):
|
def _patch_attention(self):
|
||||||
"""Apply attention-specific patches based on model type."""
|
"""Apply attention-specific patches based on model type."""
|
||||||
|
|||||||
@@ -188,7 +188,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
# Qwen base only has single token, so we need to set the special tokens
|
# Qwen base only has single token, so we need to set the special tokens
|
||||||
if cfg.is_qwen_derived_model:
|
# the following check is for Qwen1 base models
|
||||||
|
if cfg.is_qwen_derived_model and hasattr(tokenizer, "eod_id"):
|
||||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
||||||
for attr_name in token_ids:
|
for attr_name in token_ids:
|
||||||
if getattr(tokenizer, attr_name) is None:
|
if getattr(tokenizer, attr_name) is None:
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
|||||||
"loggers": {
|
"loggers": {
|
||||||
"axolotl": {
|
"axolotl": {
|
||||||
"handlers": ["color_console"],
|
"handlers": ["color_console"],
|
||||||
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL),
|
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
|
||||||
"propagate": False,
|
"propagate": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from axolotl.kernels.lora import (
|
|||||||
apply_lora_qkv,
|
apply_lora_qkv,
|
||||||
)
|
)
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -150,12 +151,15 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
|
|
||||||
return MllamaTextSelfAttention
|
return MllamaTextSelfAttention
|
||||||
|
|
||||||
|
if model_type == "llama4":
|
||||||
|
from transformers.models.llama4.modeling_llama4 import Llama4TextAttention
|
||||||
|
|
||||||
|
return Llama4TextAttention
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Dynamically import the module and attention class
|
# Dynamically import the module and attention class
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
model_cls_prefix = "".join(
|
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||||
[part.capitalize() for part in model_type.split("_")]
|
|
||||||
)
|
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
|
||||||
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
|
||||||
|
|
||||||
|
|
||||||
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||||
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
|
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
|
||||||
@@ -13,9 +15,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
|||||||
try:
|
try:
|
||||||
# Dynamically import the module and MLP class
|
# Dynamically import the module and MLP class
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
model_cls_prefix = "".join(
|
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
|
||||||
[part.capitalize() for part in model_type.split("_")]
|
|
||||||
)
|
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||||
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
||||||
|
|
||||||
@@ -45,11 +45,12 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
|||||||
else:
|
else:
|
||||||
num_shards = cfg_num_shards
|
num_shards = cfg_num_shards
|
||||||
|
|
||||||
compute_params = [
|
if not self._compute_params: # pylint: disable=protected-access
|
||||||
self.down_proj.weight,
|
self._compute_params = [ # pylint: disable=protected-access
|
||||||
self.gate_proj.weight,
|
p for p in self.parameters() if p.requires_grad
|
||||||
self.up_proj.weight,
|
]
|
||||||
]
|
|
||||||
|
compute_params = self._compute_params # pylint: disable=protected-access
|
||||||
|
|
||||||
down_res = TiledMLP.apply(
|
down_res = TiledMLP.apply(
|
||||||
mlp_forward,
|
mlp_forward,
|
||||||
@@ -61,6 +62,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
|||||||
return down_res
|
return down_res
|
||||||
|
|
||||||
mlp_cls.forward = tiled_mlp_forward
|
mlp_cls.forward = tiled_mlp_forward
|
||||||
|
mlp_cls._compute_params = [] # pylint: disable=protected-access
|
||||||
except (ImportError, AttributeError) as e:
|
except (ImportError, AttributeError) as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Could not import MLP class for model_type: {model_type}. "
|
f"Could not import MLP class for model_type: {model_type}. "
|
||||||
|
|||||||
@@ -798,7 +798,7 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
if is_main_process():
|
if state.is_world_process_zero:
|
||||||
try:
|
try:
|
||||||
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
||||||
with NamedTemporaryFile(
|
with NamedTemporaryFile(
|
||||||
|
|||||||
23
src/axolotl/utils/callbacks/models.py
Normal file
23
src/axolotl/utils/callbacks/models.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""Helper functions for model classes"""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
def get_causal_lm_model_cls_prefix(model_type: str) -> Tuple[str, str]:
|
||||||
|
if model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
causal_lm_cls = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
||||||
|
causal_lm_cls_prefix = causal_lm_cls
|
||||||
|
for suffix in [
|
||||||
|
"ForCausalLM",
|
||||||
|
"ForConditionalGeneration",
|
||||||
|
"LMHeadModel",
|
||||||
|
"GenerationDecoder",
|
||||||
|
]:
|
||||||
|
causal_lm_cls_prefix = causal_lm_cls_prefix.replace(suffix, "")
|
||||||
|
return causal_lm_cls_prefix, causal_lm_cls
|
||||||
|
causal_lm_cls_prefix = "".join(
|
||||||
|
[part.capitalize() for part in model_type.split("_")]
|
||||||
|
)
|
||||||
|
return causal_lm_cls_prefix, f"{causal_lm_cls_prefix}ForCausalLM"
|
||||||
@@ -148,8 +148,6 @@ def normalize_config(cfg):
|
|||||||
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
|
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
|
||||||
|
|
||||||
if not cfg.base_model_config:
|
if not cfg.base_model_config:
|
||||||
cfg.base_model_config = cfg.base_model
|
cfg.base_model_config = cfg.base_model
|
||||||
|
|
||||||
|
|||||||
@@ -410,9 +410,8 @@ def save_preprocessed_dataset(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
||||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||||
|
num_workers = cfg.dataset_processes
|
||||||
if isinstance(dataset, IterableDataset):
|
if isinstance(dataset, IterableDataset):
|
||||||
num_workers = cfg.dataset_processes
|
|
||||||
|
|
||||||
ds_from_iter = Dataset.from_generator(
|
ds_from_iter = Dataset.from_generator(
|
||||||
functools.partial(_generate_from_iterable_dataset, dataset),
|
functools.partial(_generate_from_iterable_dataset, dataset),
|
||||||
features=dataset.features,
|
features=dataset.features,
|
||||||
@@ -423,10 +422,20 @@ def save_preprocessed_dataset(
|
|||||||
"num_workers": [num_workers] * num_workers,
|
"num_workers": [num_workers] * num_workers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
ds_from_iter.save_to_disk(
|
||||||
|
str(prepared_ds_path),
|
||||||
|
num_proc=num_workers,
|
||||||
|
max_shard_size=None,
|
||||||
|
num_shards=cfg.num_dataset_shards_to_save,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
os.makedirs(prepared_ds_path, exist_ok=True)
|
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(
|
||||||
|
str(prepared_ds_path),
|
||||||
|
num_proc=num_workers,
|
||||||
|
max_shard_size=None,
|
||||||
|
num_shards=cfg.num_dataset_shards_to_save,
|
||||||
|
)
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Pushing merged prepared dataset to Huggingface hub at "
|
"Pushing merged prepared dataset to Huggingface hub at "
|
||||||
@@ -460,13 +469,13 @@ def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset |
|
|||||||
):
|
):
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Loading prepared dataset from disk at {prepared_ds_path}...",
|
f"Loading prepared dataset from disk at {prepared_ds_path}...",
|
||||||
main_process_only=False,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
return load_from_disk(str(prepared_ds_path))
|
return load_from_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Unable to find prepared dataset in {prepared_ds_path}",
|
f"Unable to find prepared dataset in {prepared_ds_path}",
|
||||||
main_process_only=False,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from torchao.quantization.quant_api import (
|
|||||||
UIntXWeightOnlyConfig,
|
UIntXWeightOnlyConfig,
|
||||||
_is_linear,
|
_is_linear,
|
||||||
)
|
)
|
||||||
|
from transformers import TorchAoConfig
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import TorchIntDType
|
from axolotl.utils.schemas.enums import TorchIntDType
|
||||||
|
|
||||||
@@ -149,7 +150,9 @@ def quantize_model_for_ptq(
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
quantize_(model, linear_ptq_config)
|
quantize_(model, linear_ptq_config)
|
||||||
|
quantization_config = TorchAoConfig(linear_ptq_config)
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
|
quantization_config.include_input_output_embeddings = True
|
||||||
embedding_quantize_config = get_ptq_config(
|
embedding_quantize_config = get_ptq_config(
|
||||||
weight_dtype=weight_dtype,
|
weight_dtype=weight_dtype,
|
||||||
activation_dtype=None,
|
activation_dtype=None,
|
||||||
@@ -160,6 +163,7 @@ def quantize_model_for_ptq(
|
|||||||
embedding_quantize_config,
|
embedding_quantize_config,
|
||||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||||
)
|
)
|
||||||
|
model.config.quantization_config = quantization_config
|
||||||
|
|
||||||
|
|
||||||
def convert_qat_model_for_ptq(
|
def convert_qat_model_for_ptq(
|
||||||
|
|||||||
@@ -193,6 +193,12 @@ class AxolotlInputConfig(
|
|||||||
json_schema_extra={"description": "Index of shard to use for whole dataset"},
|
json_schema_extra={"description": "Index of shard to use for whole dataset"},
|
||||||
)
|
)
|
||||||
skip_prepare_dataset: bool | None = False
|
skip_prepare_dataset: bool | None = False
|
||||||
|
num_dataset_shards_to_save: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of shards to save the prepared dataset"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
pretraining_dataset: (
|
pretraining_dataset: (
|
||||||
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
|
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
|
||||||
@@ -203,11 +209,12 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
dataset_processes: int | None = Field(
|
dataset_processes: int | None = Field(
|
||||||
default=min(
|
default=None,
|
||||||
int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()
|
|
||||||
), # type: ignore[type-var]
|
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
|
"description": (
|
||||||
|
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
|
||||||
|
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
dataset_exact_deduplication: bool | None = Field(
|
dataset_exact_deduplication: bool | None = Field(
|
||||||
@@ -576,6 +583,13 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tiled_mlp_use_original_mlp: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
llama4_linearized_experts: bool | None = None
|
llama4_linearized_experts: bool | None = None
|
||||||
|
|
||||||
deepspeed: str | dict[str, Any] | None = Field(
|
deepspeed: str | dict[str, Any] | None = Field(
|
||||||
@@ -1192,3 +1206,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
data["dataloader_prefetch_factor"] = 256
|
data["dataloader_prefetch_factor"] = 256
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def default_dataset_processes(cls, data):
|
||||||
|
if data.get("dataset_processes") is None:
|
||||||
|
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
|
||||||
|
data["dataset_processes"] = int(axolotl_dataset_processes)
|
||||||
|
elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
|
||||||
|
data["dataset_processes"] = int(runpod_cpu_count)
|
||||||
|
else:
|
||||||
|
data["dataset_processes"] = os.cpu_count()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|||||||
@@ -1066,23 +1066,23 @@ class ModelCompatibilityValidationMixin:
|
|||||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_offload_grad_checkpointing(self):
|
|
||||||
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
|
||||||
LOG.warning(
|
|
||||||
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
|
|
||||||
)
|
|
||||||
self.gradient_checkpointing = "offload"
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_gradient_checkpointing_w_offload(self):
|
def check_gradient_checkpointing_w_offload(self):
|
||||||
if self.gradient_checkpointing == "offload":
|
if self.gradient_checkpointing == "offload":
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true`"
|
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`"
|
||||||
)
|
)
|
||||||
self.gradient_checkpointing = True
|
self.gradient_checkpointing = True
|
||||||
self.activation_offloading = True
|
if self.adapter and "lora" in self.adapter:
|
||||||
|
LOG.warning(
|
||||||
|
"offloading with CUDA streams is not supported for LoRA adapters, using the `activation_offloading: legacy` implementation."
|
||||||
|
)
|
||||||
|
self.activation_offloading = "legacy"
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
"`offload` uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`"
|
||||||
|
)
|
||||||
|
self.activation_offloading = True
|
||||||
if self.gradient_checkpointing == "offload_disk":
|
if self.gradient_checkpointing == "offload_disk":
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
|
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
|
||||||
@@ -1091,6 +1091,19 @@ class ModelCompatibilityValidationMixin:
|
|||||||
self.activation_offloading = "disk"
|
self.activation_offloading = "disk"
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_activation_offloading_w_lora(self):
|
||||||
|
if (
|
||||||
|
self.activation_offloading is True
|
||||||
|
and self.adapter
|
||||||
|
and "lora" in self.adapter
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"activation_offloading with CUDA streams is not supported for LoRA adapters. Setting `activation_offloading: legacy`"
|
||||||
|
)
|
||||||
|
self.activation_offloading = "legacy"
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_activation_offloading_wo_gc(self):
|
def check_activation_offloading_wo_gc(self):
|
||||||
if self.activation_offloading and not self.gradient_checkpointing:
|
if self.activation_offloading and not self.gradient_checkpointing:
|
||||||
|
|||||||
91
tests/utils/schemas/validation/test_activation_offloading.py
Normal file
91
tests/utils/schemas/validation/test_activation_offloading.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""Test for config validation for activation offloading."""
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class TestActivationOffloading:
|
||||||
|
"""
|
||||||
|
Test cases for activation offloading schema validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_gc_converts_offload_wo_lora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing="offload",
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading is True
|
||||||
|
|
||||||
|
def test_gc_converts_offload_w_lora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing="offload",
|
||||||
|
adapter="lora",
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_gc_converts_offload_w_qlora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing="offload",
|
||||||
|
adapter="qlora",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_ac_impl_changes_w_lora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
activation_offloading=True,
|
||||||
|
adapter="lora",
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_ac_impl_changes_w_qlora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
activation_offloading=True,
|
||||||
|
adapter="qlora",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
activation_offloading=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading is True
|
||||||
Reference in New Issue
Block a user