diff --git a/requirements.txt b/requirements.txt index e1c7700de..b055b7c87 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ datasets==3.6.0 deepspeed>=0.17.0 trl==0.18.1 hf_xet==1.1.2 +mistral-common[hf-hub]==1.6.0 optimum==1.16.2 hf_transfer @@ -67,5 +68,3 @@ schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.3 - -mistral-common[hf-hub]==1.6.0 diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 9f1d9500d..cd4db4e72 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -64,6 +64,8 @@ class TokenizedPromptDataset(Dataset): desc="Strategy Filtering Rows", ) + import ipdb; ipdb.set_trace() + return dataset.map( self.prompt_tokenizer.tokenize_prompt, num_proc=num_proc, diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 8671e69f8..ae7ee5baa 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -10,7 +10,6 @@ from huggingface_hub import hf_hub_download from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import ( - MODEL_NAME_TO_TOKENIZER_CLS, MistralTokenizer, ) from transformers import ( @@ -366,15 +365,6 @@ class TokenizerConfiguration: self.cfg = cfg self.model_config = load_model_config(cfg) - def detect_by_model_name_mapping(self) -> bool: - # Extract model name from path - model = self.cfg.base_model.split("/")[-1] - for model_name in MODEL_NAME_TO_TOKENIZER_CLS.keys(): - if model_name in model.lower(): - return True - - return False - def load_mistral_tokenizer(self) -> MistralTokenizerWrapper: """Load Mistral tokenizer from model configuration.""" # Instantiate Mistral tokenizer diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index cb1a1ba4e..dd882710f 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -67,6 +67,8 @@ class PromptTokenizingStrategy(abc.ABC): LOG.warning("Empty text requested for tokenization.") return empty + import ipdb; ipdb.set_trace() + result = self.tokenizer( prompt, truncation=True, diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 88c78174b..d09c42704 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -486,6 +486,8 @@ def get_dataset_wrapper( f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}" ) + import ipdb; ipdb.set_trace() + if ( isinstance(dataset, Dataset) and "input_ids" in dataset.features