small updates

This commit is contained in:
Dan Saunders
2025-03-20 02:45:53 +00:00
parent 0b2c2ed68c
commit 22cfa42961
3 changed files with 4 additions and 6 deletions

View File

@@ -608,7 +608,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
self._sp_training_step_setup(inputs)
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
loss = super().training_step(model, inputs, num_items_in_batch)

View File

@@ -82,7 +82,7 @@ class SequenceParallelMixin:
"""
return self._create_sequence_parallel_sampler(
dataset,
shuffle=not getattr(self.args, "curriculum_sampling", False),
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
@@ -108,9 +108,6 @@ class SequenceParallelMixin:
Args:
inputs: Current batch of inputs.
"""
if not self.args.sequence_parallel_degree > 1:
return
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]

View File

@@ -1120,7 +1120,7 @@ class AxolotlInputConfig(
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"ring_flash_attn is not installed. "
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception