Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
cf00e20270 experiment w latent space
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-18 05:47:26 -04:00
2 changed files with 96 additions and 0 deletions

View File

@@ -1,8 +1,49 @@
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
import logging
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
LOG = logging.getLogger("axolotl.prompt_strategies.alpaca_instruct")
class LatentSpaceAlpacaPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
"""
Overrides the tokenization to include additional padding tokens as
latent space on the inputs
"""
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
# pylint: disable=duplicate-code
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
# latent space
if add_eos_token and not strip_bos_token:
result["input_ids"].extend([self.tokenizer.pad_token_id] * 100)
result["labels"] = result["input_ids"].copy()
return result
def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
@@ -20,3 +61,12 @@ def load_no_prompt(tokenizer, cfg):
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_latent_space(tokenizer, cfg):
return LatentSpaceAlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -31,6 +31,52 @@ def load_guanaco(tokenizer, cfg):
)
def load_latent_space(tokenizer, cfg):
return LatentSpaceShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class LatentSpaceShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
latent space padded sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
# pylint: disable=duplicate-code
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
# latent space
if add_eos_token and not strip_bos_token:
result["input_ids"].extend([self.tokenizer.pad_token_id] * 100)
result["labels"] = result["input_ids"].copy()
return result
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row