From 0e9bfa6dee5bdadf47df52e9a343e0b3bad85619 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 24 Jan 2025 19:53:54 +0000 Subject: [PATCH] small fixes, improvements --- src/axolotl/cli/evaluate.py | 4 +-- src/axolotl/evaluate.py | 6 ++--- src/axolotl/utils/models.py | 50 +++++++++---------------------------- 3 files changed, 17 insertions(+), 43 deletions(-) diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index c89715719..9370921fd 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -19,7 +19,7 @@ from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) -def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: +def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]: """ Evaluates a `transformers` model by first loading the dataset(s) specified in the `axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes @@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - evaluate(cfg=cfg, dataset_meta=dataset_meta) + return evaluate(cfg=cfg, dataset_meta=dataset_meta) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index 8d9ddc6ab..db8490432 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -4,7 +4,7 @@ import csv import os import sys from pathlib import Path -from typing import Dict, Optional +from typing import Optional import torch from accelerate.logging import get_logger @@ -26,7 +26,7 @@ LOG = get_logger("axolotl.evaluate") def evaluate_dataset( trainer, dataset, dataset_type: str, flash_optimum: bool = False -) -> Optional[Dict[str, float]]: +) -> Optional[dict[str, float]]: """Helper function to evaluate a single dataset safely. Args: @@ -61,7 +61,7 @@ def evaluate_dataset( return metrics -def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: +def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]: """ Evaluate a model on training and validation datasets diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4c612f48f..48e2c6558 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -709,45 +709,19 @@ class ModelLoader: if self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass - - if self.cfg.diff_attention: - self.model_kwargs[ - "attn_implementation" - ] = "differential_flash_attention_2" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "differential_flash_attention_2" - ) - else: - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - elif self.cfg.sdp_attention: - if self.cfg.diff_attention: - self.model_kwargs["attn_implementation"] = "differential_sdpa" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "differential_sdpa" - ) - else: - self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "sdpa" - ) - elif self.cfg.eager_attention: - if self.cfg.diff_attention: - self.model_kwargs["attn_implementation"] = "differential_eager" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "differential_eager" - ) - else: - self.model_kwargs["attn_implementation"] = "eager" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) - elif self.cfg.diff_attention: - self.model_kwargs["attn_implementation"] = "differential_eager" + self.model_kwargs["attn_implementation"] = "flash_attention_2" self.model_config._attn_implementation = ( # pylint: disable=protected-access - "differential_eager" + "flash_attention_2" + ) + elif self.cfg.sdp_attention: + self.model_kwargs["attn_implementation"] = "sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "sdpa" + ) + elif self.cfg.eager_attention: + self.model_kwargs["attn_implementation"] = "eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "eager" ) if self.cfg.low_cpu_mem_usage: