feat: enable trl's autounwrap (#1060)

* feat: test trl's autounwrap

* fix: add check for adapter

* feat: add config to disable autounwrap

* chore: fix lint
This commit is contained in:
NanoCode012
2024-01-11 22:43:41 +09:00
committed by GitHub
parent 54fe07a905
commit b432889256
5 changed files with 24 additions and 10 deletions

View File

@@ -33,3 +33,12 @@ datasets:
```yaml
rl: ipo
```
#### Trl autounwrap for peft
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
```yaml
# load ref model when adapter training.
rl_adapter_ref_model: true
```

View File

@@ -63,10 +63,15 @@ def train(
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model_ref = None
if cfg.rl:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)
safe_serialization = cfg.save_safetensors is True