Compare commits
1 Commits
fa3-hopper
...
wait-distr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
459f407e69 |
11
.github/workflows/base.yml
vendored
11
.github/workflows/base.yml
vendored
@@ -47,18 +47,11 @@ jobs:
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
suffix: "-hopper"
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -94,7 +87,7 @@ jobs:
|
||||
context: .
|
||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
build-args: |
|
||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
|
||||
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -31,11 +31,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -99,11 +94,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
11
.github/workflows/multi-gpu-e2e.yml
vendored
11
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -32,25 +32,21 @@ jobs:
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
suffix: "-hopper"
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
@@ -72,6 +68,7 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
|
||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -295,7 +295,6 @@ jobs:
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
# Run this job first as a gate for running the remainder of the test matrix
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
@@ -342,8 +341,6 @@ jobs:
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
# Only run the remainder of the matrix if the first e2e check passed;
|
||||
# this is to save on wasted compute costs for known failures that get caught in the first run
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
|
||||
strategy:
|
||||
@@ -368,12 +365,6 @@ jobs:
|
||||
pytorch: 2.7.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -32,11 +32,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \
|
||||
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=90 * 60,
|
||||
cpu=16.0,
|
||||
cpu=8.0,
|
||||
memory=131072 * N_GPUS,
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG CUDA_VERSION="12.4.1"
|
||||
ARG CUDNN_VERSION=""
|
||||
ARG CUDA_VERSION="11.8.0"
|
||||
ARG CUDNN_VERSION="8"
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
|
||||
@@ -7,16 +7,16 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION A
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="2.5.1"
|
||||
ARG CUDA="124"
|
||||
ARG PYTHON_VERSION="3.10"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG CUDA="118"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
@@ -38,10 +38,6 @@ RUN git lfs install --skip-repo && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
|
||||
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||
elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
||||
pip3 install flash-attn==2.7.4.post1; \
|
||||
fi
|
||||
|
||||
@@ -633,9 +633,7 @@ weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
adam_beta2:
|
||||
adam_beta3: # only used for CAME Optimizer
|
||||
adam_epsilon:
|
||||
adam_epsilon2: # only used for CAME Optimizer
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
|
||||
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
||||
format them.
|
||||
|
||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
|
||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
|
||||
format):
|
||||
|
||||
```json
|
||||
@@ -120,12 +120,6 @@ axolotl train my_training.yml
|
||||
|
||||
## Common Tasks {#sec-common-tasks}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
The same yaml file is used for training, inference, and merging.
|
||||
|
||||
:::
|
||||
|
||||
### Testing Your Model {#sec-testing}
|
||||
|
||||
After training, test your model:
|
||||
@@ -134,16 +128,6 @@ After training, test your model:
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
```
|
||||
|
||||
More details can be found in [Inference](inference.qmd).
|
||||
|
||||
### Using a UI {#sec-ui}
|
||||
|
||||
Launch a Gradio interface:
|
||||
|
||||
```bash
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
### Preprocessing Data {#sec-preprocessing}
|
||||
|
||||
For large datasets, preprocess first:
|
||||
@@ -152,22 +136,14 @@ For large datasets, preprocess first:
|
||||
axolotl preprocess my_training.yml
|
||||
```
|
||||
|
||||
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
|
||||
### Using a UI {#sec-ui}
|
||||
|
||||
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
|
||||
|
||||
### Merging LoRA weights {#sec-merging-lora}
|
||||
|
||||
To merge the LoRA weights back into the base model, run:
|
||||
Launch a Gradio interface:
|
||||
|
||||
```bash
|
||||
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
The merged model will be saved in the `{output_dir}/merged` directory.
|
||||
|
||||
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
|
||||
|
||||
## Next Steps {#sec-next-steps}
|
||||
|
||||
Now that you have the basics, you might want to:
|
||||
@@ -180,7 +156,6 @@ Now that you have the basics, you might want to:
|
||||
Check our other guides for details on these topics:
|
||||
|
||||
- [Configuration Guide](config.qmd) - Full configuration options
|
||||
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources
|
||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||
- [Multi-GPU Training](multi-gpu.qmd)
|
||||
- [Multi-Node Training](multi-node.qmd)
|
||||
|
||||
@@ -387,12 +387,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||
if self.cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
||||
if self.cfg.adam_beta3:
|
||||
training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3
|
||||
if self.cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
||||
if self.cfg.adam_epsilon2:
|
||||
training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2
|
||||
if self.cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
||||
|
||||
@@ -717,7 +713,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
|
||||
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
|
||||
beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999)
|
||||
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
|
||||
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
|
||||
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
|
||||
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
||||
@@ -1174,8 +1170,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.eval_dataset:
|
||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
if self.cfg.rl is not RLType.GRPO:
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
self.cfg.precompute_ref_log_probs
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -13,7 +14,7 @@ from accelerate.utils import (
|
||||
broadcast_object_list,
|
||||
gather,
|
||||
gather_object,
|
||||
is_peft_available,
|
||||
is_peft_model,
|
||||
)
|
||||
from datasets import Dataset, IterableDataset
|
||||
from torch import nn
|
||||
@@ -29,13 +30,15 @@ from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_peft_available
|
||||
from trl import GRPOTrainer
|
||||
from trl.data_utils import (
|
||||
apply_chat_template,
|
||||
is_conversational,
|
||||
maybe_apply_chat_template,
|
||||
)
|
||||
from trl.extras.profiling import profiling_context
|
||||
from trl.extras.profiling import profiling_context, profiling_decorator
|
||||
from trl.import_utils import is_deepspeed_available
|
||||
from trl.models import unwrap_model_for_generation
|
||||
from trl.trainer.grpo_config import GRPOConfig
|
||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
@@ -49,12 +52,62 @@ if is_peft_available():
|
||||
# pylint: disable=unused-import
|
||||
from peft import PeftConfig
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
@profiling_decorator
|
||||
def _move_model_to_vllm(self):
|
||||
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||
gather_if_zero3 = (
|
||||
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
||||
)
|
||||
|
||||
if is_peft_model(self.model):
|
||||
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
||||
# adapters in a sharded manner is not supported.
|
||||
with gather_if_zero3(list(self.model.parameters())):
|
||||
self.model.merge_adapter()
|
||||
|
||||
# Update vLLM weights while parameters are gathered
|
||||
for name, param in self.model.named_parameters():
|
||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
||||
name = (
|
||||
name.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", "")
|
||||
)
|
||||
if self.model.prefix in name:
|
||||
continue
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
if "original_module" in name:
|
||||
continue
|
||||
name = name.replace("modules_to_save.default.", "")
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
|
||||
# Unmerge adapters while parameters are still gathered
|
||||
self.model.unmerge_adapter()
|
||||
# Parameters will automatically be repartitioned when exiting the context
|
||||
else:
|
||||
# For non-PEFT models, simply gather and update each parameter individually.
|
||||
for name, param in self.model.named_parameters():
|
||||
with gather_if_zero3([param]):
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.update_named_param(name, param.data)
|
||||
|
||||
# Reset cache on main process
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
|
||||
@@ -227,19 +227,6 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
adam_beta3: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
adam_epsilon2: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
|
||||
@@ -20,15 +20,25 @@ from cut_cross_entropy.transformers.utils import (
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
_CONFIG_FOR_DOC,
|
||||
COHERE_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
@@ -17,15 +17,25 @@ from cut_cross_entropy.transformers.utils import (
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
@@ -20,11 +20,15 @@ from torch import nn
|
||||
from transformers.cache_utils import Cache, HybridCache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA3_INPUTS_DOCSTRING,
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
@@ -34,6 +38,10 @@ _PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -162,6 +170,10 @@ def cce_forward(
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
@@ -19,9 +19,15 @@ from transformers.modeling_outputs import (
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
_CONFIG_FOR_DOC,
|
||||
LLAMA_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
@@ -30,6 +36,10 @@ _PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
||||
@@ -16,12 +16,22 @@ from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama4.modeling_llama4 import (
|
||||
_CONFIG_FOR_DOC,
|
||||
LLAMA4_INPUTS_DOCSTRING,
|
||||
Llama4CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
@@ -150,6 +160,9 @@ def cce_forward(
|
||||
)
|
||||
|
||||
|
||||
@replace_return_docstrings(
|
||||
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||
|
||||
@@ -19,11 +19,15 @@ from transformers.models.mistral3.modeling_mistral3 import (
|
||||
Mistral3CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
_CONFIG_FOR_DOC,
|
||||
MISTRAL_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
@@ -31,6 +35,10 @@ _PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
|
||||
@@ -13,10 +13,16 @@ from cut_cross_entropy.transformers.utils import (
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||
_CONFIG_FOR_DOC,
|
||||
QWEN2MOE_INPUTS_DOCSTRING,
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
@@ -25,6 +31,10 @@ _PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
||||
@@ -14,12 +14,22 @@ from cut_cross_entropy.transformers.utils import (
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
_CONFIG_FOR_DOC,
|
||||
QWEN2_VL_INPUTS_DOCSTRING,
|
||||
Qwen2VLCausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
||||
@@ -12,13 +12,20 @@ from cut_cross_entropy.transformers.utils import (
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||
_CONFIG_FOR_DOC,
|
||||
QWEN3_MOE_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
@@ -27,6 +34,10 @@ _PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
||||
@@ -14,6 +14,10 @@ from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
|
||||
# @replace_return_docstrings(
|
||||
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
# )
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
||||
@@ -13,11 +13,21 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||
from transformers.models.jamba.modeling_jamba import (
|
||||
_CONFIG_FOR_DOC,
|
||||
JAMBA_INPUTS_DOCSTRING,
|
||||
HybridMambaAttentionDynamicCache,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
||||
@@ -7,16 +7,24 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA3_INPUTS_DOCSTRING,
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def new_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
||||
@@ -289,16 +289,18 @@ def save_trained_model(
|
||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
else:
|
||||
if cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
|
||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||
# TODO: add integration support so this can be implemented completely within the plugin
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""MLFlow module for trainer callbacks"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -17,11 +16,6 @@ if TYPE_CHECKING:
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
def should_log_artifacts() -> bool:
|
||||
truths = ["TRUE", "1", "YES"]
|
||||
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||
# pylint: disable=duplicate-code
|
||||
"""Callback to save axolotl config to mlflow"""
|
||||
@@ -38,18 +32,13 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
if should_log_artifacts():
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the MLflow artifacts."
|
||||
)
|
||||
else:
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||
LOG.info(
|
||||
"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
|
||||
"The Axolotl config has been saved to the MLflow artifacts."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
||||
|
||||
@@ -72,7 +72,6 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
num_proc=cfg.dataset_processes,
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -484,7 +484,7 @@ def get_dataset_wrapper(
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -629,49 +629,6 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
use_fa3 = False
|
||||
if self.cfg.use_flash_attention_3 is True:
|
||||
use_fa3 = True
|
||||
elif self.cfg.use_flash_attention_3 == "auto":
|
||||
if torch.cuda.get_device_capability() >= (9, 0):
|
||||
# FA3 is only available on Hopper GPUs and newer
|
||||
use_fa3 = True
|
||||
if not importlib.util.find_spec("flash_attn_interface"):
|
||||
use_fa3 = False
|
||||
if use_fa3 and not importlib.util.find_spec("flash_attn_interface"):
|
||||
# this can happen when use_flash_attention_3 is explicity set to True
|
||||
# and flash_attn_interface is not installed
|
||||
raise ModuleNotFoundError(
|
||||
"Please install the flash_attn_interface library to use Flash Attention 3.x"
|
||||
)
|
||||
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
|
||||
from flash_attn_interface import (
|
||||
flash_attn_varlen_func as flash_attn_varlen_func_v3,
|
||||
)
|
||||
|
||||
def flash_attn_func_v3_wrapper(*args, **kwargs):
|
||||
kwargs.pop("dropout_p", None)
|
||||
if "softmax_scale" in kwargs and len(args) >= 4:
|
||||
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
|
||||
args = (*args[:3],) + args[4:]
|
||||
return flash_attn_func_v3(*args, **kwargs)[0]
|
||||
|
||||
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
|
||||
kwargs.pop("dropout_p", None)
|
||||
if "softmax_scale" in kwargs and len(args) >= 4:
|
||||
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
|
||||
args = (*args[:3],) + args[4:]
|
||||
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
|
||||
|
||||
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
||||
flash_attn_func_v3_wrapper
|
||||
)
|
||||
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||
flash_attn_varlen_func_v3_wrapper
|
||||
)
|
||||
LOG.info("Switched to Flash Attention v3")
|
||||
|
||||
self.patch_attention()
|
||||
|
||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
@@ -742,7 +699,6 @@ class ModelLoader:
|
||||
|
||||
patch_mllama()
|
||||
|
||||
# TODO deprecate soon
|
||||
if self.model_config.model_type == "btlm":
|
||||
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
||||
replace_btlm_attn_with_flash_attn,
|
||||
@@ -750,7 +706,6 @@ class ModelLoader:
|
||||
|
||||
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
|
||||
|
||||
# TODO deprecate soon
|
||||
if (
|
||||
self.model_config.model_type == "stablelm_epoch"
|
||||
and self.cfg.sample_packing
|
||||
|
||||
@@ -233,7 +233,6 @@ class AxolotlInputConfig(
|
||||
flash_attn_fuse_qkv: bool | None = None
|
||||
flash_attn_fuse_mlp: bool | None = None
|
||||
flash_optimum: bool | None = None
|
||||
use_flash_attention_3: Literal["auto"] | bool | None = None
|
||||
|
||||
eager_attention: bool | None = None
|
||||
|
||||
|
||||
@@ -421,7 +421,6 @@ def temp_dir():
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup_monkeypatches():
|
||||
import transformers.modeling_flash_attention_utils
|
||||
from transformers import Trainer
|
||||
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
||||
LlamaAttention,
|
||||
@@ -435,19 +434,6 @@ def cleanup_monkeypatches():
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
original_trainer_training_step = Trainer.training_step
|
||||
original_fa_func = None
|
||||
original_fa_varlen_func = None
|
||||
if (
|
||||
importlib.util.find_spec("flash_attn")
|
||||
and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func")
|
||||
and hasattr(
|
||||
transformers.modeling_flash_attention_utils, "flash_attn_varlen_func"
|
||||
)
|
||||
):
|
||||
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
|
||||
original_fa_varlen_func = (
|
||||
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
|
||||
)
|
||||
# monkey patches can happen inside the tests
|
||||
yield
|
||||
# Reset LlamaFlashAttention2 forward
|
||||
@@ -458,11 +444,6 @@ def cleanup_monkeypatches():
|
||||
original_trainer_inner_training_loop
|
||||
)
|
||||
Trainer.training_step = original_trainer_training_step
|
||||
if original_fa_func:
|
||||
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
|
||||
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||
original_fa_varlen_func
|
||||
)
|
||||
|
||||
# Reset other known monkeypatches
|
||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||
@@ -477,7 +458,6 @@ def cleanup_monkeypatches():
|
||||
("transformers.trainer",),
|
||||
("transformers", ["Trainer"]),
|
||||
("transformers.loss.loss_utils",),
|
||||
("transformers.modeling_flash_attention_utils",),
|
||||
]
|
||||
for module_name_tuple in modules_to_reset:
|
||||
module_name = module_name_tuple[0]
|
||||
|
||||
@@ -166,6 +166,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"""
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="flaky test")
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
@@ -230,6 +231,8 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
||||
# "VLLM_USE_V1": "0",
|
||||
}
|
||||
vllm_process = start_vllm(
|
||||
cfg.base_model,
|
||||
@@ -263,6 +266,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
finally:
|
||||
recursive_kill(vllm_process)
|
||||
|
||||
@pytest.mark.skip(reason="flaky test")
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
@@ -321,6 +325,8 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
||||
# "VLLM_USE_V1": "0",
|
||||
}
|
||||
vllm_process = start_vllm(
|
||||
cfg.base_model,
|
||||
|
||||
@@ -101,13 +101,7 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_flash_attention_3",
|
||||
[False, "auto"],
|
||||
)
|
||||
def test_lora_ddp_packed(
|
||||
self, temp_dir, gradient_accumulation_steps, use_flash_attention_3
|
||||
):
|
||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -144,7 +138,6 @@ class TestMultiGPULlama:
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"use_flash_attention_3": use_flash_attention_3,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ E2E tests for packed training
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
@@ -13,17 +14,18 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_tensorboard
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestPackedLlama:
|
||||
class TestPackedLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Packed training of llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_loss_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
|
||||
Reference in New Issue
Block a user