fix kd test
This commit is contained in:
@@ -94,9 +94,15 @@ def wrap_dataset_for_tokenized_prompt(
|
||||
if prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
# Peek at the first example to get original column names
|
||||
first_example = next(iter(dataset))
|
||||
original_columns = list(first_example.keys())
|
||||
# For IterableDataset, we need to get original columns to remove them.
|
||||
# We'll peek at the first example using a separate iterator to avoid consuming the main one.
|
||||
def peek_and_get_columns():
|
||||
# Create a fresh iterator just for peeking
|
||||
temp_iter = iter(dataset)
|
||||
first_example = next(temp_iter)
|
||||
return list(first_example.keys())
|
||||
|
||||
original_columns = peek_and_get_columns()
|
||||
|
||||
# Map the dataset and remove original columns
|
||||
# This ensures only tokenized columns remain
|
||||
|
||||
@@ -471,8 +471,13 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
|
||||
)
|
||||
|
||||
# Remove length column only if it exists
|
||||
dataset_for_loader = train_dataset
|
||||
if "length" in train_dataset.column_names:
|
||||
dataset_for_loader = train_dataset.remove_columns(["length"])
|
||||
|
||||
data_loader = DataLoader(
|
||||
train_dataset.remove_columns(["length"]),
|
||||
dataset_for_loader,
|
||||
batch_sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
|
||||
|
||||
@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
|
||||
"liger_rms_norm": True,
|
||||
"liger_glu_activation": True,
|
||||
"torch_compile": True,
|
||||
"chat_template": "llama3",
|
||||
"chat_template": "qwen3",
|
||||
"kd_trainer": True,
|
||||
"kd_ce_alpha": 0.1,
|
||||
"kd_alpha": 0.9,
|
||||
|
||||
Reference in New Issue
Block a user