diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 0aae48300..3b94ed282 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -608,7 +608,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai inputs: Dictionary mapping. """ # Set up sequence parallelism for this step if enabled - self._sp_training_step_setup(inputs) + if self.args.sequence_parallel_degree > 1: + self._update_ring_flash_attn_params(inputs) # Proceed with normal training step loss = super().training_step(model, inputs, num_items_in_batch) diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 3f511a7de..f52c044b6 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -82,7 +82,7 @@ class SequenceParallelMixin: """ return self._create_sequence_parallel_sampler( dataset, - shuffle=not getattr(self.args, "curriculum_sampling", False), + shuffle=not self.args.curriculum_sampling, ) def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None: @@ -108,9 +108,6 @@ class SequenceParallelMixin: Args: inputs: Current batch of inputs. """ - if not self.args.sequence_parallel_degree > 1: - return - # At this point, inputs should already be partitioned by the sequence # parallel data collator batch_size = inputs["input_ids"].shape[0] diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index cc130c2c6..847d4e510 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1120,7 +1120,7 @@ class AxolotlInputConfig( import ring_flash_attn # noqa: F401 # pylint:disable=unused-import except ImportError as exception: raise ImportError( - "ring_flash_attn is not installed. " + "sequence_parallel_degree > 1 but ring_flash_attn is not installed. " "Please install it with `pip install axolotl[ring-flash-attn] " "or `pip install ring-flash-attn>=0.1.4`." ) from exception