Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
e86dd76154 attempt to set start method to spwan to prevent cuda issues for DPO 2024-07-17 09:29:15 -04:00

View File

@@ -13,6 +13,7 @@ from abc import abstractmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from multiprocessing import set_start_method
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, Optional, Type, Union from typing import Dict, List, Literal, Optional, Type, Union
@@ -1770,6 +1771,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
for callback in self.get_post_trainer_create_callbacks(dpo_trainer): for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback) dpo_trainer.add_callback(callback)
# prevents multiprocessing issues for datasets on multiple GPUs
set_start_method("spawn")
return dpo_trainer return dpo_trainer