make mlflow optional (#1317)
* make mlflow optional * fix xformers don't patch swiglu if xformers not working fix the check for xformers swiglu * fix install of xformers with extra index url for docker builds * fix docker build arg quoting
This commit is contained in:
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -18,6 +18,7 @@ jobs:
|
|||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
@@ -54,6 +55,7 @@ jobs:
|
|||||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
|
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||||
file: ./docker/Dockerfile
|
file: ./docker/Dockerfile
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: |
|
tags: |
|
||||||
|
|||||||
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -70,6 +70,7 @@ jobs:
|
|||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.1.2
|
||||||
|
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -87,11 +88,13 @@ jobs:
|
|||||||
# Set up build arguments
|
# Set up build arguments
|
||||||
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
|
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
|
||||||
CUDA="${{ matrix.cuda }}"
|
CUDA="${{ matrix.cuda }}"
|
||||||
|
AXOLOTL_ARGS="${{ matrix.axolotl_args }}"
|
||||||
PYTORCH_VERSION="${{ matrix.pytorch }}"
|
PYTORCH_VERSION="${{ matrix.pytorch }}"
|
||||||
# Build the Docker image
|
# Build the Docker image
|
||||||
docker build . \
|
docker build . \
|
||||||
--file ./docker/Dockerfile-tests \
|
--file ./docker/Dockerfile-tests \
|
||||||
--build-arg BASE_TAG=$BASE_TAG \
|
--build-arg BASE_TAG=$BASE_TAG \
|
||||||
|
--build-arg AXOLOTL_ARGS="$AXOLOTL_ARGS" \
|
||||||
--build-arg CUDA=$CUDA \
|
--build-arg CUDA=$CUDA \
|
||||||
--build-arg GITHUB_REF=$GITHUB_REF \
|
--build-arg GITHUB_REF=$GITHUB_REF \
|
||||||
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
|
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ FROM winglian/axolotl-base:$BASE_TAG
|
|||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
|
ARG AXOLOTL_ARGS=""
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
ENV BNB_CUDA_VERSION=$CUDA
|
ENV BNB_CUDA_VERSION=$CUDA
|
||||||
ARG PYTORCH_VERSION="2.0.1"
|
ARG PYTORCH_VERSION="2.0.1"
|
||||||
@@ -20,9 +21,9 @@ WORKDIR /workspace/axolotl
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ FROM winglian/axolotl-base:$BASE_TAG
|
|||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
|
ARG AXOLOTL_ARGS=""
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
ENV BNB_CUDA_VERSION=$CUDA
|
ENV BNB_CUDA_VERSION=$CUDA
|
||||||
ARG PYTORCH_VERSION="2.0.1"
|
ARG PYTORCH_VERSION="2.0.1"
|
||||||
@@ -24,9 +25,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ hf_transfer
|
|||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4
|
numpy>=1.24.4
|
||||||
mlflow
|
|
||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -82,5 +82,8 @@ setup(
|
|||||||
"auto-gptq": [
|
"auto-gptq": [
|
||||||
"auto-gptq==0.5.1",
|
"auto-gptq==0.5.1",
|
||||||
],
|
],
|
||||||
|
"mlflow": [
|
||||||
|
"mlflow",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Builder for the training args and trainer
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
import importlib
|
import importlib
|
||||||
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
@@ -34,7 +35,6 @@ from axolotl.utils.callbacks import (
|
|||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoMlflowCallback,
|
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
@@ -62,6 +62,10 @@ except ImportError:
|
|||||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||||
|
|
||||||
|
|
||||||
|
def is_mlflow_available():
|
||||||
|
return importlib.util.find_spec("mlflow") is not None
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||||
if isinstance(tag_names, str):
|
if isinstance(tag_names, str):
|
||||||
tag_names = [tag_names]
|
tag_names = [tag_names]
|
||||||
@@ -648,7 +652,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
if self.cfg.use_mlflow:
|
if self.cfg.use_mlflow and is_mlflow_available():
|
||||||
|
from axolotl.utils.callbacks.mlflow_ import (
|
||||||
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
|
)
|
||||||
|
|
||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -44,6 +44,18 @@ except ImportError:
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def is_xformers_swiglu_available() -> bool:
|
||||||
|
from xformers.ops.common import get_xformers_operator
|
||||||
|
|
||||||
|
try:
|
||||||
|
get_xformers_operator("swiglu_packedw")()
|
||||||
|
return True
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if "No such operator xformers::swiglu_packedw " in str(exc):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_mlp_with_swiglu(model):
|
def replace_llama_mlp_with_swiglu(model):
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, LlamaMLP):
|
if isinstance(module, LlamaMLP):
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from tempfile import NamedTemporaryFile
|
|||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import mlflow
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@@ -42,8 +41,8 @@ from axolotl.utils.distributed import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
|
|
||||||
class EvalFirstStepCallback(
|
class EvalFirstStepCallback(
|
||||||
@@ -756,31 +755,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
|
||||||
"""Callback to save axolotl config to mlflow"""
|
|
||||||
|
|
||||||
def __init__(self, axolotl_config_path):
|
|
||||||
self.axolotl_config_path = axolotl_config_path
|
|
||||||
|
|
||||||
def on_train_begin(
|
|
||||||
self,
|
|
||||||
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
|
||||||
state: TrainerState, # pylint: disable=unused-argument
|
|
||||||
control: TrainerControl,
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
if is_main_process():
|
|
||||||
try:
|
|
||||||
with NamedTemporaryFile(
|
|
||||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
|
||||||
) as temp_file:
|
|
||||||
copyfile(self.axolotl_config_path, temp_file.name)
|
|
||||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
|
||||||
LOG.info(
|
|
||||||
"The Axolotl config has been saved to the MLflow artifacts."
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
|
||||||
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
|
||||||
return control
|
|
||||||
44
src/axolotl/utils/callbacks/mlflow_.py
Normal file
44
src/axolotl/utils/callbacks/mlflow_.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""MLFlow module for trainer callbacks"""
|
||||||
|
import logging
|
||||||
|
from shutil import copyfile
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
|
|
||||||
|
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
"""Callback to save axolotl config to mlflow"""
|
||||||
|
|
||||||
|
def __init__(self, axolotl_config_path):
|
||||||
|
self.axolotl_config_path = axolotl_config_path
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
|
||||||
|
state: TrainerState, # pylint: disable=unused-argument
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if is_main_process():
|
||||||
|
try:
|
||||||
|
with NamedTemporaryFile(
|
||||||
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
|
) as temp_file:
|
||||||
|
copyfile(self.axolotl_config_path, temp_file.name)
|
||||||
|
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||||
|
LOG.info(
|
||||||
|
"The Axolotl config has been saved to the MLflow artifacts."
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
|
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
||||||
|
return control
|
||||||
@@ -512,11 +512,12 @@ def load_model(
|
|||||||
|
|
||||||
if cfg.flash_attention and not inference:
|
if cfg.flash_attention and not inference:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
is_xformers_swiglu_available,
|
||||||
replace_llama_mlp_with_swiglu,
|
replace_llama_mlp_with_swiglu,
|
||||||
replace_llama_qkv_with_fused,
|
replace_llama_qkv_with_fused,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.flash_attn_fuse_mlp:
|
if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||||
LOG.info("patching with SwiGLU")
|
LOG.info("patching with SwiGLU")
|
||||||
replace_llama_mlp_with_swiglu(model)
|
replace_llama_mlp_with_swiglu(model)
|
||||||
|
|
||||||
|
|||||||
@@ -57,9 +57,9 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"max_steps": 20,
|
"max_steps": 10,
|
||||||
"save_steps": 10,
|
"save_steps": 5,
|
||||||
"eval_steps": 10,
|
"eval_steps": 5,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
|
|||||||
Reference in New Issue
Block a user