diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index e0eeea6b3..21436bf41 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -7,6 +7,7 @@ from pathlib import Path import fire import transformers from colorama import Fore +from datasets import disable_caching from axolotl.cli import ( check_accelerate_default_config, @@ -28,9 +29,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs): check_accelerate_default_config() check_user_token() parser = transformers.HfArgumentParser((PreprocessCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( + parsed_cli_args, remaining_args = parser.parse_args_into_dataclasses( return_remaining_strings=True ) + + if ( + remaining_args.get("disable_caching") is not None + and remaining_args["disable_caching"] + ): + disable_caching() if not parsed_cfg.dataset_prepared_path: msg = ( Fore.RED diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 2248784df..6b5f49686 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -6,6 +6,7 @@ from pathlib import Path import fire import transformers +from datasets import disable_caching from axolotl.cli import ( check_accelerate_default_config, @@ -28,9 +29,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs): check_accelerate_default_config() check_user_token() parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( + parsed_cli_args, remaining_args = parser.parse_args_into_dataclasses( return_remaining_strings=True ) + + if ( + remaining_args.get("disable_caching") is not None + and remaining_args["disable_caching"] + ): + disable_caching() if parsed_cfg.rl: dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) else: