Files
axolotl/src/axolotl/evaluate.py
Dan Saunders 79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00

150 lines
4.6 KiB
Python

"""Module for evaluating models."""
import csv
import os
import sys
from pathlib import Path
from typing import Dict, Optional
import torch
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.train import (
TrainDatasetMeta,
setup_model_and_tokenizer,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
LOG = get_logger(__name__)
def evaluate_dataset(
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]:
"""Helper function to evaluate a single dataset.
Args:
trainer: The trainer instance.
dataset: Dataset to evaluate.
dataset_type: Type of dataset ('train' or 'eval').
flash_optimum: Whether to use flash optimum.
Returns:
Dictionary of metrics or None if dataset is None.
"""
if dataset is None:
return None
LOG.info(f"Starting {dataset_type} set evaluation...")
if flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)
else:
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)
LOG.info(f"{dataset_type.capitalize()} set evaluation completed!")
LOG.info(f"{dataset_type.capitalize()} Metrics:")
for key, value in metrics.items():
LOG.info(f"{key}: {value}")
return metrics
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
dataset_meta: Dataset metadata containing training and evaluation datasets.
Returns:
Dictionary mapping metric names to their values.
"""
# Load tokenizer, processor and model
LOG.debug("loading model for evaluation...")
model, tokenizer, _, processor = setup_model_and_tokenizer(cfg)
# Get datasets
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Set up trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,
)
# Evaluate datasets
all_metrics = {}
train_metrics = evaluate_dataset(trainer, train_dataset, "train", cfg.flash_optimum)
eval_metrics = evaluate_dataset(trainer, eval_dataset, "eval", cfg.flash_optimum)
if train_metrics:
all_metrics.update(train_metrics)
if eval_metrics:
all_metrics.update(eval_metrics)
# Save metrics to CSV if output directory is specified and we have metrics
if cfg.output_dir and (train_metrics or eval_metrics):
output_dir = Path(cfg.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
metrics_file = output_dir / "eval_summary.csv"
with metrics_file.open("w", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
writer.writerow(["metric", "training", "validation"])
# Get unique metric names (removing prefixes) from available metrics
train_metric_names = {
k.replace("train_", ""): k for k in (train_metrics or {})
}
eval_metric_names = {
k.replace("eval_", ""): k for k in (eval_metrics or {})
}
all_metric_names = sorted(
set(train_metric_names.keys()) | set(eval_metric_names.keys())
)
for metric_name in all_metric_names:
train_value = (
train_metrics.get(train_metric_names.get(metric_name, ""), "")
if train_metrics
else ""
)
eval_value = (
eval_metrics.get(eval_metric_names.get(metric_name, ""), "")
if eval_metrics
else ""
)
writer.writerow([metric_name, train_value, eval_value])
LOG.info(f"Evaluation results saved to {metrics_file}")
del model
del tokenizer
cleanup_distributed()
return all_metrics