SP restore buffers
This commit is contained in:
@@ -207,6 +207,9 @@ class SequenceParallelContextManager:
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
|
||||
# Store kwargs passed to model forward pass
|
||||
self.original_kwargs: None | dict[str, torch.Tensor] = None
|
||||
|
||||
# Create a partially applied version of the apply_sequence_parallelism function
|
||||
self.apply_sequence_parallelism = functools.partial(
|
||||
@@ -259,6 +262,9 @@ class SequenceParallelContextManager:
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
# Store original kwargs
|
||||
self.original_kwargs = {key: value.clone() for key, value in updated_kwargs.items()}
|
||||
|
||||
# Apply sequence parallelism to updated kwargs
|
||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||
|
||||
Reference in New Issue
Block a user