diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 8d5e41d34..cab8a0634 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -45,7 +45,8 @@ except ImportError: def update_ring_flash_attn_params(*args, **kwargs): raise ImportError( "ring_flash_attn is not installed. " - "Please install it with `pip install ring-flash-attn>=0.1.4`" + "Please install it with `pip install axolotl[ring-flash-attn] " + "or `pip install ring-flash-attn>=0.1.4`." ) diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index eb146609e..fe333ad32 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -36,6 +36,10 @@ def register_ring_attn(sequence_parallel_degree: int): ) world_size = dist.get_world_size() + assert sequence_parallel_degree <= world_size, ( + f"sequence_parallel_degree ({sequence_parallel_degree}) " + f"must be less than or equal to world_size ({world_size})" + ) assert world_size % sequence_parallel_degree == 0, ( f"sequence_parallel_degree ({sequence_parallel_degree}) " f"must evenly divide world_size ({world_size})"