diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py index 9599c3abf..9e1e6b261 100644 --- a/tests/e2e/multigpu/test_qwen2.py +++ b/tests/e2e/multigpu/test_qwen2.py @@ -28,7 +28,10 @@ class TestMultiGPUQwen2: cfg = DictDefault( { "base_model": base_model, - "load_in_4bit": True, + "quantization": { + "backend": "bnb", + "bits": 4, + }, "rl": "dpo", "chat_template": "chatml", "sequence_len": 2048,