diff --git a/scripts/finetune.py b/scripts/finetune.py index da08fda0b..e7c70b581 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -18,7 +18,6 @@ from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig, TextStreamer from axolotl.logging_config import configure_logging -from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.dict import DictDefault @@ -226,8 +225,6 @@ def train( LOG.info("Finished preparing dataset. Exiting...") return - log_gpu_memory_usage(LOG, "baseline", cfg.device) - # Load the model and tokenizer LOG.info("loading model and (optionally) peft_config...") model, peft_config = load_model(cfg, tokenizer) diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index 759fb6e21..c39d07772 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -4,13 +4,23 @@ import pynvml import torch -def gpu_memory_usage(device): +def gpu_memory_usage(device=0): + return torch.cuda.memory_allocated(device) / 1024.0**3 + + +def gpu_memory_usage_all(device=0): + usage = torch.cuda.memory_allocated(device) / 1024.0**3 + reserved = torch.cuda.memory_reserved(device) / 1024.0**3 + smi = gpu_memory_usage_smi(device) + return usage, reserved - usage, max(0, smi - reserved) + + +def gpu_memory_usage_smi(device=0): if isinstance(device, torch.device): device = device.index if isinstance(device, str) and device.startswith("cuda:"): device = int(device[5:]) - # NB torch.cuda.memory_usage returns zero so we use lower level api pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(device) info = pynvml.nvmlDeviceGetMemoryInfo(handle) @@ -18,6 +28,13 @@ def gpu_memory_usage(device): def log_gpu_memory_usage(log, msg, device): + usage, cache, misc = gpu_memory_usage_all(device) + extras = [] + if cache > 0: + extras.append(f"+{cache:.03f}GB cache") + if misc > 0: + extras.append(f"+{misc:.03f}GB misc") log.info( - f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2 + f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2 ) + return usage, cache, misc diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index f06762b6b..32a7f0c99 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -74,10 +74,10 @@ class SaveBetterTransformerModelCallback( return control -class PrintGPUStatsCallback( +class GPUStatsCallback( TrainerCallback ): # pylint: disable=too-few-public-methods disable=unused-argument - """Callback to print GPU utilization""" + """Callback to track GPU utilization""" def __init__(self, cfg): self.cfg = cfg @@ -90,7 +90,7 @@ class PrintGPUStatsCallback( control: TrainerControl, **kwargs, ): - if not self.logged: + if not self.logged and state.global_step > 1: log_gpu_memory_usage(LOG, "while training", self.cfg.device) self.logged = True return control diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index e69bffa7a..9873687aa 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -5,6 +5,8 @@ import os import torch +from axolotl.utils.bench import log_gpu_memory_usage + LOG = logging.getLogger("axolotl") @@ -54,6 +56,8 @@ def normalize_config(cfg): else: torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False + log_gpu_memory_usage(LOG, "baseline", cfg.device) + def validate_config(cfg): if cfg.max_packed_sequence_len and cfg.sample_packing: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2f672433d..fcc3a3ac4 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -381,9 +381,6 @@ def load_model( module.scales = module.scales.half() module.bias = module.bias.half() - if model.device.type == "cuda": - log_gpu_memory_usage(LOG, "after adapters", model.device) - if ( torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1 @@ -406,6 +403,9 @@ def load_model( if cfg.flash_optimum: model = BetterTransformer.transform(model) + if cfg.adapter is not None: + log_gpu_memory_usage(LOG, "after adapters", model.device) + # TODO resume_from_checkpoint handling return model, lora_config diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 25d0b1e82..b143cb01f 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -22,7 +22,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import ( - PrintGPUStatsCallback, + GPUStatsCallback, SaveBetterTransformerModelCallback, SavePeftModelCallback, ) @@ -555,7 +555,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) callbacks = [] - callbacks.append(PrintGPUStatsCallback(cfg)) + callbacks.append(GPUStatsCallback(cfg)) # TODO on_save callback to sync checkpoints to GCP/AWS in background if cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback(