diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 28d4d543a..d0cf8455b 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -109,6 +109,13 @@ def ray_train_func(kwargs: dict): # initialize accelerator before model instantiation Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps) + # Register plugins in Ray workers + if cfg.get("plugins"): + from axolotl.cli.config import plugin_set_cfg, prepare_plugins + + prepare_plugins(cfg) + plugin_set_cfg(cfg) + kwargs["cfg"] = cfg do_train(**kwargs)