small fixes, improvements
This commit is contained in:
@@ -19,7 +19,7 @@ from axolotl.utils.dict import DictDefault
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
|
||||
"""
|
||||
Evaluates a `transformers` model by first loading the dataset(s) specified in the
|
||||
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
|
||||
@@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
evaluate(cfg=cfg, dataset_meta=dataset_meta)
|
||||
return evaluate(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ import csv
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
@@ -26,7 +26,7 @@ LOG = get_logger("axolotl.evaluate")
|
||||
|
||||
def evaluate_dataset(
|
||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||
) -> Optional[Dict[str, float]]:
|
||||
) -> Optional[dict[str, float]]:
|
||||
"""Helper function to evaluate a single dataset safely.
|
||||
|
||||
Args:
|
||||
@@ -61,7 +61,7 @@ def evaluate_dataset(
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets
|
||||
|
||||
|
||||
@@ -709,45 +709,19 @@ class ModelLoader:
|
||||
if self.cfg.flash_attention:
|
||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
pass
|
||||
|
||||
if self.cfg.diff_attention:
|
||||
self.model_kwargs[
|
||||
"attn_implementation"
|
||||
] = "differential_flash_attention_2"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"differential_flash_attention_2"
|
||||
)
|
||||
else:
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
)
|
||||
elif self.cfg.sdp_attention:
|
||||
if self.cfg.diff_attention:
|
||||
self.model_kwargs["attn_implementation"] = "differential_sdpa"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"differential_sdpa"
|
||||
)
|
||||
else:
|
||||
self.model_kwargs["attn_implementation"] = "sdpa"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"sdpa"
|
||||
)
|
||||
elif self.cfg.eager_attention:
|
||||
if self.cfg.diff_attention:
|
||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"differential_eager"
|
||||
)
|
||||
else:
|
||||
self.model_kwargs["attn_implementation"] = "eager"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
elif self.cfg.diff_attention:
|
||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"differential_eager"
|
||||
"flash_attention_2"
|
||||
)
|
||||
elif self.cfg.sdp_attention:
|
||||
self.model_kwargs["attn_implementation"] = "sdpa"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"sdpa"
|
||||
)
|
||||
elif self.cfg.eager_attention:
|
||||
self.model_kwargs["attn_implementation"] = "eager"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
|
||||
if self.cfg.low_cpu_mem_usage:
|
||||
|
||||
Reference in New Issue
Block a user