small changes

This commit is contained in:
Dan Saunders
2025-03-19 17:15:30 +00:00
parent c1a58339e8
commit a26985c53c
2 changed files with 6 additions and 1 deletions

View File

@@ -45,7 +45,8 @@ except ImportError:
def update_ring_flash_attn_params(*args, **kwargs):
raise ImportError(
"ring_flash_attn is not installed. "
"Please install it with `pip install ring-flash-attn>=0.1.4`"
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
)

View File

@@ -36,6 +36,10 @@ def register_ring_attn(sequence_parallel_degree: int):
)
world_size = dist.get_world_size()
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"