attempt to also run e2e tests that needs gpus (#1070)

* attempt to also run e2e tests that needs gpus

* fix stray quote

* checkout specific github ref

* dockerfile for tests with proper checkout

ensure wandb is dissabled for docker pytests
clear wandb env after testing
clear wandb env after testing
make sure to provide a default val for pop
tryin skipping wandb validation tests
explicitly disable wandb in the e2e tests
explicitly report_to None to see if that fixes the docker e2e tests
split gpu from non-gpu unit tests
skip bf16 check in test for now
build docker w/o cache since it uses branch name ref
revert some changes now that caching is fixed
skip bf16 check if on gpu w support

* pytest skip for auto-gptq requirements

* skip mamba tests for now, split multipack and non packed lora llama tests

* split tests that use monkeypatches

* fix relative import for prev commit

* move other tests using monkeypatches to the correct run
This commit is contained in:
Wing Lian
2024-01-09 21:23:23 -05:00
committed by GitHub
parent 9be92d1448
commit 788649fe95
13 changed files with 214 additions and 105 deletions

View File

@@ -8,6 +8,7 @@ import unittest
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -59,7 +60,6 @@ class TestPhi(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"bf16": True,
"flash_attention": True,
"max_steps": 10,
"save_steps": 10,
@@ -67,6 +67,10 @@ class TestPhi(unittest.TestCase):
"save_safetensors": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -110,9 +114,13 @@ class TestPhi(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"bf16": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)