remove debug logs and simplify
This commit is contained in:
@@ -200,8 +200,9 @@ class DataCollatorForSeq2Seq:
|
||||
Returns:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
# Process keys that need to be sliced
|
||||
for key in ["input_ids", "attention_mask", "labels"]:
|
||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||
|
||||
for key in keys_to_slice:
|
||||
if key in batch:
|
||||
seq_len = batch[key].shape[1]
|
||||
slice_size = seq_len // self.local_world_size
|
||||
@@ -211,90 +212,11 @@ class DataCollatorForSeq2Seq:
|
||||
if self.local_rank < self.local_world_size - 1
|
||||
else seq_len
|
||||
)
|
||||
|
||||
if key == "input_ids":
|
||||
# Before slicing
|
||||
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
|
||||
logger.info(
|
||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
f"Total sequence length: {seq_len}, "
|
||||
f"Non-padding tokens: {non_pad_tokens_total}"
|
||||
)
|
||||
logger.info(
|
||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
f"GPU {self.rank} token IDs: {batch['input_ids']}"
|
||||
)
|
||||
logger.info(
|
||||
f"GPU {self.rank}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
f"Slicing {key} from {seq_len} tokens to "
|
||||
f"indices {start_idx}:{end_idx}"
|
||||
)
|
||||
|
||||
batch[key] = batch[key][:, start_idx:end_idx]
|
||||
|
||||
# if key == "input_ids":
|
||||
# # After slicing
|
||||
# non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
||||
# logger.info(
|
||||
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
# f"Slice {start_idx}:{end_idx}, "
|
||||
# f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
||||
# )
|
||||
# logger.info(
|
||||
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
# f"GPU {self.rank} token IDs: {batch['input_ids']}"
|
||||
# )
|
||||
|
||||
# if key == "labels":
|
||||
# min_label = batch["labels"][batch["labels"] != -100].min().item() if (batch["labels"] != -100).any() else -100
|
||||
# max_label = batch["labels"][batch["labels"] != -100].max().item() if (batch["labels"] != -100).any() else -100
|
||||
# logger.info(f"GPU {self.rank}: Label range: {min_label} to {max_label}, Vocab size: {self.tokenizer.vocab_size}, labels: {batch['labels']}")
|
||||
|
||||
# # Find any labels that are outside the valid vocabulary range (but not -100 which is the ignore index)
|
||||
# invalid_mask = (batch["labels"] >= self.tokenizer.vocab_size) & (batch["labels"] != -100)
|
||||
|
||||
# if invalid_mask.any():
|
||||
# # Log this for debugging
|
||||
# num_invalid = invalid_mask.sum().item()
|
||||
# logger.warning(f"GPU {self.rank}: Found {num_invalid} invalid labels (>= vocab_size), setting to -100")
|
||||
|
||||
# # Replace invalid labels with -100 (ignore index)
|
||||
# batch["labels"][invalid_mask] = -100
|
||||
|
||||
if key == "attention_mask":
|
||||
logger.info(
|
||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
f"Attention mask: {batch['attention_mask']}, "
|
||||
)
|
||||
|
||||
# Handle position_ids if present
|
||||
if "position_ids" in batch:
|
||||
pos_ids = batch["position_ids"]
|
||||
seq_len = pos_ids.shape[1]
|
||||
slice_size = seq_len // self.local_world_size
|
||||
start_idx = self.local_rank * slice_size
|
||||
end_idx = (
|
||||
start_idx + slice_size
|
||||
if self.local_rank < self.local_world_size - 1
|
||||
else seq_len
|
||||
)
|
||||
|
||||
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
||||
|
||||
# Adjust position_ids to be relative to the slice start
|
||||
if self.local_rank > 0:
|
||||
batch["position_ids"] = adjust_position_ids_for_slice(
|
||||
batch["position_ids"], start_idx
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
||||
f"Position IDs: {batch['position_ids']}, "
|
||||
)
|
||||
|
||||
# if dist.get_rank() == 0:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# dist.barrier()
|
||||
# Special handling for position_ids
|
||||
if key == "position_ids" and self.local_rank > 0:
|
||||
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user