From b582d340b047345f74aad5567a6ccd4763a81308 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 20 Jan 2025 13:31:26 -0500 Subject: [PATCH] save tokenizer too --- src/axolotl/integrations/rrt/cli/convert.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index e408dc66d..b067e6b64 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -11,7 +11,7 @@ import torch from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards from safetensors.torch import save_file from tqdm import tqdm -from transformers import AutoConfig +from transformers import AutoConfig, AutoTokenizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from axolotl.integrations.rrt.modeling.modeling_rrt_llama import ( @@ -277,6 +277,7 @@ def convert_llama_to_rrt( ] config = AutoConfig.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) num_hidden_layers = config.num_hidden_layers if num_hidden_layers % recurse_layers != 0: raise ValueError( @@ -293,6 +294,7 @@ def convert_llama_to_rrt( } ) config.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth")) # create a new state_dict to store the RRT model weights