From 0ade60d455bfac996e6b685c5918358df607a213 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 13 Mar 2025 23:32:07 +0000 Subject: [PATCH] another import scoping change --- src/axolotl/core/trainers/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index ebc51d5fa..109017145 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -27,7 +27,6 @@ from trl.trainer.utils import pad_to_length from typing_extensions import override from axolotl.integrations.base import BaseOptimizerFactory -from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group from axolotl.monkeypatch.relora import ReLoRAScheduler from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( @@ -386,6 +385,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + if self.args.sequence_parallel_degree > 1: + from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group + + self.ring_attn_group = get_ring_attn_group() + def _wrap_model(self, model, training=True, dataloader=None): if self.args.torch_compile: torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access @@ -984,7 +988,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len ) - update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) + update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) class AxolotlMambaTrainer(AxolotlTrainer):