gradient accumulation tests, embeddings w pad_token fix, smaller models (#2059)

* add more test cases for gradient accumulation and fix zero3

* swap out for smaller model

* fix missing return

* fix missing pad_token in config

* support concurrency for multigpu testing

* cast empty deepspeed to empty string for zero3 check

* fix temp_dir as fixture so parametrize works properly

* fix test file for multigpu evals

* don't use default

* don't use default for fsdp_state_dict_type

* don't use llama tokenizer w smollm

* also automatically cancel multigpu for concurrency
This commit is contained in:
Wing Lian
2024-11-14 12:59:00 -05:00
committed by GitHub
parent f3a5d119af
commit 71d4030b79
8 changed files with 118 additions and 71 deletions

View File

@@ -4,31 +4,30 @@ E2E tests for multigpu qwen2
import logging
import os
import unittest
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
class TestMultiGPUQwen2(unittest.TestCase):
class TestMultiGPUQwen2:
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_qlora_fsdp_dpo(self, temp_dir):
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2-1.5B",
"base_model": base_model,
"load_in_4bit": True,
"rl": "dpo",
"chat_template": "chatml",
@@ -47,9 +46,9 @@ class TestMultiGPUQwen2(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 5,
"warmup_steps": 20,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
@@ -91,6 +90,8 @@ class TestMultiGPUQwen2(unittest.TestCase):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),