Disable caching on --disable_caching in CLI (#1110)

* Disable caching on `--disable_caching` in CLI

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Casper
2024-01-13 10:13:35 +01:00
committed by GitHub
parent 304ea1b814
commit d66b10141e
2 changed files with 16 additions and 2 deletions

View File

@@ -7,6 +7,7 @@ from pathlib import Path
import fire
import transformers
from colorama import Fore
from datasets import disable_caching
from axolotl.cli import (
check_accelerate_default_config,
@@ -28,9 +29,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
parsed_cli_args, remaining_args = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if (
remaining_args.get("disable_caching") is not None
and remaining_args["disable_caching"]
):
disable_caching()
if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED

View File

@@ -6,6 +6,7 @@ from pathlib import Path
import fire
import transformers
from datasets import disable_caching
from axolotl.cli import (
check_accelerate_default_config,
@@ -28,9 +29,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
parsed_cli_args, remaining_args = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if (
remaining_args.get("disable_caching") is not None
and remaining_args["disable_caching"]
):
disable_caching()
if parsed_cfg.rl:
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else: