log GPU memory usage
This commit is contained in:
@@ -19,3 +19,4 @@ evaluate==0.4.0
|
|||||||
rouge-score==0.1.2
|
rouge-score==0.1.2
|
||||||
scipy
|
scipy
|
||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
|
nvidia-ml-py3
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from optimum.bettertransformer import BetterTransformer
|
|||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
@@ -250,6 +251,8 @@ def train(
|
|||||||
LOG.info("Finished preparing dataset. Exiting...")
|
LOG.info("Finished preparing dataset. Exiting...")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
LOG.info("loading model and peft_config...")
|
LOG.info("loading model and peft_config...")
|
||||||
model, peft_config = load_model(
|
model, peft_config = load_model(
|
||||||
|
|||||||
23
src/axolotl/utils/bench.py
Normal file
23
src/axolotl/utils/bench.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""Benchmarking and measurement utilities"""
|
||||||
|
|
||||||
|
import pynvml
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def gpu_memory_usage(device):
|
||||||
|
if isinstance(device, torch.device):
|
||||||
|
device = device.index
|
||||||
|
if isinstance(device, str) and device.startswith("cuda:"):
|
||||||
|
device = int(device[5:])
|
||||||
|
|
||||||
|
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||||
|
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
|
return info.used / 1024.0**3
|
||||||
|
|
||||||
|
|
||||||
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
|
log.info(
|
||||||
|
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
|
||||||
|
)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Callbacks for Trainer class"""
|
"""Callbacks for Trainer class"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
@@ -11,6 +12,10 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
|
||||||
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||||
"""Callback to save the PEFT adapter"""
|
"""Callback to save the PEFT adapter"""
|
||||||
@@ -67,3 +72,25 @@ class SaveBetterTransformerModelCallback(
|
|||||||
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||||
control.should_save = False
|
control.should_save = False
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class PrintGPUStatsCallback(
|
||||||
|
TrainerCallback
|
||||||
|
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||||
|
"""Callback to print GPU utilization"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.logged = False
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if not self.logged:
|
||||||
|
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||||
|
self.logged = True
|
||||||
|
return control
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from transformers import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -324,6 +325,9 @@ def load_model(
|
|||||||
)
|
)
|
||||||
model.config.max_position_embeddings = cfg.sequence_len
|
model.config.max_position_embeddings = cfg.sequence_len
|
||||||
|
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||||
|
|
||||||
if not cfg.gptq and (
|
if not cfg.gptq and (
|
||||||
(cfg.adapter == "lora" and load_in_8bit)
|
(cfg.adapter == "lora" and load_in_8bit)
|
||||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||||
@@ -360,6 +364,9 @@ def load_model(
|
|||||||
module.scales = module.scales.half()
|
module.scales = module.scales.half()
|
||||||
module.bias = module.bias.half()
|
module.bias = module.bias.half()
|
||||||
|
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
torch.cuda.device_count() > 1
|
torch.cuda.device_count() > 1
|
||||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
|||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
|
PrintGPUStatsCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
)
|
)
|
||||||
@@ -292,6 +293,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
|
callbacks.append(PrintGPUStatsCallback(cfg))
|
||||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||||
if cfg.early_stopping_patience:
|
if cfg.early_stopping_patience:
|
||||||
early_stop_cb = EarlyStoppingCallback(
|
early_stop_cb = EarlyStoppingCallback(
|
||||||
|
|||||||
Reference in New Issue
Block a user