Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
1a229b0901 add colab callback to fix inference post train 2025-05-05 16:40:01 -04:00
2 changed files with 32 additions and 0 deletions

View File

@@ -21,6 +21,7 @@ import importlib.util
import inspect import inspect
import logging import logging
import math import math
import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
@@ -72,6 +73,7 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory, causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory, log_prediction_callback_factory,
) )
from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.lisa import lisa_callback_factory
@@ -293,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer)) callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks return callbacks

View File

@@ -868,3 +868,29 @@ class GCCallback(TrainerCallback):
): ):
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
def colab_inference_post_train_callback(trainer: Trainer):
class ColabCallback(TrainerCallback):
"""Callback to prep model for inference on Google Colab"""
def __init__(self, cfg):
self.gpu_name = torch.cuda.get_device_name(0)
self.cfg = cfg
def on_train_end(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
"""
handle T4 gpu, we need to convert attention to eager for inference
"""
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention:
trainer.model.eval()
trainer.model.config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
trainer.model.gradient_checkpointing_disable()
trainer.model.config.use_cache = True
trainer.model.eval()
return ColabCallback