fixes for alpaca w chatml, and don't include attention_mask w mistral for flash attention (#728)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||
"""Module for Alpaca prompt strategy classes"""
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompt_style = PromptStyle.CHAT.value
|
||||
if ds_cfg and "conversation" in ds_cfg:
|
||||
prompt_style = ds_cfg["conversation"]
|
||||
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
AlpacaPrompter(prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -423,7 +423,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
)
|
||||
|
||||
# Phi doesn't want the attention_mask feature when training
|
||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
|
||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
||||
):
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||
|
||||
Reference in New Issue
Block a user