feat: enable trl's autounwrap (#1060)
* feat: test trl's autounwrap * fix: add check for adapter * feat: add config to disable autounwrap * chore: fix lint
This commit is contained in:
@@ -63,10 +63,15 @@ def train(
|
||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
model_ref = None
|
||||
if cfg.rl:
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(
|
||||
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
||||
)
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
# use built-in trl autounwrap
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(
|
||||
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
||||
)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user