scoping down problematic import
This commit is contained in:
@@ -13,8 +13,6 @@ import torch.distributed as dist
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -111,11 +109,11 @@ class DataCollatorForSeq2Seq:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.sequence_parallel_degree > 1:
|
if self.sequence_parallel_degree > 1:
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
# Get information about our position in the SP group
|
# Get information about our position in the SP group
|
||||||
sp_group = get_ring_attn_group()
|
sp_group = get_ring_attn_group()
|
||||||
self.rank = dist.get_rank()
|
|
||||||
self.local_rank = dist.get_rank(group=sp_group)
|
self.local_rank = dist.get_rank(group=sp_group)
|
||||||
self.world_size = dist.get_world_size()
|
|
||||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user