Compare commits

..

6 Commits

Author SHA1 Message Date
Wing Lian
39ab9626f1 add transformers module to cleanup 2024-12-08 14:52:54 -05:00
Wing Lian
26bd81cec0 re-enable tests w change in patching 2024-12-08 14:52:09 -05:00
Wing Lian
1302e31049 Transformers version flexibility and FSDP optimizer patch (#2155)
* allow flexibility in transformers version for FSDP

* more flexibility with dev versions of 4.47.0.dev0

* add patch for fsdp

* fix typo

* correct fn name

* stray character

* fix patch

* reset Trainer too

* also reset Trainer.training_step

* allow tests/patched to run more than one process on e2e runner

* skip tests/patched in e2e for now since it's run in regular pytest
2024-12-08 14:50:40 -05:00
Wing Lian
be5f554a62 bump autoawq to 0.2.7.post3 (#2150) 2024-12-07 22:24:09 -05:00
Wing Lian
22319182ab fix for auto_map check when using remote code and multipack for models like deepseek (#2151) [skip ci] 2024-12-07 22:23:52 -05:00
Wing Lian
440aab8a6f add --version support to axolotl cli (#2152) [skip ci] 2024-12-07 22:23:33 -05:00
13 changed files with 231 additions and 113 deletions

View File

@@ -22,38 +22,36 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
# - cuda: "121" - cuda: "121"
# cuda_version: 12.1.1 cuda_version: 12.1.1
# cudnn_version: 8 cudnn_version: 8
# python_version: "3.10" python_version: "3.10"
# pytorch: 2.3.1 pytorch: 2.3.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# from_base_img: "" - cuda: "121"
# from_base_tag: "" cuda_version: 12.1.1
# - cuda: "121" cudnn_version: 8
# cuda_version: 12.1.1 python_version: "3.11"
# cudnn_version: 8 pytorch: 2.3.1
# python_version: "3.11" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# pytorch: 2.3.1 - cuda: "124"
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" cuda_version: 12.4.1
# from_base_img: "" cudnn_version: ""
# from_base_tag: "" python_version: "3.10"
# - cuda: "124" pytorch: 2.4.1
# cuda_version: 12.4.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# cudnn_version: "" - cuda: "124"
# python_version: "3.11" cuda_version: 12.4.1
# pytorch: 2.4.1 cudnn_version: ""
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" python_version: "3.11"
# from_base_img: "" pytorch: 2.4.1
# from_base_tag: "" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124" - cuda: "124"
cuda_version: 12.4.1 cuda_version: 12.4.1
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
from_base_img: nvcr.io/nvidia/pytorch
from_base_tag: 24.10-py3
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -63,7 +61,7 @@ jobs:
with: with:
images: | images: |
winglian/axolotl-base winglian/axolotl-base
# axolotlai/axolotl-base axolotlai/axolotl-base
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
with: with:
@@ -76,8 +74,7 @@ jobs:
with: with:
context: . context: .
file: ./docker/Dockerfile-base file: ./docker/Dockerfile-base
push: true push: ${{ github.event_name != 'pull_request' }}
# push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-args: | build-args: |
@@ -87,5 +84,3 @@ jobs:
PYTHON_VERSION=${{ matrix.python_version }} PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }} PYTORCH_VERSION=${{ matrix.pytorch }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }} TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
BASE_IMAGE=${{ matrix.from_base_img || '' }}
BASE_TAG=${{ matrix.from_base_tag || '' }}

View File

@@ -148,64 +148,63 @@ jobs:
run: | run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
# docker-e2e-tests-1st: docker-e2e-tests-1st:
# if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# # this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
# runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
# timeout-minutes: 90 timeout-minutes: 90
# needs: [pre-commit, pytest, pytest-sdist] needs: [pre-commit, pytest, pytest-sdist]
#
# strategy: strategy:
# fail-fast: false fail-fast: false
# matrix: matrix:
# include: include:
# - cuda: 124 - cuda: 124
# cuda_version: 12.4.1 cuda_version: 12.4.1
# python_version: "3.11" python_version: "3.11"
# pytorch: 2.4.1 pytorch: 2.4.1
# num_gpus: 1 num_gpus: 1
# axolotl_extras: axolotl_extras:
# steps: steps:
# - name: Checkout - name: Checkout
# uses: actions/checkout@v4 uses: actions/checkout@v4
# - name: Install Python - name: Install Python
# uses: actions/setup-python@v5 uses: actions/setup-python@v5
# with: with:
# python-version: "3.10" python-version: "3.10"
# - name: Install Modal - name: Install Modal
# run: | run: |
# python -m pip install --upgrade pip python -m pip install --upgrade pip
# pip install modal==0.63.64 jinja2 pip install modal==0.63.64 jinja2
# - name: Update env vars - name: Update env vars
# run: | run: |
# echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
# echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
# echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
# echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
# echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
# echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
# - name: Run tests job on Modal - name: Run tests job on Modal
# run: | run: |
# modal run cicd.tests modal run cicd.tests
docker-e2e-tests: docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 90 timeout-minutes: 90
# needs: [pre-commit, pytest, docker-e2e-tests-1st] needs: [pre-commit, pytest, docker-e2e-tests-1st]
needs: [pre-commit, pytest]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
# - cuda: 121 - cuda: 121
# cuda_version: 12.1.1 cuda_version: 12.1.1
# python_version: "3.10" python_version: "3.10"
# pytorch: 2.3.1 pytorch: 2.3.1
# num_gpus: 1 num_gpus: 1
# axolotl_extras: mamba-ssm axolotl_extras: mamba-ssm
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
@@ -225,7 +224,7 @@ jobs:
pip install modal==0.63.64 jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=pr-2139-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV

View File

@@ -1,4 +1,4 @@
FROM winglian/axolotl-base:{{ BASE_TAG }} FROM axolotlai/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}" ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"

View File

@@ -2,6 +2,6 @@
set -e set -e
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -1,6 +1,5 @@
ARG BASE_IMAGE=axolotlai/axolotl-base
ARG BASE_TAG=main-base ARG BASE_TAG=main-base
FROM $BASE_IMAGE:$BASE_TAG FROM axolotlai/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""

View File

@@ -3,10 +3,7 @@ ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
ARG BASE_IMAGE=nvidia/cuda FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG DEFAULT_TAG=${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu${UBUNTU_VERSION}
ARG BASE_TAG=""
FROM ${BASE_IMAGE:-nvidia/cuda}:${BASE_TAG:-${DEFAULT_TAG}} AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"
@@ -19,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.47.0 transformers>=4.46.3
tokenizers>=0.20.1 tokenizers>=0.20.1
bitsandbytes==0.45.0 bitsandbytes==0.45.0
accelerate==1.2.0 accelerate==1.2.0
@@ -31,7 +31,7 @@ art
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq==0.2.7.post2 autoawq==0.2.7.post3
triton>=2.3.0 triton>=2.3.0
liger-kernel==0.4.2 liger-kernel==0.4.2

View File

@@ -5,6 +5,7 @@ from typing import Optional
import click import click
import axolotl
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
@@ -16,6 +17,7 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group() @click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli(): def cli():
"""Axolotl CLI - Train and fine-tune large language models""" """Axolotl CLI - Train and fine-tune large language models"""

View File

@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
@@ -973,7 +974,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super().log(logs, start_time)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics( def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1165,9 +1172,13 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
) return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
@@ -1185,9 +1196,13 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
) return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
@@ -1232,9 +1247,13 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
) return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
@@ -1252,9 +1271,13 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items(): for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
) return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
@@ -1266,9 +1289,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method # TODO remove once trl supports the updated to the Trainer.log method
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
logs, start_time return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
) logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):

View File

@@ -0,0 +1,80 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
"""
PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_fsdp():
"""
monkeypatch for fixing the training loop for fsdp with optimizer save
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -5,8 +5,7 @@ see https://github.com/huggingface/transformers/pull/35128
import inspect import inspect
import logging import logging
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM, Trainer
from transformers.trainer import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code from axolotl.monkeypatch.unsloth_ import detab_code

View File

@@ -380,6 +380,13 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg) plugin_manager.pre_model_load(self.cfg)
if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)
patch_training_loop_for_fsdp()
if self.cfg.gradient_checkpointing == "unsloth": if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
@@ -406,10 +413,14 @@ class ModelLoader:
and self.cfg.flash_attention and self.cfg.flash_attention
and self.cfg.sample_packing and self.cfg.sample_packing
): ):
has_remote_code = ( if "auto_map" in self.model_config:
"auto_map" in self.model_config try:
and "AutoModelForCausalLM" in self.model_config["auto_map"] auto_map_config = self.model_config["auto_map"]
) except TypeError:
auto_map_config = self.model_config.auto_map
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False: if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code has_remote_code = self.cfg.trust_remote_code

View File

@@ -119,18 +119,28 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches(): def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import LlamaFlashAttention2 from transformers.models.llama.modeling_llama import LlamaFlashAttention2
original_fa2_forward = LlamaFlashAttention2.forward original_fa2_forward = LlamaFlashAttention2.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests # monkey patches can happen inside the tests
yield yield
# Reset LlamaFlashAttention2 forward # Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward LlamaFlashAttention2.forward = original_fa2_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
# Reset other known monkeypatches # Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [ modules_to_reset: list[tuple[str, list[str]]] = [
("transformers",),
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]), ("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer",), ("transformers.trainer", ["Trainer"]),
("transformers.loss.loss_utils",), ("transformers.loss.loss_utils",),
] ]
for module_name_tuple in modules_to_reset: for module_name_tuple in modules_to_reset: