Sequence parallelism quick follow-ups; remove ModelCallback (#2450)

* guard return if ring attn alrady registered

* add docs link, bits in multi-gpu docs, remove save model callback (subsumed by HF trainers)

* configurable heads_k_stride from ring-flash-attn hf adapter
This commit is contained in:
Dan Saunders
2025-03-31 09:13:42 -04:00
committed by GitHub
parent cf0c79d52e
commit 5410195e0b
10 changed files with 56 additions and 31 deletions

View File

@@ -110,7 +110,7 @@ class TestRingAttention:
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4)
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2