diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 9ce9d84ae..7dd402a3f 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -13,8 +13,6 @@ import torch.distributed as dist from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy -from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group - logger = logging.getLogger(__name__) @@ -111,11 +109,11 @@ class DataCollatorForSeq2Seq: def __post_init__(self): if self.sequence_parallel_degree > 1: + from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group + # Get information about our position in the SP group sp_group = get_ring_attn_group() - self.rank = dist.get_rank() self.local_rank = dist.get_rank(group=sp_group) - self.world_size = dist.get_world_size() self.local_world_size = dist.get_world_size(group=sp_group) def __call__(self, features, return_tensors=None):