remove debug logs and simplify
This commit is contained in:
@@ -200,8 +200,9 @@ class DataCollatorForSeq2Seq:
|
|||||||
Returns:
|
Returns:
|
||||||
Sliced batch dictionary.
|
Sliced batch dictionary.
|
||||||
"""
|
"""
|
||||||
# Process keys that need to be sliced
|
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||||
for key in ["input_ids", "attention_mask", "labels"]:
|
|
||||||
|
for key in keys_to_slice:
|
||||||
if key in batch:
|
if key in batch:
|
||||||
seq_len = batch[key].shape[1]
|
seq_len = batch[key].shape[1]
|
||||||
slice_size = seq_len // self.local_world_size
|
slice_size = seq_len // self.local_world_size
|
||||||
@@ -211,90 +212,11 @@ class DataCollatorForSeq2Seq:
|
|||||||
if self.local_rank < self.local_world_size - 1
|
if self.local_rank < self.local_world_size - 1
|
||||||
else seq_len
|
else seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
if key == "input_ids":
|
|
||||||
# Before slicing
|
|
||||||
non_pad_tokens_total = (batch["input_ids"] != 128001).sum().item()
|
|
||||||
logger.info(
|
|
||||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
|
||||||
f"Total sequence length: {seq_len}, "
|
|
||||||
f"Non-padding tokens: {non_pad_tokens_total}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
|
||||||
f"GPU {self.rank} token IDs: {batch['input_ids']}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"GPU {self.rank}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
|
||||||
f"Slicing {key} from {seq_len} tokens to "
|
|
||||||
f"indices {start_idx}:{end_idx}"
|
|
||||||
)
|
|
||||||
|
|
||||||
batch[key] = batch[key][:, start_idx:end_idx]
|
batch[key] = batch[key][:, start_idx:end_idx]
|
||||||
|
|
||||||
# if key == "input_ids":
|
# Special handling for position_ids
|
||||||
# # After slicing
|
if key == "position_ids" and self.local_rank > 0:
|
||||||
# non_pad_tokens_slice = (batch["input_ids"] != 128001).sum().item()
|
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
||||||
# logger.info(
|
|
||||||
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
|
||||||
# f"Slice {start_idx}:{end_idx}, "
|
|
||||||
# f"Non-padding tokens in slice: {non_pad_tokens_slice}"
|
|
||||||
# )
|
|
||||||
# logger.info(
|
|
||||||
# f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
|
||||||
# f"GPU {self.rank} token IDs: {batch['input_ids']}"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if key == "labels":
|
|
||||||
# min_label = batch["labels"][batch["labels"] != -100].min().item() if (batch["labels"] != -100).any() else -100
|
|
||||||
# max_label = batch["labels"][batch["labels"] != -100].max().item() if (batch["labels"] != -100).any() else -100
|
|
||||||
# logger.info(f"GPU {self.rank}: Label range: {min_label} to {max_label}, Vocab size: {self.tokenizer.vocab_size}, labels: {batch['labels']}")
|
|
||||||
|
|
||||||
# # Find any labels that are outside the valid vocabulary range (but not -100 which is the ignore index)
|
|
||||||
# invalid_mask = (batch["labels"] >= self.tokenizer.vocab_size) & (batch["labels"] != -100)
|
|
||||||
|
|
||||||
# if invalid_mask.any():
|
|
||||||
# # Log this for debugging
|
|
||||||
# num_invalid = invalid_mask.sum().item()
|
|
||||||
# logger.warning(f"GPU {self.rank}: Found {num_invalid} invalid labels (>= vocab_size), setting to -100")
|
|
||||||
|
|
||||||
# # Replace invalid labels with -100 (ignore index)
|
|
||||||
# batch["labels"][invalid_mask] = -100
|
|
||||||
|
|
||||||
if key == "attention_mask":
|
|
||||||
logger.info(
|
|
||||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
|
||||||
f"Attention mask: {batch['attention_mask']}, "
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle position_ids if present
|
|
||||||
if "position_ids" in batch:
|
|
||||||
pos_ids = batch["position_ids"]
|
|
||||||
seq_len = pos_ids.shape[1]
|
|
||||||
slice_size = seq_len // self.local_world_size
|
|
||||||
start_idx = self.local_rank * slice_size
|
|
||||||
end_idx = (
|
|
||||||
start_idx + slice_size
|
|
||||||
if self.local_rank < self.local_world_size - 1
|
|
||||||
else seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
batch["position_ids"] = pos_ids[:, start_idx:end_idx]
|
|
||||||
|
|
||||||
# Adjust position_ids to be relative to the slice start
|
|
||||||
if self.local_rank > 0:
|
|
||||||
batch["position_ids"] = adjust_position_ids_for_slice(
|
|
||||||
batch["position_ids"], start_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"GPU {self.rank}/{self.world_size}, SP Rank {self.local_rank}/{self.local_world_size}: "
|
|
||||||
f"Position IDs: {batch['position_ids']}, "
|
|
||||||
)
|
|
||||||
|
|
||||||
# if dist.get_rank() == 0:
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
# dist.barrier()
|
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user