Files
axolotl/tests/e2e/utils.py
Wing Lian 130ef7c51a Various fixes for VLMs (#3063)
* fix to not use batch feature indexing

* more vlm fixes

* use AutoModelForImageTextToText

* add example yaml and need num2words for chat template

* improve handling of adding image tokens to conversation

* add lfm2-vl support

* update the lfm readme

* fix markdown and add rtol for loss checks

* feat: add smolvlm2 processing strat

* fix: check for causal-conv1d in lfm models

* feat: add docs for lfm2

* feat: add new models and tips to docs

* feat: add smolvlm2 docs and remove extra dep

* chore: update docs

* feat: add video instructions

* chore: cleanup

* chore: comments

* fix: typo

* feat: add usage stats

* chore: refactor

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-15 10:52:57 -04:00

193 lines
5.2 KiB
Python

"""
helper utils for tests
"""
import os
import shutil
import tempfile
import unittest
from functools import wraps
from pathlib import Path
import torch
from packaging import version
from tbparse import SummaryReader
from axolotl.utils.dict import DictDefault
def with_temp_dir(test_func):
@wraps(test_func)
def wrapper(*args, **kwargs):
# Create a temporary directory
temp_dir = tempfile.mkdtemp()
try:
# Pass the temporary directory to the test function
test_func(*args, temp_dir=temp_dir, **kwargs)
finally:
# Clean up the directory after the test
shutil.rmtree(temp_dir)
return wrapper
def most_recent_subdir(path):
base_path = Path(path)
subdirectories = [d for d in base_path.iterdir() if d.is_dir()]
if not subdirectories:
return None
subdir = max(subdirectories, key=os.path.getctime)
return subdir
def require_torch_2_4_1(test_case):
"""
Decorator marking a test that requires torch >= 2.5.1
"""
def is_min_2_4_1():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.4.1")
return unittest.skipUnless(is_min_2_4_1(), "test requires torch>=2.4.1")(test_case)
def require_torch_2_5_1(test_case):
"""
Decorator marking a test that requires torch >= 2.5.1
"""
def is_min_2_5_1():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.5.1")
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
def require_torch_2_6_0(test_case):
"""
Decorator marking a test that requires torch >= 2.6.0
"""
def is_min_2_6_0():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.6.0")
return unittest.skipUnless(is_min_2_6_0(), "test requires torch>=2.6.0")(test_case)
def require_torch_2_7_0(test_case):
"""
Decorator marking a test that requires torch >= 2.7.0
"""
def is_min_2_7_0():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.7.0")
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
def require_torch_lt_2_6_0(test_case):
"""
Decorator marking a test that requires torch < 2.6.0
"""
def is_max_2_6_0():
torch_version = version.parse(torch.__version__)
return torch_version < version.parse("2.6.0")
return unittest.skipUnless(is_max_2_6_0(), "test requires torch<2.6.0")(test_case)
def require_vllm(test_case):
"""
Decorator marking a test that requires a vllm to be installed
"""
def is_vllm_installed():
try:
import vllm # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return unittest.skipUnless(
is_vllm_installed(), "test requires vllm to be installed"
)(test_case)
def require_llmcompressor(test_case):
"""
Decorator marking a test that requires a llmcompressor to be installed
"""
def is_llmcompressor_installed():
try:
import llmcompressor # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return unittest.skipUnless(
is_llmcompressor_installed(), "test requires llmcompressor to be installed"
)(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)
def require_hopper(test_case):
return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case)
def check_tensorboard(
temp_run_dir: str,
tag: str,
lt_val: float,
assertion_err: str,
rtol: float = 0.02,
) -> None:
"""
helper function to parse and check tensorboard logs
"""
tb_log_path = most_recent_subdir(temp_run_dir)
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
lt_val = (1 + rtol) * lt_val
if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
else:
assert df.value.values[-1] < lt_val, assertion_err
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
"""
helper function to check if a model output file exists after training
checks based on adapter or not and if safetensors saves are enabled or not
"""
if cfg.save_safetensors:
if not cfg.adapter:
assert (Path(temp_dir) / "model.safetensors").exists()
else:
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
else:
# check for both, b/c in trl, it often defaults to saving safetensors
if not cfg.adapter:
assert (Path(temp_dir) / "pytorch_model.bin").exists() or (
Path(temp_dir) / "model.safetensors"
).exists()
else:
assert (Path(temp_dir) / "adapter_model.bin").exists() or (
Path(temp_dir) / "adapter_model.safetensors"
).exists()