diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a8..34b505ff1 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -196,7 +196,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset: eval_dataset = eval_dataset.remove_columns("attention_mask") - if cfg.model_config_type == "falcon": + if cfg.model_config_type in ["falcon", "mistral"]: LOG.info("dropping token_type_ids column if it exists") if "token_type_ids" in train_dataset.column_names: train_dataset = train_dataset.remove_columns("token_type_ids")