From 24146733db454aa4e35cb59e99612ce3940bc6a8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 14 Sep 2023 22:49:27 -0400 Subject: [PATCH] E2e device cuda (#575) * use torch.cuda.current_device() instead of local_rank * ignore NVML errors for gpu stats * llama lora packing e2e tests --- .github/workflows/e2e.yml | 1 + src/axolotl/utils/bench.py | 13 ++++++----- src/axolotl/utils/config.py | 2 +- tests/e2e/test_lora_llama.py | 42 ++++++++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index ada1fd0c4..09c26c2a6 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -24,6 +24,7 @@ jobs: - name: Install dependencies run: | pip3 install -e . + pip3 install flash-attn pip3 install -r requirements-tests.txt - name: Run e2e tests diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index 30f0985e7..b460b2ba7 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -2,6 +2,7 @@ import pynvml import torch +from pynvml.nvml import NVMLError def gpu_memory_usage(device=0): @@ -20,11 +21,13 @@ def gpu_memory_usage_smi(device=0): device = device.index if isinstance(device, str) and device.startswith("cuda:"): device = int(device[5:]) - - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(device) - info = pynvml.nvmlDeviceGetMemoryInfo(handle) - return info.used / 1024.0**3 + try: + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + return info.used / 1024.0**3 + except NVMLError: + return 0.0 def log_gpu_memory_usage(log, msg, device): diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 90ed409b9..a31f34b73 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -29,7 +29,7 @@ def choose_device(cfg): cfg.device_map = "auto" else: if cfg.device.startswith("cuda"): - cfg.device_map = {"": cfg.local_rank} + cfg.device_map = {"": torch.cuda.current_device()} else: cfg.device_map = {"": cfg.device} diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 7873b7ec2..905c3711f 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -78,3 +78,45 @@ class TestLoraLlama(unittest.TestCase): dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + + def test_lora_packing(self): + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "base_model_config": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": tempfile.mkdtemp(), + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)