small changes
This commit is contained in:
@@ -45,7 +45,8 @@ except ImportError:
|
||||
def update_ring_flash_attn_params(*args, **kwargs):
|
||||
raise ImportError(
|
||||
"ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install ring-flash-attn>=0.1.4`"
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -36,6 +36,10 @@ def register_ring_attn(sequence_parallel_degree: int):
|
||||
)
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
)
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
|
||||
Reference in New Issue
Block a user