diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py index ea4f9c57a..91f1d0c67 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch_ring.py @@ -34,8 +34,7 @@ def create_ring_flash_attention_forward(process_group: dist.ProcessGroup) -> Cal Create a ring flash attention forward function compatible with HuggingFace's interface. Args: - process_group: A PyTorch distributed process group that defines the communication - topology for the ring attention pattern. + process_group: A PyTorch distributed process group. Returns: A function that implements the ring flash attention forward pass with the