make mlflow optional (#1317)

* make mlflow optional

* fix xformers

don't patch swiglu if xformers not working
fix the check for xformers swiglu

* fix install of xformers with extra index url for docker builds

* fix docker build arg quoting
This commit is contained in:
Wing Lian
2024-02-26 11:41:33 -05:00
committed by GitHub
parent 5cf226e177
commit 5894f0e57e
12 changed files with 86 additions and 41 deletions

View File

@@ -18,6 +18,7 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.1.2
axolotl_extras: axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
is_latest: true is_latest: true
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
@@ -54,6 +55,7 @@ jobs:
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }} PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
file: ./docker/Dockerfile file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: | tags: |

View File

@@ -70,6 +70,7 @@ jobs:
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.1.2
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
@@ -87,11 +88,13 @@ jobs:
# Set up build arguments # Set up build arguments
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
CUDA="${{ matrix.cuda }}" CUDA="${{ matrix.cuda }}"
AXOLOTL_ARGS="${{ matrix.axolotl_args }}"
PYTORCH_VERSION="${{ matrix.pytorch }}" PYTORCH_VERSION="${{ matrix.pytorch }}"
# Build the Docker image # Build the Docker image
docker build . \ docker build . \
--file ./docker/Dockerfile-tests \ --file ./docker/Dockerfile-tests \
--build-arg BASE_TAG=$BASE_TAG \ --build-arg BASE_TAG=$BASE_TAG \
--build-arg AXOLOTL_ARGS="$AXOLOTL_ARGS" \
--build-arg CUDA=$CUDA \ --build-arg CUDA=$CUDA \
--build-arg GITHUB_REF=$GITHUB_REF \ --build-arg GITHUB_REF=$GITHUB_REF \
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \ --build-arg PYTORCH_VERSION=$PYTORCH_VERSION \

View File

@@ -3,6 +3,7 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118" ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1" ARG PYTORCH_VERSION="2.0.1"
@@ -20,9 +21,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \ pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -3,6 +3,7 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118" ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1" ARG PYTORCH_VERSION="2.0.1"
@@ -24,9 +25,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \ pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -21,7 +21,6 @@ hf_transfer
colorama colorama
numba numba
numpy>=1.24.4 numpy>=1.24.4
mlflow
# qlora things # qlora things
evaluate==0.4.1 evaluate==0.4.1
scipy scipy

View File

@@ -82,5 +82,8 @@ setup(
"auto-gptq": [ "auto-gptq": [
"auto-gptq==0.5.1", "auto-gptq==0.5.1",
], ],
"mlflow": [
"mlflow",
],
}, },
) )

View File

@@ -5,6 +5,7 @@ Builder for the training args and trainer
import abc import abc
import importlib import importlib
import importlib.util
import logging import logging
import math import math
import sys import sys
@@ -34,7 +35,6 @@ from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
LossWatchDogCallback, LossWatchDogCallback,
SaveAxolotlConfigtoMlflowCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
@@ -62,6 +62,10 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder") LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str): if isinstance(tag_names, str):
tag_names = [tag_names] tag_names = [tag_names]
@@ -648,7 +652,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
callbacks.append( callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.use_mlflow: if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append( callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
) )

View File

@@ -44,6 +44,18 @@ except ImportError:
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def is_xformers_swiglu_available() -> bool:
from xformers.ops.common import get_xformers_operator
try:
get_xformers_operator("swiglu_packedw")()
return True
except RuntimeError as exc:
if "No such operator xformers::swiglu_packedw " in str(exc):
return False
return True
def replace_llama_mlp_with_swiglu(model): def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, LlamaMLP): if isinstance(module, LlamaMLP):

View File

@@ -9,7 +9,6 @@ from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, List from typing import TYPE_CHECKING, Dict, List
import evaluate import evaluate
import mlflow
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
@@ -42,8 +41,8 @@ from axolotl.utils.distributed import (
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100 IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks")
class EvalFirstStepCallback( class EvalFirstStepCallback(
@@ -756,31 +755,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
except (FileNotFoundError, ConnectionError) as err: except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}") LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control return control
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
"""Callback to save axolotl config to mlflow"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
return control

View File

@@ -0,0 +1,44 @@
"""MLFlow module for trainer callbacks"""
import logging
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
import mlflow
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
# pylint: disable=duplicate-code
"""Callback to save axolotl config to mlflow"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
return control

View File

@@ -512,11 +512,12 @@ def load_model(
if cfg.flash_attention and not inference: if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu, replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused, replace_llama_qkv_with_fused,
) )
if cfg.flash_attn_fuse_mlp: if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("patching with SwiGLU") LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(model) replace_llama_mlp_with_swiglu(model)

View File

@@ -57,9 +57,9 @@ class TestFusedLlama(unittest.TestCase):
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"max_steps": 20, "max_steps": 10,
"save_steps": 10, "save_steps": 5,
"eval_steps": 10, "eval_steps": 5,
} }
) )
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():