small fixes, improvements
This commit is contained in:
@@ -19,7 +19,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluates a `transformers` model by first loading the dataset(s) specified in the
|
Evaluates a `transformers` model by first loading the dataset(s) specified in the
|
||||||
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
|
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
|
||||||
@@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
evaluate(cfg=cfg, dataset_meta=dataset_meta)
|
return evaluate(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import csv
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
@@ -26,7 +26,7 @@ LOG = get_logger("axolotl.evaluate")
|
|||||||
|
|
||||||
def evaluate_dataset(
|
def evaluate_dataset(
|
||||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||||
) -> Optional[Dict[str, float]]:
|
) -> Optional[dict[str, float]]:
|
||||||
"""Helper function to evaluate a single dataset safely.
|
"""Helper function to evaluate a single dataset safely.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -61,7 +61,7 @@ def evaluate_dataset(
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets
|
Evaluate a model on training and validation datasets
|
||||||
|
|
||||||
|
|||||||
@@ -709,45 +709,19 @@ class ModelLoader:
|
|||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
if self.cfg.diff_attention:
|
|
||||||
self.model_kwargs[
|
|
||||||
"attn_implementation"
|
|
||||||
] = "differential_flash_attention_2"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"differential_flash_attention_2"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"flash_attention_2"
|
|
||||||
)
|
|
||||||
elif self.cfg.sdp_attention:
|
|
||||||
if self.cfg.diff_attention:
|
|
||||||
self.model_kwargs["attn_implementation"] = "differential_sdpa"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"differential_sdpa"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model_kwargs["attn_implementation"] = "sdpa"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"sdpa"
|
|
||||||
)
|
|
||||||
elif self.cfg.eager_attention:
|
|
||||||
if self.cfg.diff_attention:
|
|
||||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"differential_eager"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model_kwargs["attn_implementation"] = "eager"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"eager"
|
|
||||||
)
|
|
||||||
elif self.cfg.diff_attention:
|
|
||||||
self.model_kwargs["attn_implementation"] = "differential_eager"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"differential_eager"
|
"flash_attention_2"
|
||||||
|
)
|
||||||
|
elif self.cfg.sdp_attention:
|
||||||
|
self.model_kwargs["attn_implementation"] = "sdpa"
|
||||||
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"sdpa"
|
||||||
|
)
|
||||||
|
elif self.cfg.eager_attention:
|
||||||
|
self.model_kwargs["attn_implementation"] = "eager"
|
||||||
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"eager"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.low_cpu_mem_usage:
|
if self.cfg.low_cpu_mem_usage:
|
||||||
|
|||||||
Reference in New Issue
Block a user