Tokenization open assistant (#1)
* refactor prompt tokenization to more easily support open assistant * add open assisstant handling, more logging, black formatting
This commit is contained in:
@@ -37,6 +37,7 @@ from axolotl.prompt_tokenizers import (
|
|||||||
ShareGPTPromptTokenizingStrategy,
|
ShareGPTPromptTokenizingStrategy,
|
||||||
LLAMA_DEFAULT_PAD_TOKEN,
|
LLAMA_DEFAULT_PAD_TOKEN,
|
||||||
GPTeacherPromptTokenizingStrategy,
|
GPTeacherPromptTokenizingStrategy,
|
||||||
|
OpenAssistantPromptTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
||||||
|
|
||||||
@@ -56,7 +57,15 @@ def setup_wandb_env_vars(cfg):
|
|||||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||||
|
|
||||||
|
|
||||||
def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
|
def load_model(
|
||||||
|
base_model,
|
||||||
|
base_model_config,
|
||||||
|
model_type,
|
||||||
|
tokenizer_type,
|
||||||
|
cfg,
|
||||||
|
adapter="lora",
|
||||||
|
inference: bool = False,
|
||||||
|
):
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
@@ -67,13 +76,17 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|||||||
if is_llama_derived_model and cfg.flash_attention:
|
if is_llama_derived_model and cfg.flash_attention:
|
||||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
if cfg.device not in ["mps", "cpu"] and inference is False:
|
||||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||||
|
|
||||||
logging.info("patching with flash attention")
|
logging.info("patching with flash attention")
|
||||||
replace_llama_attn_with_flash_attn()
|
replace_llama_attn_with_flash_attn()
|
||||||
|
|
||||||
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
|
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
|
||||||
try:
|
try:
|
||||||
if cfg.load_4bit:
|
if cfg.load_4bit:
|
||||||
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
|
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||||
|
replace_peft_model_with_int4_lora_model,
|
||||||
|
)
|
||||||
|
|
||||||
replace_peft_model_with_int4_lora_model()
|
replace_peft_model_with_int4_lora_model()
|
||||||
|
|
||||||
from peft import (
|
from peft import (
|
||||||
@@ -92,18 +105,26 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
cache_model_path = Path(snapshot_download(base_model))
|
cache_model_path = Path(snapshot_download(base_model))
|
||||||
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensors')) + list(cache_model_path.glob('*.bin'))
|
files = (
|
||||||
|
list(cache_model_path.glob("*.pt"))
|
||||||
|
+ list(cache_model_path.glob("*.safetensors"))
|
||||||
|
+ list(cache_model_path.glob("*.bin"))
|
||||||
|
)
|
||||||
if len(files) > 0:
|
if len(files) > 0:
|
||||||
model_path = str(files[0])
|
model_path = str(files[0])
|
||||||
else:
|
else:
|
||||||
logging.warning("unable to find a cached model file, this will likely fail...")
|
logging.warning(
|
||||||
|
"unable to find a cached model file, this will likely fail..."
|
||||||
|
)
|
||||||
model_path = str(cache_model_path)
|
model_path = str(cache_model_path)
|
||||||
model, tokenizer = load_llama_model_4bit_low_ram(
|
model, tokenizer = load_llama_model_4bit_low_ram(
|
||||||
base_model_config if base_model_config else base_model,
|
base_model_config if base_model_config else base_model,
|
||||||
model_path,
|
model_path,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
|
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
|
||||||
is_v1_model=cfg.gptq_model_v1 if cfg.gptq_model_v1 is not None else True,
|
is_v1_model=cfg.gptq_model_v1
|
||||||
|
if cfg.gptq_model_v1 is not None
|
||||||
|
else True,
|
||||||
)
|
)
|
||||||
load_in_8bit = False
|
load_in_8bit = False
|
||||||
elif is_llama_derived_model:
|
elif is_llama_derived_model:
|
||||||
@@ -120,7 +141,11 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
)
|
)
|
||||||
except:
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||||
|
)
|
||||||
|
logging.exception(e)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
load_in_8bit=cfg.load_in_8bit,
|
load_in_8bit=cfg.load_in_8bit,
|
||||||
@@ -145,7 +170,6 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|||||||
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
||||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -165,7 +189,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
model = PeftModel.from_pretrained(model, cfg.lora_model_dir, device_map = cfg.device_map, torch_dtype=torch.float16)
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
cfg.lora_model_dir,
|
||||||
|
device_map=cfg.device_map,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
@@ -174,9 +203,11 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
|
|||||||
|
|
||||||
if cfg.load_4bit:
|
if cfg.load_4bit:
|
||||||
# Scales to half
|
# Scales to half
|
||||||
logging.info('Fitting 4bit scales and zeros to half')
|
logging.info("Fitting 4bit scales and zeros to half")
|
||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
|
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
|
||||||
|
type(m)
|
||||||
|
):
|
||||||
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
||||||
m.zeros = m.zeros.half()
|
m.zeros = m.zeros.half()
|
||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
@@ -236,37 +267,44 @@ def check_dataset_labels(dataset, tokenizer):
|
|||||||
|
|
||||||
|
|
||||||
def do_inference(cfg, model, tokenizer):
|
def do_inference(cfg, model, tokenizer):
|
||||||
tokenizer.add_special_tokens({'unk_token': '<unk>'})
|
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
||||||
tokenizer.add_special_tokens({'bos_token': '<s>'})
|
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
||||||
tokenizer.add_special_tokens({'eos_token': '</s>'})
|
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
||||||
|
|
||||||
instruction = "Tell me a joke about dromedaries."
|
instruction = "Tell me a joke about dromedaries."
|
||||||
input = ""
|
input = ""
|
||||||
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(instruction=instruction, input=input)
|
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(
|
||||||
|
instruction=instruction, input=input
|
||||||
|
)
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# gc = GenerationConfig() # TODO swap out and use this
|
# gc = GenerationConfig() # TODO swap out and use this
|
||||||
generated = model.generate(inputs=batch["input_ids"].to("cuda"),
|
generated = model.generate(
|
||||||
do_sample=True, use_cache=True,
|
inputs=batch["input_ids"].to("cuda"),
|
||||||
repetition_penalty=1.1,
|
do_sample=True,
|
||||||
max_new_tokens=100,
|
use_cache=True,
|
||||||
temperature=0.9,
|
repetition_penalty=1.1,
|
||||||
top_p=0.95,
|
max_new_tokens=100,
|
||||||
top_k=40,
|
temperature=0.9,
|
||||||
return_dict_in_generate=True,
|
top_p=0.95,
|
||||||
output_attentions=False,
|
top_k=40,
|
||||||
output_hidden_states=False,
|
return_dict_in_generate=True,
|
||||||
output_scores=False)
|
output_attentions=False,
|
||||||
print(tokenizer.decode(generated['sequences'].cpu().tolist()[0]))
|
output_hidden_states=False,
|
||||||
|
output_scores=False,
|
||||||
|
)
|
||||||
|
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
yaml_files = [file for file in path.glob("*.yml")]
|
yaml_files = [file for file in path.glob("*.yml")]
|
||||||
|
|
||||||
if not yaml_files:
|
if not yaml_files:
|
||||||
raise ValueError("No YAML config files found in the specified directory. Are you using a .yml extension?")
|
raise ValueError(
|
||||||
|
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
||||||
|
)
|
||||||
|
|
||||||
print("Choose a YAML file:")
|
print("Choose a YAML file:")
|
||||||
for idx, file in enumerate(yaml_files):
|
for idx, file in enumerate(yaml_files):
|
||||||
@@ -376,6 +414,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
config: Path = Path("configs/"),
|
config: Path = Path("configs/"),
|
||||||
prepare_ds_only: bool = False,
|
prepare_ds_only: bool = False,
|
||||||
@@ -420,7 +459,13 @@ def train(
|
|||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
logging.info("loading model, tokenizer, and lora_config...")
|
logging.info("loading model, tokenizer, and lora_config...")
|
||||||
model, tokenizer, lora_config = load_model(
|
model, tokenizer, lora_config = load_model(
|
||||||
cfg.base_model, cfg.base_model_config, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
|
cfg.base_model,
|
||||||
|
cfg.base_model_config,
|
||||||
|
cfg.model_type,
|
||||||
|
cfg.tokenizer_type,
|
||||||
|
cfg,
|
||||||
|
adapter=cfg.adapter,
|
||||||
|
inference=("inference" in kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "inference" in kwargs:
|
if "inference" in kwargs:
|
||||||
@@ -428,10 +473,26 @@ def train(
|
|||||||
do_inference(cfg, model, tokenizer)
|
do_inference(cfg, model, tokenizer)
|
||||||
return
|
return
|
||||||
|
|
||||||
max_packed_sequence_len = cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
max_packed_sequence_len = (
|
||||||
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
|
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||||
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
|
)
|
||||||
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
max_packed_sequence_len = min(
|
||||||
|
max_packed_sequence_len, cfg.sequence_len
|
||||||
|
) # make sure we don't accidentally set it larger than sequence_len
|
||||||
|
ds_hash = str(
|
||||||
|
md5(
|
||||||
|
(
|
||||||
|
str(max_packed_sequence_len)
|
||||||
|
+ "@"
|
||||||
|
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
||||||
|
).encode("utf-8")
|
||||||
|
).hexdigest()
|
||||||
|
)
|
||||||
|
prepared_ds_path = (
|
||||||
|
Path(cfg.dataset_prepared_path) / ds_hash
|
||||||
|
if cfg.dataset_prepared_path
|
||||||
|
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
||||||
|
)
|
||||||
|
|
||||||
if any(prepared_ds_path.glob("*")):
|
if any(prepared_ds_path.glob("*")):
|
||||||
logging.info("Loading prepared dataset from disk...")
|
logging.info("Loading prepared dataset from disk...")
|
||||||
@@ -464,9 +525,18 @@ def train(
|
|||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
|
elif d.type == "oasst":
|
||||||
|
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||||
|
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
elif d.type == "gpteacher":
|
elif d.type == "gpteacher":
|
||||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||||
GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
GPTeacherPrompter(),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
@@ -476,13 +546,17 @@ def train(
|
|||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
|
else:
|
||||||
|
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
||||||
constant_len_dataset = ConstantLengthDataset(
|
constant_len_dataset = ConstantLengthDataset(
|
||||||
tokenizer, datasets, seq_length=max_packed_sequence_len,
|
tokenizer,
|
||||||
|
datasets,
|
||||||
|
seq_length=max_packed_sequence_len,
|
||||||
)
|
)
|
||||||
logging.info("merging, packing, shuffling, and splitting master dataset")
|
logging.info("merging, packing, shuffling, and splitting master dataset")
|
||||||
dataset = Dataset.from_list(
|
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
|
||||||
[_ for _ in constant_len_dataset]
|
test_size=cfg.val_set_size, shuffle=True, seed=42
|
||||||
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
|
logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
|
||||||
@@ -525,7 +599,9 @@ def train(
|
|||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||||
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
logging.info(
|
||||||
|
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
|
||||||
|
)
|
||||||
model.save_pretrained(cfg.output_dir)
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,14 +31,18 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||||
|
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
full_prompt = self._tokenize_full_prompt(prompt)
|
instruction, input, response = self.parse_instruction_fields(prompt)
|
||||||
|
full_prompt = self._build_full_prompt(instruction, input, response)
|
||||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
user_prompt = self.prompter.build_prompt(
|
user_prompt = self.prompter.build_prompt(
|
||||||
prompt["instruction"],
|
instruction,
|
||||||
prompt["input"] if "input" in prompt else "",
|
input,
|
||||||
)
|
)
|
||||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||||
@@ -49,11 +53,11 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
return tokenized_full_prompt
|
return tokenized_full_prompt
|
||||||
|
|
||||||
def _tokenize_full_prompt(self, prompt):
|
def _build_full_prompt(self, instruction, input, response):
|
||||||
return self.prompter.build_prompt(
|
return self.prompter.build_prompt(
|
||||||
prompt["instruction"],
|
instruction,
|
||||||
prompt["input"] if "input" in prompt else "",
|
input,
|
||||||
prompt["output"],
|
response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True):
|
def _tokenize(self, prompt, add_eos_token=True):
|
||||||
@@ -76,11 +80,29 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
|
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
def _tokenize_full_prompt(self, prompt):
|
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||||
return self.prompter.build_prompt(
|
return (
|
||||||
prompt["instruction"],
|
prompt["instruction"],
|
||||||
prompt["input"],
|
prompt["input"] if "input" in prompt else "",
|
||||||
|
prompt["output"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
|
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||||
|
return (
|
||||||
|
prompt["INSTRUCTION"],
|
||||||
|
"",
|
||||||
|
prompt["RESPONSE"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
|
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||||
|
return (
|
||||||
|
prompt["instruction"],
|
||||||
|
prompt["input"] if "input" in prompt else "",
|
||||||
prompt["response"],
|
prompt["response"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user