clear cuda cache to help with memory leak/creep (#1858)
* clear cuda cache to help with memory leak/creep * reverse order of gc
This commit is contained in:
@@ -4,6 +4,7 @@ Builder for the training args and trainer
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import gc
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
@@ -15,11 +16,12 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -997,6 +999,14 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
res[key] = res[key][1:]
|
res[key] = res[key][1:]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def training_step(
|
||||||
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
loss: torch.Tensor = super().training_step(model, inputs)
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user