SP dataloader patching + removing custom sampler / dataloader logic (#2686)

* utilize accelerate prepare_data_loader with patching

* lint

* cleanup, fix

* update to support DPO quirk

* small change

* coderabbit commits, cleanup, remove dead code

* quarto fix

* patch fix

* review comments

* moving monkeypatch up one level

* fix
This commit is contained in:
Dan Saunders
2025-05-21 11:20:20 -04:00
committed by GitHub
parent a27b909c5c
commit 6aa41740df
20 changed files with 304 additions and 477 deletions

View File

@@ -10,7 +10,7 @@ import pytest
import torch
from accelerate.state import PartialState
from axolotl.monkeypatch.attention.ring_attn import (
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
@@ -313,13 +313,13 @@ class TestApplySequenceParallelism:
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
MagicMock,
)
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)