From 7263845207c0653a6cdd7d88e2d21efae665d2c5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 May 2025 15:00:42 -0400 Subject: [PATCH] remove debugging --- src/axolotl/core/builders/causal.py | 2 +- src/axolotl/integrations/kd/kernels/liger.py | 29 -------------------- 2 files changed, 1 insertion(+), 30 deletions(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index e8838014f..fa6a9ec37 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -435,7 +435,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator_cls = plugin_manager.get_collator_cls(self.cfg, is_eval=is_eval) if collator_cls: - pass + collator = collator_cls elif self.cfg.reward_model: collator = RewardDataCollatorWithPadding elif use_batch_sampler_collator: diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py index 9fa0b663d..35f82855f 100644 --- a/src/axolotl/integrations/kd/kernels/liger.py +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -253,12 +253,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): B, N, D = student_input.shape # pylint: disable=invalid-name K = target_token_ids.shape[-1] # pylint: disable=invalid-name - # print("student_input shape: " + str(student_input.shape)) - # print("target_token_ids shape: " + str(target_token_ids.shape)) - # print("target_logprobs shape: " + str(target_logprobs.shape)) - # print("target_mask shape: " + str(target_mask.shape)) - # print("true_labels shape: " + str(true_labels.shape)) - student_input_flat = student_input.reshape(-1, student_input.shape[-1]) target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1]) target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1]) @@ -266,20 +260,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): # pad and shift for cross entropy loss true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index) true_labels_flat = true_labels[:, 1:].contiguous().view(-1) - # true_labels_flat = true_labels.reshape(-1) - # student_input_flat = student_input[:, :-1, :].contiguous().view(-1, student_input.shape[-1]) - # target_token_ids_flat = target_token_ids[:, 1:, :].contiguous().view(-1, target_token_ids.shape[-1]) - # target_logprobs_flat = target_logprobs[:, 1:, :].contiguous().view(-1, target_logprobs.shape[-1]) - # target_mask_flat = target_mask[:, 1:, :].contiguous().view(-1, target_mask.shape[-1]) - # true_labels_flat = true_labels[:, 1:].contiguous().view(-1) - # N = N - 1 - - # print("student_input_flat shape: " + str(student_input_flat.shape)) - # print("target_token_ids_flat shape: " + str(target_token_ids_flat.shape)) - # print("target_logprobs_flat shape: " + str(target_logprobs_flat.shape)) - # print("target_mask_flat shape: " + str(target_mask_flat.shape)) - # print("true_labels_flat shape: " + str(true_labels_flat.shape)) num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE) _student_input_chunks = torch.chunk( @@ -323,16 +304,6 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): grad_input_flat, grad_weight, grad_bias_maybe = ( ctx.saved_tensors ) # grad_input_flat is (B*N, D) - print("grad_input_flat") - print(grad_input_flat.shape) - - # num_valid_tokens_scalar = ctx.num_valid_tokens_scalar - # normalizer = float(num_valid_tokens_scalar) if num_valid_tokens_scalar > 0 else 1.0 - - # grad_input_flat = grad_input_flat / normalizer - # grad_weight = grad_weight / normalizer - # if grad_bias_maybe is not None: - # grad_bias_maybe = grad_bias_maybe / normalizer # Scale gradients by grad_output if it's not 1.0 if not torch.equal(