log GPU memory usage

This commit is contained in:
Aman Karmani
2023-08-09 08:10:37 +00:00
parent 176b888a63
commit e303d64728
6 changed files with 63 additions and 0 deletions

View File

@@ -18,6 +18,7 @@ from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -250,6 +251,8 @@ def train(
LOG.info("Finished preparing dataset. Exiting...")
return
log_gpu_memory_usage(LOG, "baseline", cfg.device)
# Load the model and tokenizer
LOG.info("loading model and peft_config...")
model, peft_config = load_model(