clear cuda cache to help with memory leak/creep (#1858)

* clear cuda cache to help with memory leak/creep

* reverse order of gc
This commit is contained in:
Wing Lian
2024-08-26 15:50:26 -04:00
committed by GitHub
parent 2dac1edf72
commit 17af1d7081

View File

@@ -4,6 +4,7 @@ Builder for the training args and trainer
"""
import abc
import gc
import importlib
import importlib.util
import logging
@@ -15,11 +16,12 @@ from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Dict, List, Literal, Optional, Type, Union
from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -997,6 +999,14 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
res[key] = res[key][1:]
return res
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs)
gc.collect()
torch.cuda.empty_cache()
return loss
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""