* ipo-dpo trainer * fix missing abstract method * chatml template, grad checkpointing kwargs support * fix steps calc for RL and add dataloader kwargs * wip to fix dpo and start ppo * more fixes * refactor to generalize map fn * fix dataset loop and handle argilla pref dataset * set training args * load reference model on seperate gpu if more than one device * no auto upload to hub for dpo, don't add lora adapters to ref model for dpo * fixes for rl training * support for ipo from yaml * set dpo training args from the config, add tests * chore: lint * set sequence_len for model in test * add RLHF docs
67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
"""
|
|
module for TRL PPO training
|
|
"""
|
|
import torch
|
|
from tqdm import tqdm
|
|
from trl import PPOTrainer
|
|
|
|
|
|
class TRLPPOTrainer(PPOTrainer):
|
|
"""
|
|
wrapper for ppo trainer to handle customizations
|
|
"""
|
|
|
|
def train(
|
|
self,
|
|
reward_pipe,
|
|
resume_from_checkpoint=None, # pylint: disable=unused-argument
|
|
):
|
|
generation_kwargs = {
|
|
"min_length": -1,
|
|
"top_k": 0.0,
|
|
"top_p": 1.0,
|
|
"do_sample": True,
|
|
"pad_token_id": self.tokenizer.eos_token_id,
|
|
"max_new_tokens": 32,
|
|
}
|
|
sent_kwargs = {
|
|
"return_all_scores": True,
|
|
"function_to_apply": "none",
|
|
"batch_size": 16,
|
|
}
|
|
|
|
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
|
enumerate(self.dataloader)
|
|
):
|
|
query_tensors = batch["input_ids"]
|
|
|
|
# generate model response
|
|
response_tensors, ref_response_tensors = self.generate(
|
|
query_tensors,
|
|
return_prompt=False,
|
|
generate_ref_response=True,
|
|
**generation_kwargs
|
|
)
|
|
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
|
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
|
|
|
# Compute sentiment score
|
|
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
|
pipe_outputs = reward_pipe(texts, **sent_kwargs)
|
|
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
|
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
|
|
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
|
|
ref_rewards = [
|
|
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
|
|
]
|
|
batch["ref_rewards"] = ref_rewards
|
|
|
|
# Run PPO step
|
|
stats = self.step(query_tensors, response_tensors, rewards)
|
|
self.log_stats(
|
|
stats,
|
|
batch,
|
|
rewards,
|
|
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
|
)
|