clear cuda cache to help with memory leak/creep (#1858)
* clear cuda cache to help with memory leak/creep * reverse order of gc
This commit is contained in:
@@ -4,6 +4,7 @@ Builder for the training args and trainer
|
||||
"""
|
||||
|
||||
import abc
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.util
|
||||
import logging
|
||||
@@ -15,11 +16,12 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Type, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
@@ -997,6 +999,14 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
res[key] = res[key][1:]
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user