feat: enable trl's autounwrap (#1060)
* feat: test trl's autounwrap * fix: add check for adapter * feat: add config to disable autounwrap * chore: fix lint
This commit is contained in:
@@ -33,3 +33,12 @@ datasets:
|
|||||||
```yaml
|
```yaml
|
||||||
rl: ipo
|
rl: ipo
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Trl autounwrap for peft
|
||||||
|
|
||||||
|
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# load ref model when adapter training.
|
||||||
|
rl_adapter_ref_model: true
|
||||||
|
```
|
||||||
|
|||||||
@@ -63,10 +63,15 @@ def train(
|
|||||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
model_ref = None
|
model_ref = None
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
# load the model again for model_ref/baseline
|
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
model_ref, _ = load_model(
|
# use built-in trl autounwrap
|
||||||
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
LOG.debug("Passing model_ref: None to RL trainer")
|
||||||
)
|
model_ref = None # explicit setting to None
|
||||||
|
else:
|
||||||
|
# load the model again for model_ref/baseline
|
||||||
|
model_ref, _ = load_model(
|
||||||
|
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
||||||
|
)
|
||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user