another import scoping change

This commit is contained in:
Dan Saunders
2025-03-13 23:32:07 +00:00
parent 02e1a42f04
commit 0ade60d455

View File

@@ -27,7 +27,6 @@ from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
@@ -386,6 +385,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
if self.args.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
self.ring_attn_group = get_ring_attn_group()
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -984,7 +988,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
class AxolotlMambaTrainer(AxolotlTrainer):