improve GPU logging to break out pytorch cache and system mem
This commit is contained in:
committed by
Aman Gupta Karmani
parent
e029ab34ea
commit
7b55fe6419
@@ -18,7 +18,6 @@ from optimum.bettertransformer import BetterTransformer
|
|||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -226,8 +225,6 @@ def train(
|
|||||||
LOG.info("Finished preparing dataset. Exiting...")
|
LOG.info("Finished preparing dataset. Exiting...")
|
||||||
return
|
return
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
LOG.info("loading model and (optionally) peft_config...")
|
LOG.info("loading model and (optionally) peft_config...")
|
||||||
model, peft_config = load_model(cfg, tokenizer)
|
model, peft_config = load_model(cfg, tokenizer)
|
||||||
|
|||||||
@@ -4,13 +4,23 @@ import pynvml
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def gpu_memory_usage(device):
|
def gpu_memory_usage(device=0):
|
||||||
|
return torch.cuda.memory_allocated(device) / 1024.0**3
|
||||||
|
|
||||||
|
|
||||||
|
def gpu_memory_usage_all(device=0):
|
||||||
|
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
||||||
|
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
||||||
|
smi = gpu_memory_usage_smi(device)
|
||||||
|
return usage, reserved - usage, max(0, smi - reserved)
|
||||||
|
|
||||||
|
|
||||||
|
def gpu_memory_usage_smi(device=0):
|
||||||
if isinstance(device, torch.device):
|
if isinstance(device, torch.device):
|
||||||
device = device.index
|
device = device.index
|
||||||
if isinstance(device, str) and device.startswith("cuda:"):
|
if isinstance(device, str) and device.startswith("cuda:"):
|
||||||
device = int(device[5:])
|
device = int(device[5:])
|
||||||
|
|
||||||
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
@@ -18,6 +28,13 @@ def gpu_memory_usage(device):
|
|||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
|
usage, cache, misc = gpu_memory_usage_all(device)
|
||||||
|
extras = []
|
||||||
|
if cache > 0:
|
||||||
|
extras.append(f"+{cache:.03f}GB cache")
|
||||||
|
if misc > 0:
|
||||||
|
extras.append(f"+{misc:.03f}GB misc")
|
||||||
log.info(
|
log.info(
|
||||||
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
|
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
|
||||||
)
|
)
|
||||||
|
return usage, cache, misc
|
||||||
|
|||||||
@@ -74,10 +74,10 @@ class SaveBetterTransformerModelCallback(
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class PrintGPUStatsCallback(
|
class GPUStatsCallback(
|
||||||
TrainerCallback
|
TrainerCallback
|
||||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||||
"""Callback to print GPU utilization"""
|
"""Callback to track GPU utilization"""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
@@ -90,7 +90,7 @@ class PrintGPUStatsCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if not self.logged:
|
if not self.logged and state.global_step > 1:
|
||||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||||
self.logged = True
|
self.logged = True
|
||||||
return control
|
return control
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
@@ -54,6 +56,8 @@ def normalize_config(cfg):
|
|||||||
else:
|
else:
|
||||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||||
|
|
||||||
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||||
|
|||||||
@@ -381,9 +381,6 @@ def load_model(
|
|||||||
module.scales = module.scales.half()
|
module.scales = module.scales.half()
|
||||||
module.bias = module.bias.half()
|
module.bias = module.bias.half()
|
||||||
|
|
||||||
if model.device.type == "cuda":
|
|
||||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
torch.cuda.device_count() > 1
|
torch.cuda.device_count() > 1
|
||||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||||
@@ -406,6 +403,9 @@ def load_model(
|
|||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.transform(model)
|
model = BetterTransformer.transform(model)
|
||||||
|
|
||||||
|
if cfg.adapter is not None:
|
||||||
|
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
|||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
PrintGPUStatsCallback,
|
GPUStatsCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
)
|
)
|
||||||
@@ -555,7 +555,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
callbacks.append(PrintGPUStatsCallback(cfg))
|
callbacks.append(GPUStatsCallback(cfg))
|
||||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||||
if cfg.early_stopping_patience:
|
if cfg.early_stopping_patience:
|
||||||
early_stop_cb = EarlyStoppingCallback(
|
early_stop_cb = EarlyStoppingCallback(
|
||||||
|
|||||||
Reference in New Issue
Block a user