remove debugging
This commit is contained in:
@@ -435,7 +435,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval)
|
||||
|
||||
if collator_cls:
|
||||
pass
|
||||
collator = collator_cls
|
||||
elif self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
elif use_batch_sampler_collator:
|
||||
|
||||
@@ -253,12 +253,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
B, N, D = student_input.shape # pylint: disable=invalid-name
|
||||
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
|
||||
|
||||
# print("student_input shape: " + str(student_input.shape))
|
||||
# print("target_token_ids shape: " + str(target_token_ids.shape))
|
||||
# print("target_logprobs shape: " + str(target_logprobs.shape))
|
||||
# print("target_mask shape: " + str(target_mask.shape))
|
||||
# print("true_labels shape: " + str(true_labels.shape))
|
||||
|
||||
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
|
||||
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
|
||||
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
|
||||
@@ -266,20 +260,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
# pad and shift for cross entropy loss
|
||||
true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
|
||||
true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
|
||||
# true_labels_flat = true_labels.reshape(-1)
|
||||
|
||||
# student_input_flat = student_input[:, :-1, :].contiguous().view(-1, student_input.shape[-1])
|
||||
# target_token_ids_flat = target_token_ids[:, 1:, :].contiguous().view(-1, target_token_ids.shape[-1])
|
||||
# target_logprobs_flat = target_logprobs[:, 1:, :].contiguous().view(-1, target_logprobs.shape[-1])
|
||||
# target_mask_flat = target_mask[:, 1:, :].contiguous().view(-1, target_mask.shape[-1])
|
||||
# true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
|
||||
# N = N - 1
|
||||
|
||||
# print("student_input_flat shape: " + str(student_input_flat.shape))
|
||||
# print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape))
|
||||
# print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape))
|
||||
# print("target_mask_flat shape: " + str(target_mask_flat.shape))
|
||||
# print("true_labels_flat shape: " + str(true_labels_flat.shape))
|
||||
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
|
||||
|
||||
_student_input_chunks = torch.chunk(
|
||||
@@ -323,16 +304,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
grad_input_flat, grad_weight, grad_bias_maybe = (
|
||||
ctx.saved_tensors
|
||||
) # grad_input_flat is (B*N, D)
|
||||
print("grad_input_flat")
|
||||
print(grad_input_flat.shape)
|
||||
|
||||
# num_valid_tokens_scalar = ctx.num_valid_tokens_scalar
|
||||
# normalizer = float(num_valid_tokens_scalar) if num_valid_tokens_scalar > 0 else 1.0
|
||||
|
||||
# grad_input_flat = grad_input_flat / normalizer
|
||||
# grad_weight = grad_weight / normalizer
|
||||
# if grad_bias_maybe is not None:
|
||||
# grad_bias_maybe = grad_bias_maybe / normalizer
|
||||
|
||||
# Scale gradients by grad_output if it's not 1.0
|
||||
if not torch.equal(
|
||||
|
||||
Reference in New Issue
Block a user