small updates
This commit is contained in:
@@ -608,7 +608,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
inputs: Dictionary mapping.
|
inputs: Dictionary mapping.
|
||||||
"""
|
"""
|
||||||
# Set up sequence parallelism for this step if enabled
|
# Set up sequence parallelism for this step if enabled
|
||||||
self._sp_training_step_setup(inputs)
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._update_ring_flash_attn_params(inputs)
|
||||||
|
|
||||||
# Proceed with normal training step
|
# Proceed with normal training step
|
||||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class SequenceParallelMixin:
|
|||||||
"""
|
"""
|
||||||
return self._create_sequence_parallel_sampler(
|
return self._create_sequence_parallel_sampler(
|
||||||
dataset,
|
dataset,
|
||||||
shuffle=not getattr(self.args, "curriculum_sampling", False),
|
shuffle=not self.args.curriculum_sampling,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||||
@@ -108,9 +108,6 @@ class SequenceParallelMixin:
|
|||||||
Args:
|
Args:
|
||||||
inputs: Current batch of inputs.
|
inputs: Current batch of inputs.
|
||||||
"""
|
"""
|
||||||
if not self.args.sequence_parallel_degree > 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
# At this point, inputs should already be partitioned by the sequence
|
# At this point, inputs should already be partitioned by the sequence
|
||||||
# parallel data collator
|
# parallel data collator
|
||||||
batch_size = inputs["input_ids"].shape[0]
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
|
|||||||
@@ -1120,7 +1120,7 @@ class AxolotlInputConfig(
|
|||||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||||
except ImportError as exception:
|
except ImportError as exception:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"ring_flash_attn is not installed. "
|
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||||
"or `pip install ring-flash-attn>=0.1.4`."
|
"or `pip install ring-flash-attn>=0.1.4`."
|
||||||
) from exception
|
) from exception
|
||||||
|
|||||||
Reference in New Issue
Block a user