test fix
This commit is contained in:
@@ -977,7 +977,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
packed_seq_lens=[seq_len] * batch_size, total_seq_len=total_seq_len
|
packed_seq_lens=[seq_len] * batch_size, total_seq_len=total_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the loss from the parent implementation
|
|
||||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@@ -552,7 +552,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
if self.cfg.sequence_parallel_degree > 1:
|
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||||
|
|
||||||
# Initialize ring attn for sequence parallelism. This must be done after
|
# Initialize ring attn for sequence parallelism. This must be done after
|
||||||
|
|||||||
Reference in New Issue
Block a user