scoping down problematic import
This commit is contained in:
@@ -13,8 +13,6 @@ import torch.distributed as dist
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -111,11 +109,11 @@ class DataCollatorForSeq2Seq:
|
||||
|
||||
def __post_init__(self):
|
||||
if self.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
# Get information about our position in the SP group
|
||||
sp_group = get_ring_attn_group()
|
||||
self.rank = dist.get_rank()
|
||||
self.local_rank = dist.get_rank(group=sp_group)
|
||||
self.world_size = dist.get_world_size()
|
||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
|
||||
Reference in New Issue
Block a user