fix kd test

This commit is contained in:
Dan Saunders
2025-08-21 19:37:15 +00:00
parent 0caa24eab0
commit 4121bcbc33
3 changed files with 16 additions and 5 deletions

View File

@@ -94,9 +94,15 @@ def wrap_dataset_for_tokenized_prompt(
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
# Peek at the first example to get original column names
first_example = next(iter(dataset))
original_columns = list(first_example.keys())
# For IterableDataset, we need to get original columns to remove them.
# We'll peek at the first example using a separate iterator to avoid consuming the main one.
def peek_and_get_columns():
# Create a fresh iterator just for peeking
temp_iter = iter(dataset)
first_example = next(temp_iter)
return list(first_example.keys())
original_columns = peek_and_get_columns()
# Map the dataset and remove original columns
# This ensures only tokenized columns remain

View File

@@ -471,8 +471,13 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
)
# Remove length column only if it exists
dataset_for_loader = train_dataset
if "length" in train_dataset.column_names:
dataset_for_loader = train_dataset.remove_columns(["length"])
data_loader = DataLoader(
train_dataset.remove_columns(["length"]),
dataset_for_loader,
batch_sampler=sampler,
)
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size

View File

@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": True,
"chat_template": "llama3",
"chat_template": "qwen3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,