var naming and add todo
This commit is contained in:
@@ -494,7 +494,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
logprobs = sample.pop(self.logprobs_field)
|
logprobs = sample.pop(self.logprobs_field)
|
||||||
target_seq_len = len(logprobs)
|
target_seq_len = len(logprobs)
|
||||||
input_seq_len = len(sample["input_ids"])
|
input_seq_len = len(sample["input_ids"])
|
||||||
padding_len = input_seq_len - target_seq_len
|
input_padding_len = input_seq_len - target_seq_len
|
||||||
top_k = len(logprobs[0])
|
top_k = len(logprobs[0])
|
||||||
target_logprobs = []
|
target_logprobs = []
|
||||||
target_token_ids = []
|
target_token_ids = []
|
||||||
@@ -502,12 +502,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
|
|
||||||
# fill with -inf for padding_len tokens for top_k tokens
|
# fill with -inf for padding_len tokens for top_k tokens
|
||||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||||
for _ in range(padding_len):
|
for _ in range(input_padding_len):
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_token_ids.append(list(range(top_k)))
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
for _ in range(target_seq_len):
|
for _ in range(target_seq_len):
|
||||||
|
# TODO also check against sample["labels"]
|
||||||
target_mask.append([1] * top_k)
|
target_mask.append([1] * top_k)
|
||||||
|
|
||||||
for _, token_pos_logprobs in enumerate(logprobs):
|
for _, token_pos_logprobs in enumerate(logprobs):
|
||||||
|
|||||||
Reference in New Issue
Block a user