diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2e3728cc8..41c8fa9e5 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -197,6 +197,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset: eval_dataset = eval_dataset.remove_columns("attention_mask") + if cfg.model_config_type == "olmo": + LOG.info("dropping position_ids column") + train_dataset = train_dataset.remove_columns("position_ids") + if eval_dataset: + eval_dataset = eval_dataset.remove_columns("position_ids") + if cfg.model_config_type == "falcon": LOG.info("dropping token_type_ids column if it exists") if "token_type_ids" in train_dataset.column_names: