Register Plugins in Ray Workers (#2901) [skip ci]
* Access plugins in ray cluster * Add comment * chore: lint --------- Co-authored-by: Ed Sealing <ed.sealing@patapsco.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -109,6 +109,13 @@ def ray_train_func(kwargs: dict):
|
|||||||
# initialize accelerator before model instantiation
|
# initialize accelerator before model instantiation
|
||||||
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
|
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
# Register plugins in Ray workers
|
||||||
|
if cfg.get("plugins"):
|
||||||
|
from axolotl.cli.config import plugin_set_cfg, prepare_plugins
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
kwargs["cfg"] = cfg
|
kwargs["cfg"] = cfg
|
||||||
|
|
||||||
do_train(**kwargs)
|
do_train(**kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user