From 1a229b09017ef119c989cece363ac62b0f336928 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 5 May 2025 16:39:33 -0400 Subject: [PATCH] add colab callback to fix inference post train --- src/axolotl/core/trainer_builder.py | 6 ++++++ src/axolotl/utils/callbacks/__init__.py | 26 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 31ee3cccf..4df700d90 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -21,6 +21,7 @@ import importlib.util import inspect import logging import math +import os import sys from abc import abstractmethod from pathlib import Path @@ -72,6 +73,7 @@ from axolotl.utils.callbacks import ( SaveBetterTransformerModelCallback, bench_eval_callback_factory, causal_lm_bench_eval_callback_factory, + colab_inference_post_train_callback, log_prediction_callback_factory, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory @@ -293,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: callbacks.append(lisa_callback_factory(trainer)) + if any("COLAB_" in key for key in os.environ): + ColabCallback = colab_inference_post_train_callback(trainer) + callbacks.append(ColabCallback(self.cfg)) + callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) return callbacks diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 21b14d986..6afa8a51c 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -868,3 +868,29 @@ class GCCallback(TrainerCallback): ): torch.cuda.empty_cache() gc.collect() + + +def colab_inference_post_train_callback(trainer: Trainer): + class ColabCallback(TrainerCallback): + """Callback to prep model for inference on Google Colab""" + + def __init__(self, cfg): + self.gpu_name = torch.cuda.get_device_name(0) + self.cfg = cfg + + def on_train_end( + self, args, state, control, **kwargs + ): # pylint: disable=unused-argument + """ + handle T4 gpu, we need to convert attention to eager for inference + """ + if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention: + trainer.model.eval() + trainer.model.config._attn_implementation = ( # pylint: disable=protected-access + "eager" + ) + trainer.model.gradient_checkpointing_disable() + trainer.model.config.use_cache = True + trainer.model.eval() + + return ColabCallback