diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py index f43d4a938..e1d4fc763 100644 --- a/tests/e2e/integrations/test_llm_compressor.py +++ b/tests/e2e/integrations/test_llm_compressor.py @@ -5,6 +5,7 @@ E2E smoke tests for LLMCompressorPlugin integration from pathlib import Path import pytest +from llmcompressor import active_session from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets @@ -89,9 +90,12 @@ class TestLLMCompressorIntegration: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, dataset_meta=dataset_meta) - check_model_output_exists(temp_dir, cfg) - _check_llmcompressor_model_outputs(temp_dir, save_compressed) + try: + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + _check_llmcompressor_model_outputs(temp_dir, save_compressed) + finally: + active_session().reset() def _check_llmcompressor_model_outputs(temp_dir, save_compressed):