Compare commits

...

1 Commits

Author SHA1 Message Date
Dan Saunders
979632f59c SP restore buffers 2025-06-26 02:44:58 +00:00

View File

@@ -207,6 +207,9 @@ class SequenceParallelContextManager:
# Store original sequence length and padding information
self.original_seq_len = 0
self.pad_len = 0
# Store kwargs passed to model forward pass
self.original_kwargs: None | dict[str, torch.Tensor] = None
# Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial(
@@ -259,6 +262,9 @@ class SequenceParallelContextManager:
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
# Store original kwargs
self.original_kwargs = {key: value.clone() for key, value in updated_kwargs.items()}
# Apply sequence parallelism to updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = (