download model weights on preprocess step (#1693)

This commit is contained in:
Wing Lian
2024-06-09 20:10:17 -04:00
committed by GitHub
parent cbbf039a46
commit 5783839c6e
2 changed files with 8 additions and 0 deletions

View File

@@ -7,7 +7,9 @@ from typing import Union
import fire
import transformers
from accelerate import init_empty_weights
from colorama import Fore
from transformers import AutoModelForCausalLM
from axolotl.cli import (
check_accelerate_default_config,
@@ -71,6 +73,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.download:
model_name = parsed_cfg.base_model
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"

View File

@@ -40,6 +40,7 @@ class PreprocessCliArgs:
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
def load_model_and_tokenizer(