small changes
This commit is contained in:
@@ -45,7 +45,8 @@ except ImportError:
|
|||||||
def update_ring_flash_attn_params(*args, **kwargs):
|
def update_ring_flash_attn_params(*args, **kwargs):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"ring_flash_attn is not installed. "
|
"ring_flash_attn is not installed. "
|
||||||
"Please install it with `pip install ring-flash-attn>=0.1.4`"
|
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||||
|
"or `pip install ring-flash-attn>=0.1.4`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,10 @@ def register_ring_attn(sequence_parallel_degree: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
assert sequence_parallel_degree <= world_size, (
|
||||||
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
|
f"must be less than or equal to world_size ({world_size})"
|
||||||
|
)
|
||||||
assert world_size % sequence_parallel_degree == 0, (
|
assert world_size % sequence_parallel_degree == 0, (
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
f"must evenly divide world_size ({world_size})"
|
f"must evenly divide world_size ({world_size})"
|
||||||
|
|||||||
Reference in New Issue
Block a user