From 5783839c6e29bb148041338772040c85aaae4646 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 9 Jun 2024 20:10:17 -0400 Subject: [PATCH] download model weights on preprocess step (#1693) --- src/axolotl/cli/preprocess.py | 7 +++++++ src/axolotl/common/cli.py | 1 + 2 files changed, 8 insertions(+) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index e7b3596a4..f43277e49 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -7,7 +7,9 @@ from typing import Union import fire import transformers +from accelerate import init_empty_weights from colorama import Fore +from transformers import AutoModelForCausalLM from axolotl.cli import ( check_accelerate_default_config, @@ -71,6 +73,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): else: load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if parsed_cli_args.download: + model_name = parsed_cfg.base_model + with init_empty_weights(): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + LOG.info( Fore.GREEN + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`" diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 636a23ba5..c96f8f81f 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -40,6 +40,7 @@ class PreprocessCliArgs: debug_text_only: bool = field(default=False) debug_num_examples: int = field(default=1) prompter: Optional[str] = field(default=None) + download: Optional[bool] = field(default=True) def load_model_and_tokenizer(