diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 5574d3d47..baf11acbc 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -94,9 +94,15 @@ def wrap_dataset_for_tokenized_prompt( if prompt_tokenizer.supports_batched: map_kwargs["batched"] = True - # Peek at the first example to get original column names - first_example = next(iter(dataset)) - original_columns = list(first_example.keys()) + # For IterableDataset, we need to get original columns to remove them. + # We'll peek at the first example using a separate iterator to avoid consuming the main one. + def peek_and_get_columns(): + # Create a fresh iterator just for peeking + temp_iter = iter(dataset) + first_example = next(temp_iter) + return list(first_example.keys()) + + original_columns = peek_and_get_columns() # Map the dataset and remove original columns # This ensures only tokenized columns remain diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 005312733..32f472cc7 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -471,8 +471,13 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): mp_start_method=cfg.sample_packing_mp_start_method or "fork", ) + # Remove length column only if it exists + dataset_for_loader = train_dataset + if "length" in train_dataset.column_names: + dataset_for_loader = train_dataset.remove_columns(["length"]) + data_loader = DataLoader( - train_dataset.remove_columns(["length"]), + dataset_for_loader, batch_sampler=sampler, ) data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index 1ac3b537e..b52fb3d08 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -25,7 +25,7 @@ def min_cfg(temp_dir): "liger_rms_norm": True, "liger_glu_activation": True, "torch_compile": True, - "chat_template": "llama3", + "chat_template": "qwen3", "kd_trainer": True, "kd_ce_alpha": 0.1, "kd_alpha": 0.9,