From 35a84f2cb8dcd9220cd7da7c12c9bc5e3eadd28d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 14 Jan 2025 21:37:10 -0500 Subject: [PATCH] more fixes --- src/axolotl/integrations/kd/chat_template.py | 8 ++++---- src/axolotl/utils/trainer.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 32f5d0ce4..ee557c896 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -71,7 +71,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # so we need to slice from the left/beginning of logprobs logprobs = logprobs[:-input_seq_len] input_padding_len = 0 - target_seq_len = input_seq_len + # target_seq_len = input_seq_len # fill with -inf for padding_len tokens for top_k tokens # extend target_logprobs with a padding_len x top_k 2D list filled with -inf @@ -84,9 +84,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_token_ids.append(list(range(top_k))) target_mask.append([0] * top_k) - for _ in range(target_seq_len): - # TODO also check against sample["labels"] - target_mask.append([1] * top_k) + # for _ in range(target_seq_len): + # # TODO also check against sample["labels"] + # target_mask.append([1] * top_k) for position in range(input_padding_len, input_seq_len): if sample["labels"][position] == -100: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index ce4deafa9..7728ab181 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -279,6 +279,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): drop_long_kwargs["desc"] = "Dropping Long Sequences" train_dataset = train_dataset.filter( drop_long, + batched=True, **filter_map_kwargs, **drop_long_kwargs, ) @@ -310,8 +311,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): """ labels = sample["labels"] if not labels: - # Edge case: if labels is empty, decide if you want to keep or drop - return True # or False + return True # Check if single example or batch # If first element is an int, we assume a single example