scoping down problematic import

This commit is contained in:
Dan Saunders
2025-03-13 23:30:04 +00:00
parent 919b88f11b
commit 02e1a42f04

View File

@@ -13,8 +13,6 @@ import torch.distributed as dist
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
logger = logging.getLogger(__name__)
@@ -111,11 +109,11 @@ class DataCollatorForSeq2Seq:
def __post_init__(self):
if self.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.rank = dist.get_rank()
self.local_rank = dist.get_rank(group=sp_group)
self.world_size = dist.get_world_size()
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):