Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
e86dd76154 attempt to set start method to spwan to prevent cuda issues for DPO 2024-07-17 09:29:15 -04:00

View File

@@ -13,6 +13,7 @@ from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from multiprocessing import set_start_method
from pathlib import Path
from typing import Dict, List, Literal, Optional, Type, Union
@@ -1770,6 +1771,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)
# prevents multiprocessing issues for datasets on multiple GPUs
set_start_method("spawn")
return dpo_trainer