gradient accumulation tests, embeddings w pad_token fix, smaller models (#2059)
* add more test cases for gradient accumulation and fix zero3 * swap out for smaller model * fix missing return * fix missing pad_token in config * support concurrency for multigpu testing * cast empty deepspeed to empty string for zero3 check * fix temp_dir as fixture so parametrize works properly * fix test file for multigpu evals * don't use default * don't use default for fsdp_state_dict_type * don't use llama tokenizer w smollm * also automatically cancel multigpu for concurrency
This commit is contained in:
5
.github/workflows/multi-gpu-e2e.yml
vendored
5
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,6 +8,11 @@ on:
|
|||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||||
|
|
||||||
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-axolotl-multigpu:
|
test-axolotl-multigpu:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# only run one test at a time so as not to OOM the GPU
|
# only run one test at a time so as not to OOM the GPU
|
||||||
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/
|
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
||||||
|
|||||||
@@ -1291,6 +1291,25 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def warn_qlora_zero3_w_use_reentrant(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("adapter") == "qlora"
|
||||||
|
and data.get("gradient_checkpointing_kwargs", {})
|
||||||
|
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||||
|
is False
|
||||||
|
and "zero3" in data.get("deepspeed", "")
|
||||||
|
):
|
||||||
|
# may result in:
|
||||||
|
# torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint:
|
||||||
|
# Recomputed values for the following tensors have different metadata
|
||||||
|
# than during the forward pass.
|
||||||
|
LOG.warning(
|
||||||
|
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_val_w_test_datasets(cls, data):
|
def check_val_w_test_datasets(cls, data):
|
||||||
|
|||||||
@@ -238,6 +238,7 @@ def load_tokenizer(cfg):
|
|||||||
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
and k != "pad_token"
|
||||||
):
|
):
|
||||||
lora_modules_to_save = ", ".join(
|
lora_modules_to_save = ", ".join(
|
||||||
[f"`{x}`" for x in lora_modules_to_save]
|
[f"`{x}`" for x in lora_modules_to_save]
|
||||||
|
|||||||
16
tests/e2e/conftest.py
Normal file
16
tests/e2e/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
shared pytest fixtures
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
# Create a temporary directory
|
||||||
|
_temp_dir = tempfile.mkdtemp()
|
||||||
|
yield _temp_dir
|
||||||
|
# Clean up the directory after the test
|
||||||
|
shutil.rmtree(_temp_dir)
|
||||||
@@ -3,28 +3,25 @@ E2E tests for multigpu eval
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPUEval(unittest.TestCase):
|
class TestMultiGPUEval:
|
||||||
"""
|
"""
|
||||||
Test case for MultiGPU Eval Sample Packing
|
Test case for MultiGPU Eval Sample Packing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_eval_sample_packing(self, temp_dir):
|
def test_eval_sample_packing(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -83,13 +80,14 @@ class TestMultiGPUEval(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_eval(self, temp_dir):
|
def test_eval(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -148,6 +146,8 @@ class TestMultiGPUEval(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ E2E tests for multigpu lora tinyllama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import is_hopper, with_temp_dir
|
from ..utils import is_hopper
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -28,18 +28,16 @@ def download_model():
|
|||||||
snapshot_download("TinyLlama/TinyLlama_v1.1")
|
snapshot_download("TinyLlama/TinyLlama_v1.1")
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPULlama(unittest.TestCase):
|
class TestMultiGPULlama:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_lora_ddp(self, temp_dir):
|
def test_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
@@ -48,9 +46,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -81,19 +77,23 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_lora_ddp_packed(self, temp_dir):
|
"gradient_accumulation_steps",
|
||||||
|
[1, 4],
|
||||||
|
)
|
||||||
|
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
@@ -105,9 +105,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -118,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 15,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -138,6 +136,8 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
@@ -145,7 +145,6 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
||||||
@with_temp_dir
|
|
||||||
def test_dpo_lora_ddp(self, temp_dir):
|
def test_dpo_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -210,13 +209,14 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_dpo_qlora_ddp(self, temp_dir):
|
def test_dpo_qlora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -278,25 +278,27 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_fsdp(self, temp_dir):
|
"gradient_accumulation_steps",
|
||||||
|
[1, 4],
|
||||||
|
)
|
||||||
|
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.01,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -305,9 +307,9 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 10,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
@@ -324,7 +326,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"fsdp_use_orig_params": False,
|
"fsdp_use_orig_params": False,
|
||||||
"fsdp_cpu_ram_efficient_loading": False,
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -341,28 +343,29 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_fsdp_packed(self, temp_dir):
|
"fsdp_state_dict_type",
|
||||||
|
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||||
|
)
|
||||||
|
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -390,7 +393,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"fsdp_use_orig_params": False,
|
"fsdp_use_orig_params": False,
|
||||||
"fsdp_cpu_ram_efficient_loading": False,
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
"fsdp_state_dict_type": fsdp_state_dict_type,
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -407,13 +410,14 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -483,28 +487,29 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_ds_zero3_packed(self, temp_dir):
|
"gradient_accumulation_steps",
|
||||||
|
[1, 4],
|
||||||
|
)
|
||||||
|
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -515,7 +520,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 15,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
@@ -536,19 +541,19 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_ds_zero3_qlora_packed(self, temp_dir):
|
def test_ds_zero3_qlora_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
@@ -561,9 +566,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -595,6 +598,8 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
@@ -4,31 +4,30 @@ E2E tests for multigpu qwen2
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPUQwen2(unittest.TestCase):
|
class TestMultiGPUQwen2:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||||
def test_qlora_fsdp_dpo(self, temp_dir):
|
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "Qwen/Qwen2-1.5B",
|
"base_model": base_model,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"rl": "dpo",
|
"rl": "dpo",
|
||||||
"chat_template": "chatml",
|
"chat_template": "chatml",
|
||||||
@@ -47,9 +46,9 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 5,
|
||||||
"warmup_steps": 20,
|
"warmup_steps": 20,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
@@ -91,6 +90,8 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
Reference in New Issue
Block a user