remove debug logs and simplify

This commit is contained in:
Dan Saunders
2025-03-13 15:47:45 +00:00
parent 5731cdc0cf
commit d0e178d52f

View File

@@ -200,8 +200,9 @@ class DataCollatorForSeq2Seq:
Returns:
Sliced batch dictionary.
"""
# Process keys that need to be sliced
for key in ["input_ids", "attention_mask", "labels"]:
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
for key in keys_to_slice:
if key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // self.local_world_size
@@ -211,90 +212,11 @@ class DataCollatorForSeq2Seq:
if self.local_rank < self.local_world_size - 1
else seq_len
)
if key == "input_ids":
# Before slicing
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
logger.info(
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
f"Total sequence length: {seq_len}, "
f"Non-padding tokens: {non_pad_tokens_total}"
)
logger.info(
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
f"GPU {self.rank} token IDs: {batch['input_ids']}"
)
logger.info(
f"GPU {self.rank}, SP Rank {self.local_rank}/{self.local_world_size}: "
f"Slicing {key} from {seq_len} tokens to "
f"indices {start_idx}:{end_idx}"
)
batch[key] = batch[key][:, start_idx:end_idx]
# if key == "input_ids":
# # After slicing
# non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
# logger.info(
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
# f"Slice {start_idx}:{end_idx}, "
# f"Non-padding tokens in slice: {non_pad_tokens_slice}"
# )
# logger.info(
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
# f"GPU {self.rank} token IDs: {batch['input_ids']}"
# )
# if key == "labels":
# min_label = batch["labels"][batch["labels"] != -100].min().item() if (batch["labels"] != -100).any() else -100
# max_label = batch["labels"][batch["labels"] != -100].max().item() if (batch["labels"] != -100).any() else -100
# logger.info(f"GPU {self.rank}: Label range: {min_label} to {max_label}, Vocab size: {self.tokenizer.vocab_size}, labels: {batch['labels']}")
# # Find any labels that are outside the valid vocabulary range (but not -100 which is the ignore index)
# invalid_mask = (batch["labels"] >= self.tokenizer.vocab_size) & (batch["labels"] != -100)
# if invalid_mask.any():
# # Log this for debugging
# num_invalid = invalid_mask.sum().item()
# logger.warning(f"GPU {self.rank}: Found {num_invalid} invalid labels (>= vocab_size), setting to -100")
# # Replace invalid labels with -100 (ignore index)
# batch["labels"][invalid_mask] = -100
if key == "attention_mask":
logger.info(
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
f"Attention mask: {batch['attention_mask']}, "
)
# Handle position_ids if present
if "position_ids" in batch:
pos_ids = batch["position_ids"]
seq_len = pos_ids.shape[1]
slice_size = seq_len // self.local_world_size
start_idx = self.local_rank * slice_size
end_idx = (
start_idx + slice_size
if self.local_rank < self.local_world_size - 1
else seq_len
)
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
# Adjust position_ids to be relative to the slice start
if self.local_rank > 0:
batch["position_ids"] = adjust_position_ids_for_slice(
batch["position_ids"], start_idx
)
logger.info(
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
f"Position IDs: {batch['position_ids']}, "
)
# if dist.get_rank() == 0:
# import ipdb; ipdb.set_trace()
# dist.barrier()
# Special handling for position_ids
if key == "position_ids" and self.local_rank > 0:
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
return batch