diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index f429cd2ae..09872445b 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -207,6 +207,9 @@ class SequenceParallelContextManager: # Store original sequence length and padding information self.original_seq_len = 0 self.pad_len = 0 + + # Store kwargs passed to model forward pass + self.original_kwargs: None | dict[str, torch.Tensor] = None # Create a partially applied version of the apply_sequence_parallelism function self.apply_sequence_parallelism = functools.partial( @@ -259,6 +262,9 @@ class SequenceParallelContextManager: # Any excess positional arguments are kept as-is remaining_args = args[len(forward_params) :] + + # Store original kwargs + self.original_kwargs = {key: value.clone() for key, value in updated_kwargs.items()} # Apply sequence parallelism to updated kwargs updated_kwargs, self.original_seq_len, self.pad_len = (