diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 662b64896..0bb5f11e7 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -13,6 +13,7 @@ from abc import abstractmethod from collections import defaultdict from dataclasses import dataclass, field from functools import wraps +from multiprocessing import set_start_method from pathlib import Path from typing import Dict, List, Literal, Optional, Type, Union @@ -1770,6 +1771,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): for callback in self.get_post_trainer_create_callbacks(dpo_trainer): dpo_trainer.add_callback(callback) + # prevents multiprocessing issues for datasets on multiple GPUs + set_start_method("spawn") + return dpo_trainer