SP restore buffers
This commit is contained in:
@@ -207,6 +207,9 @@ class SequenceParallelContextManager:
|
|||||||
# Store original sequence length and padding information
|
# Store original sequence length and padding information
|
||||||
self.original_seq_len = 0
|
self.original_seq_len = 0
|
||||||
self.pad_len = 0
|
self.pad_len = 0
|
||||||
|
|
||||||
|
# Store kwargs passed to model forward pass
|
||||||
|
self.original_kwargs: None | dict[str, torch.Tensor] = None
|
||||||
|
|
||||||
# Create a partially applied version of the apply_sequence_parallelism function
|
# Create a partially applied version of the apply_sequence_parallelism function
|
||||||
self.apply_sequence_parallelism = functools.partial(
|
self.apply_sequence_parallelism = functools.partial(
|
||||||
@@ -259,6 +262,9 @@ class SequenceParallelContextManager:
|
|||||||
|
|
||||||
# Any excess positional arguments are kept as-is
|
# Any excess positional arguments are kept as-is
|
||||||
remaining_args = args[len(forward_params) :]
|
remaining_args = args[len(forward_params) :]
|
||||||
|
|
||||||
|
# Store original kwargs
|
||||||
|
self.original_kwargs = {key: value.clone() for key, value in updated_kwargs.items()}
|
||||||
|
|
||||||
# Apply sequence parallelism to updated kwargs
|
# Apply sequence parallelism to updated kwargs
|
||||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||||
|
|||||||
Reference in New Issue
Block a user