small updates
This commit is contained in:
@@ -608,7 +608,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
||||
inputs: Dictionary mapping.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
self._sp_training_step_setup(inputs)
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
# Proceed with normal training step
|
||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
@@ -82,7 +82,7 @@ class SequenceParallelMixin:
|
||||
"""
|
||||
return self._create_sequence_parallel_sampler(
|
||||
dataset,
|
||||
shuffle=not getattr(self.args, "curriculum_sampling", False),
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
)
|
||||
|
||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||
@@ -108,9 +108,6 @@ class SequenceParallelMixin:
|
||||
Args:
|
||||
inputs: Current batch of inputs.
|
||||
"""
|
||||
if not self.args.sequence_parallel_degree > 1:
|
||||
return
|
||||
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
|
||||
@@ -1120,7 +1120,7 @@ class AxolotlInputConfig(
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"ring_flash_attn is not installed. "
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
Reference in New Issue
Block a user