fix kd test
This commit is contained in:
@@ -94,9 +94,15 @@ def wrap_dataset_for_tokenized_prompt(
|
|||||||
if prompt_tokenizer.supports_batched:
|
if prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
|
|
||||||
# Peek at the first example to get original column names
|
# For IterableDataset, we need to get original columns to remove them.
|
||||||
first_example = next(iter(dataset))
|
# We'll peek at the first example using a separate iterator to avoid consuming the main one.
|
||||||
original_columns = list(first_example.keys())
|
def peek_and_get_columns():
|
||||||
|
# Create a fresh iterator just for peeking
|
||||||
|
temp_iter = iter(dataset)
|
||||||
|
first_example = next(temp_iter)
|
||||||
|
return list(first_example.keys())
|
||||||
|
|
||||||
|
original_columns = peek_and_get_columns()
|
||||||
|
|
||||||
# Map the dataset and remove original columns
|
# Map the dataset and remove original columns
|
||||||
# This ensures only tokenized columns remain
|
# This ensures only tokenized columns remain
|
||||||
|
|||||||
@@ -471,8 +471,13 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
|
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remove length column only if it exists
|
||||||
|
dataset_for_loader = train_dataset
|
||||||
|
if "length" in train_dataset.column_names:
|
||||||
|
dataset_for_loader = train_dataset.remove_columns(["length"])
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
train_dataset.remove_columns(["length"]),
|
dataset_for_loader,
|
||||||
batch_sampler=sampler,
|
batch_sampler=sampler,
|
||||||
)
|
)
|
||||||
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
|
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
|
|||||||
"liger_rms_norm": True,
|
"liger_rms_norm": True,
|
||||||
"liger_glu_activation": True,
|
"liger_glu_activation": True,
|
||||||
"torch_compile": True,
|
"torch_compile": True,
|
||||||
"chat_template": "llama3",
|
"chat_template": "qwen3",
|
||||||
"kd_trainer": True,
|
"kd_trainer": True,
|
||||||
"kd_ce_alpha": 0.1,
|
"kd_ce_alpha": 0.1,
|
||||||
"kd_alpha": 0.9,
|
"kd_alpha": 0.9,
|
||||||
|
|||||||
Reference in New Issue
Block a user