download model weights on preprocess step (#1693)
This commit is contained in:
@@ -7,7 +7,9 @@ from typing import Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from colorama import Fore
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
@@ -71,6 +73,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
if parsed_cli_args.download:
|
||||
model_name = parsed_cfg.base_model
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||
|
||||
@@ -40,6 +40,7 @@ class PreprocessCliArgs:
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
||||
Reference in New Issue
Block a user