When using Ray use prepare for dataloader fixes (#3198)
* make sure to use ray prepare for dataloader fixes * ray tests use 2.7.0+ * don't call init_distributed w ray and deepspeed * handle dict deepspeed config * better handling of dict deepspeed config * use json.dumps * guard to_dict * wrap import for optional ray
This commit is contained in:
@@ -13,7 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
from tests.e2e.utils import (
|
||||
check_tensorboard,
|
||||
require_torch_2_7_0,
|
||||
require_torch_lt_2_6_0,
|
||||
)
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
@@ -24,7 +23,7 @@ class TestMultiGPURay:
|
||||
Test cases for AnyScale Ray post training
|
||||
"""
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@require_torch_2_7_0
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -83,7 +82,7 @@ class TestMultiGPURay:
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
|
||||
Reference in New Issue
Block a user