Respect sequence_len in config for type: llama2_chat (#926)

* Respect sequence_len in config for `type: llama2_chat`

It was hardcoded to `4096` I am not sure why?  This updates it to pull from the config. 

cc: @winglian

* Update llama2_chat.py

* apply black formatting

* fix tokenizer

* update test data

* lint fixtures
This commit is contained in:
Hamel Husain
2023-12-12 09:39:22 -08:00
committed by GitHub
parent 7fabc4d95e
commit f1de29dd1e
2 changed files with 4 additions and 3 deletions

View File

@@ -81,8 +81,9 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sequence_len = 4096
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
self.tokenizer.add_special_tokens(
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")}
)
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
def tokenize_prompt(self, prompt):

File diff suppressed because one or more lines are too long