fixes for alpaca w chatml, and don't include attention_mask w mistral for flash attention (#728)
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
"""Module for Alpaca prompt strategy classes"""
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
|
|||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
prompt_style = PromptStyle.CHAT.value
|
||||||
|
if ds_cfg and "conversation" in ds_cfg:
|
||||||
|
prompt_style = ds_cfg["conversation"]
|
||||||
|
|
||||||
return AlpacaPromptTokenizingStrategy(
|
return AlpacaPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
AlpacaPrompter(prompt_style),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
|
|||||||
@@ -423,7 +423,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Phi doesn't want the attention_mask feature when training
|
# Phi doesn't want the attention_mask feature when training
|
||||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
|
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
||||||
|
cfg.is_mistral_derived_model and cfg.flash_attention
|
||||||
|
):
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||||
|
|||||||
Reference in New Issue
Block a user