Compare commits
10 Commits
5e9fa33f3d
...
docker-bas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3afc91fba9 | ||
|
|
0689419d25 | ||
|
|
e64c32c0bd | ||
|
|
ec819dde3b | ||
|
|
fdf4bb5087 | ||
|
|
f67d16268c | ||
|
|
684b543aa1 | ||
|
|
5bef19064b | ||
|
|
743ba62bd5 | ||
|
|
f9a7748bd8 |
57
.github/workflows/base.yml
vendored
57
.github/workflows/base.yml
vendored
@@ -22,36 +22,38 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "121"
|
# - cuda: "121"
|
||||||
cuda_version: 12.1.1
|
# cuda_version: 12.1.1
|
||||||
cudnn_version: 8
|
# cudnn_version: 8
|
||||||
python_version: "3.10"
|
# python_version: "3.10"
|
||||||
pytorch: 2.3.1
|
# pytorch: 2.3.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "121"
|
# from_base_img: ""
|
||||||
cuda_version: 12.1.1
|
# from_base_tag: ""
|
||||||
cudnn_version: 8
|
# - cuda: "121"
|
||||||
python_version: "3.11"
|
# cuda_version: 12.1.1
|
||||||
pytorch: 2.3.1
|
# cudnn_version: 8
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
# python_version: "3.11"
|
||||||
- cuda: "124"
|
# pytorch: 2.3.1
|
||||||
cuda_version: 12.4.1
|
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
cudnn_version: ""
|
# from_base_img: ""
|
||||||
python_version: "3.10"
|
# from_base_tag: ""
|
||||||
pytorch: 2.4.1
|
# - cuda: "124"
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
# cuda_version: 12.4.1
|
||||||
- cuda: "124"
|
# cudnn_version: ""
|
||||||
cuda_version: 12.4.1
|
# python_version: "3.11"
|
||||||
cudnn_version: ""
|
# pytorch: 2.4.1
|
||||||
python_version: "3.11"
|
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
pytorch: 2.4.1
|
# from_base_img: ""
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
# from_base_tag: ""
|
||||||
- cuda: "124"
|
- cuda: "124"
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
from_base_img: nvcr.io/nvidia/pytorch
|
||||||
|
from_base_tag: 24.10-py3
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -61,7 +63,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
images: |
|
images: |
|
||||||
winglian/axolotl-base
|
winglian/axolotl-base
|
||||||
axolotlai/axolotl-base
|
# axolotlai/axolotl-base
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
@@ -74,7 +76,8 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/Dockerfile-base
|
file: ./docker/Dockerfile-base
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: true
|
||||||
|
# push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
build-args: |
|
build-args: |
|
||||||
@@ -84,3 +87,5 @@ jobs:
|
|||||||
PYTHON_VERSION=${{ matrix.python_version }}
|
PYTHON_VERSION=${{ matrix.python_version }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
||||||
|
BASE_IMAGE=${{ matrix.from_base_img || '' }}
|
||||||
|
BASE_TAG=${{ matrix.from_base_tag || '' }}
|
||||||
|
|||||||
95
.github/workflows/tests.yml
vendored
95
.github/workflows/tests.yml
vendored
@@ -148,63 +148,64 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests-1st:
|
# docker-e2e-tests-1st:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
# if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# # this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
# runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
# timeout-minutes: 90
|
||||||
needs: [pre-commit, pytest, pytest-sdist]
|
# needs: [pre-commit, pytest, pytest-sdist]
|
||||||
|
#
|
||||||
strategy:
|
# strategy:
|
||||||
fail-fast: false
|
# fail-fast: false
|
||||||
matrix:
|
# matrix:
|
||||||
include:
|
# include:
|
||||||
- cuda: 124
|
# - cuda: 124
|
||||||
cuda_version: 12.4.1
|
# cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
# python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
# pytorch: 2.4.1
|
||||||
num_gpus: 1
|
# num_gpus: 1
|
||||||
axolotl_extras:
|
# axolotl_extras:
|
||||||
steps:
|
# steps:
|
||||||
- name: Checkout
|
# - name: Checkout
|
||||||
uses: actions/checkout@v4
|
# uses: actions/checkout@v4
|
||||||
- name: Install Python
|
# - name: Install Python
|
||||||
uses: actions/setup-python@v5
|
# uses: actions/setup-python@v5
|
||||||
with:
|
# with:
|
||||||
python-version: "3.10"
|
# python-version: "3.10"
|
||||||
- name: Install Modal
|
# - name: Install Modal
|
||||||
run: |
|
# run: |
|
||||||
python -m pip install --upgrade pip
|
# python -m pip install --upgrade pip
|
||||||
pip install modal==0.63.64 jinja2
|
# pip install modal==0.63.64 jinja2
|
||||||
- name: Update env vars
|
# - name: Update env vars
|
||||||
run: |
|
# run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
# echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
# echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
# echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
# echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
# echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
# echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
# - name: Run tests job on Modal
|
||||||
run: |
|
# run: |
|
||||||
modal run cicd.tests
|
# modal run cicd.tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
# needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||||
|
needs: [pre-commit, pytest]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 121
|
# - cuda: 121
|
||||||
cuda_version: 12.1.1
|
# cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
# python_version: "3.10"
|
||||||
pytorch: 2.3.1
|
# pytorch: 2.3.1
|
||||||
num_gpus: 1
|
# num_gpus: 1
|
||||||
axolotl_extras: mamba-ssm
|
# axolotl_extras: mamba-ssm
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -224,7 +225,7 @@ jobs:
|
|||||||
pip install modal==0.63.64 jinja2
|
pip install modal==0.63.64 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=pr-2139-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
FROM winglian/axolotl-base:{{ BASE_TAG }}
|
||||||
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
ARG BASE_IMAGE=axolotlai/axolotl-base
|
||||||
ARG BASE_TAG=main-base
|
ARG BASE_TAG=main-base
|
||||||
FROM axolotlai/axolotl-base:$BASE_TAG
|
FROM $BASE_IMAGE:$BASE_TAG
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ ARG CUDNN_VERSION="8"
|
|||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
ARG BASE_IMAGE=nvidia/cuda
|
||||||
|
ARG DEFAULT_TAG=${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||||
|
ARG BASE_TAG=""
|
||||||
|
FROM ${BASE_IMAGE:-nvidia/cuda}:${BASE_TAG:-${DEFAULT_TAG}} AS base-builder
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ ARG BASE_TAG=main
|
|||||||
FROM axolotlai/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ ARG BASE_TAG=main
|
|||||||
FROM axolotlai/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.46.3
|
transformers==4.47.0
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.45.0
|
bitsandbytes==0.45.0
|
||||||
accelerate==1.1.0
|
accelerate==1.2.0
|
||||||
datasets==3.1.0
|
datasets==3.1.0
|
||||||
deepspeed==0.15.4
|
deepspeed==0.15.4
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
@@ -42,7 +42,7 @@ s3fs>=2024.5.0
|
|||||||
gcsfs>=2024.5.0
|
gcsfs>=2024.5.0
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl==0.12.0
|
trl==0.12.1
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
|
|||||||
@@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
env_capabilities={
|
env_capabilities={
|
||||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
|
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -957,13 +957,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float]) -> None:
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logs (`Dict[str, float]`):
|
logs (`Dict[str, float]`):
|
||||||
The values to log.
|
The values to log.
|
||||||
|
start_time (`Optional[float]`):
|
||||||
|
The start of training.
|
||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
@@ -971,7 +973,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
return super().log(logs)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
@@ -1155,6 +1157,18 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1163,6 +1177,18 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1171,6 +1197,45 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# train metrics should have no prefix, eval should have 'eval_'
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
# accumulate average metrics from sums and lengths
|
||||||
|
for split in ["chosen", "rejected"]:
|
||||||
|
if f"count/{split}" in self._stored_metrics[train_eval]:
|
||||||
|
count_sum = (
|
||||||
|
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
for metric in ["rewards", "logps", "logits"]:
|
||||||
|
logs[f"{prefix}{metric}/{split}"] = (
|
||||||
|
torch.Tensor(
|
||||||
|
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||||
|
)
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
/ count_sum
|
||||||
|
)
|
||||||
|
# delete obsolete metric
|
||||||
|
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||||
|
del self._stored_metrics[train_eval][f"count/{split}"]
|
||||||
|
# calculate reward margin
|
||||||
|
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
||||||
|
logs[f"{prefix}rewards/margins"] = (
|
||||||
|
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
||||||
|
)
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1179,6 +1244,18 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1187,6 +1264,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
207
src/axolotl/monkeypatch/trainer_grad_accum.py
Normal file
207
src/axolotl/monkeypatch/trainer_grad_accum.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""
|
||||||
|
fix for FSDP gradient accumulation
|
||||||
|
see https://github.com/huggingface/transformers/pull/35128
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers import LlamaForCausalLM
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
||||||
|
|
||||||
|
ORIGINAL_CONTEXT_CODE = """
|
||||||
|
with self.compute_loss_context_manager():
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
loss = self.compute_loss(model, inputs)
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_CONTEXT_CODE = """
|
||||||
|
with self.compute_loss_context_manager():
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(model, inputs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
ORIGINAL_LLAMA_FCLM_CODE = """
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_LLAMA_FCLM_CODE = """
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||||
|
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_step_code() -> str:
|
||||||
|
training_step = inspect.getsource(
|
||||||
|
Trainer.training_step # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return training_step
|
||||||
|
|
||||||
|
|
||||||
|
def check_training_step_is_patchable() -> bool:
|
||||||
|
training_step = get_training_step_code()
|
||||||
|
training_step, _ = detab_code(training_step)
|
||||||
|
return ORIGINAL_CONTEXT_CODE in training_step
|
||||||
|
|
||||||
|
|
||||||
|
def patch_training_step_for_ga():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for gradient accumulation
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_step = get_training_step_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer._original_training_step = training_step # pylint: disable=protected-access
|
||||||
|
training_step, _ = detab_code(training_step)
|
||||||
|
if ORIGINAL_CONTEXT_CODE not in training_step:
|
||||||
|
return
|
||||||
|
# assert (
|
||||||
|
# ORIGINAL_CONTEXT_CODE in training_step
|
||||||
|
# ), "Original training_step code not found"
|
||||||
|
|
||||||
|
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
||||||
|
training_step = training_step.replace(
|
||||||
|
"def training_step(",
|
||||||
|
"def _fixed_training_step(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.trainer
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.trainer):
|
||||||
|
if item in training_step:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.trainer import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching training_step")
|
||||||
|
Trainer.training_step = ( # pylint: disable=protected-access
|
||||||
|
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_forward_code() -> str:
|
||||||
|
forward = inspect.getsource(
|
||||||
|
LlamaForCausalLM.forward # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_is_patchable() -> bool:
|
||||||
|
forward = get_model_forward_code()
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
|
return ORIGINAL_LLAMA_FCLM_CODE in forward
|
||||||
|
|
||||||
|
|
||||||
|
def patch_forward_for_ga():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for gradient accumulation
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
forward = get_model_forward_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
|
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
|
||||||
|
return
|
||||||
|
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
|
||||||
|
|
||||||
|
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
|
||||||
|
forward = forward.replace(
|
||||||
|
"def forward(",
|
||||||
|
"def _fixed_forward(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
|
if item in forward:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching forward")
|
||||||
|
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
|
||||||
|
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
@@ -9,10 +9,7 @@ import torch
|
|||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||||
LlamaFlashAttention2,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
@@ -55,11 +52,6 @@ def original_apply_o(self, hidden_states):
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
def get_forward_code() -> str:
|
|
||||||
forward = inspect.getsource(LlamaForCausalLM.forward)
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_self_attn_code() -> str:
|
def get_self_attn_code() -> str:
|
||||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||||
return forward
|
return forward
|
||||||
@@ -102,12 +94,22 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
|||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
try:
|
||||||
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||||
|
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
||||||
|
except AttributeError:
|
||||||
|
return code, ""
|
||||||
return code, spaces
|
return code, spaces
|
||||||
|
|
||||||
|
|
||||||
|
self_attn_lora_patched = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def patch_self_attn_lora():
|
def patch_self_attn_lora():
|
||||||
|
global self_attn_lora_patched # pylint: disable=global-statement
|
||||||
|
if self_attn_lora_patched:
|
||||||
|
# prevent patching multiple times
|
||||||
|
return
|
||||||
self_attn_forward = get_self_attn_code()
|
self_attn_forward = get_self_attn_code()
|
||||||
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
||||||
self_attn_forward
|
self_attn_forward
|
||||||
@@ -139,6 +141,7 @@ def patch_self_attn_lora():
|
|||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
self_attn_lora_patched = True
|
||||||
LOG.info("patching unsloth attn lora", main_process_only=True)
|
LOG.info("patching unsloth attn lora", main_process_only=True)
|
||||||
LlamaFlashAttention2.forward = (
|
LlamaFlashAttention2.forward = (
|
||||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ def normalize_config(cfg):
|
|||||||
cfg.is_llama_derived_model = (
|
cfg.is_llama_derived_model = (
|
||||||
(
|
(
|
||||||
hasattr(model_config, "model_type")
|
hasattr(model_config, "model_type")
|
||||||
and model_config.model_type == ["llama", "mllama_text_model"]
|
and model_config.model_type in ["llama", "mllama_text_model"]
|
||||||
)
|
)
|
||||||
or cfg.is_llama_derived_model
|
or cfg.is_llama_derived_model
|
||||||
or "llama" in cfg.base_model.lower()
|
or "llama" in cfg.base_model.lower()
|
||||||
|
|||||||
@@ -386,6 +386,15 @@ class ModelLoader:
|
|||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "llama":
|
||||||
|
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||||
|
patch_forward_for_ga,
|
||||||
|
patch_training_step_for_ga,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_forward_for_ga()
|
||||||
|
patch_training_step_for_ga()
|
||||||
|
|
||||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
shared pytest fixtures
|
shared pytest fixtures
|
||||||
"""
|
"""
|
||||||
import functools
|
import functools
|
||||||
|
import importlib
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -113,3 +115,30 @@ def temp_dir():
|
|||||||
yield _temp_dir
|
yield _temp_dir
|
||||||
# Clean up the directory after the test
|
# Clean up the directory after the test
|
||||||
shutil.rmtree(_temp_dir)
|
shutil.rmtree(_temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def cleanup_monkeypatches():
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||||
|
|
||||||
|
original_fa2_forward = LlamaFlashAttention2.forward
|
||||||
|
# monkey patches can happen inside the tests
|
||||||
|
yield
|
||||||
|
# Reset LlamaFlashAttention2 forward
|
||||||
|
LlamaFlashAttention2.forward = original_fa2_forward
|
||||||
|
|
||||||
|
# Reset other known monkeypatches
|
||||||
|
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||||
|
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
|
||||||
|
("transformers.trainer",),
|
||||||
|
("transformers.loss.loss_utils",),
|
||||||
|
]
|
||||||
|
for module_name_tuple in modules_to_reset:
|
||||||
|
module_name = module_name_tuple[0]
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
importlib.reload(sys.modules[module_name])
|
||||||
|
if len(module_name_tuple) > 1:
|
||||||
|
module_globals = module_name_tuple[1]
|
||||||
|
for module_global in module_globals:
|
||||||
|
globals().pop(module_global, None)
|
||||||
|
|||||||
@@ -36,6 +36,9 @@ class TestUnslothQLoRA:
|
|||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sample_packing": sample_packing,
|
"sample_packing": sample_packing,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
"unsloth_lora_mlp": True,
|
||||||
|
"unsloth_lora_qkv": True,
|
||||||
|
"unsloth_lora_o": True,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 16,
|
"lora_r": 16,
|
||||||
@@ -82,6 +85,9 @@ class TestUnslothQLoRA:
|
|||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
|
"unsloth_lora_mlp": True,
|
||||||
|
"unsloth_lora_qkv": True,
|
||||||
|
"unsloth_lora_o": True,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
@@ -133,6 +139,9 @@ class TestUnslothQLoRA:
|
|||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
|
"unsloth_lora_mlp": True,
|
||||||
|
"unsloth_lora_qkv": True,
|
||||||
|
"unsloth_lora_o": True,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
|
|||||||
25
tests/patched/test_llama_trainer_ga.py
Normal file
25
tests/patched/test_llama_trainer_ga.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||||
|
check_forward_is_patchable,
|
||||||
|
check_training_step_is_patchable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainerGAIntegration(unittest.TestCase):
|
||||||
|
"""llama monkeypatch integration tests."""
|
||||||
|
|
||||||
|
def test_train_step_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_training_step_is_patchable(),
|
||||||
|
"HF transformers Trainer.training_step has changed and isn't patchable",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_model_forward_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_forward_is_patchable(),
|
||||||
|
"HF transformers LlamaForCausalLM.forward has changed and isn't patchable",
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user