diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 967179903..d57cb463e 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -115,8 +115,11 @@ def setup_reference_model( LOG.debug("Passing model_ref: None to RL trainer") model_ref = None # explicit setting to None else: + reference_model: bool = True + if cfg.rl == RLType.GRPO and cfg.trl.beta == 0: + reference_model = False # load the model again for model_ref/baseline - model_loader = ModelLoader(cfg, tokenizer, reference_model=True) + model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model) model_ref, _ = model_loader.load() return model_ref