remove debugging

This commit is contained in:
Wing Lian
2025-05-22 15:00:42 -04:00
parent 5ccfd225cb
commit 7263845207
2 changed files with 1 additions and 30 deletions

View File

@@ -435,7 +435,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval)
if collator_cls:
pass
collator = collator_cls
elif self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:

View File

@@ -253,12 +253,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
B, N, D = student_input.shape # pylint: disable=invalid-name
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
# print("student_input shape: " + str(student_input.shape))
# print("target_token_ids shape: " + str(target_token_ids.shape))
# print("target_logprobs shape: " + str(target_logprobs.shape))
# print("target_mask shape: " + str(target_mask.shape))
# print("true_labels shape: " + str(true_labels.shape))
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
@@ -266,20 +260,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
# pad and shift for cross entropy loss
true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
# true_labels_flat = true_labels.reshape(-1)
# student_input_flat = student_input[:, :-1, :].contiguous().view(-1, student_input.shape[-1])
# target_token_ids_flat = target_token_ids[:, 1:, :].contiguous().view(-1, target_token_ids.shape[-1])
# target_logprobs_flat = target_logprobs[:, 1:, :].contiguous().view(-1, target_logprobs.shape[-1])
# target_mask_flat = target_mask[:, 1:, :].contiguous().view(-1, target_mask.shape[-1])
# true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
# N = N - 1
# print("student_input_flat shape: " + str(student_input_flat.shape))
# print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape))
# print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape))
# print("target_mask_flat shape: " + str(target_mask_flat.shape))
# print("true_labels_flat shape: " + str(true_labels_flat.shape))
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(
@@ -323,16 +304,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
grad_input_flat, grad_weight, grad_bias_maybe = (
ctx.saved_tensors
) # grad_input_flat is (B*N, D)
print("grad_input_flat")
print(grad_input_flat.shape)
# num_valid_tokens_scalar = ctx.num_valid_tokens_scalar
# normalizer = float(num_valid_tokens_scalar) if num_valid_tokens_scalar > 0 else 1.0
# grad_input_flat = grad_input_flat / normalizer
# grad_weight = grad_weight / normalizer
# if grad_bias_maybe is not None:
# grad_bias_maybe = grad_bias_maybe / normalizer
# Scale gradients by grad_output if it's not 1.0
if not torch.equal(