clear cuda cache to help with memory leak/creep (#1858)

* clear cuda cache to help with memory leak/creep

* reverse order of gc
This commit is contained in:
Wing Lian
2024-08-26 15:50:26 -04:00
committed by GitHub
parent 2dac1edf72
commit 17af1d7081

View File

@@ -4,6 +4,7 @@ Builder for the training args and trainer
""" """
import abc import abc
import gc
import importlib import importlib
import importlib.util import importlib.util
import logging import logging
@@ -15,11 +16,12 @@ from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, Optional, Type, Union from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import ( from transformers import (
@@ -997,6 +999,14 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
res[key] = res[key][1:] res[key] = res[key][1:]
return res return res
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs)
gc.collect()
torch.cuda.empty_cache()
return loss
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
""" """