feat(grpo): add reward_weights config and refactor (#2365)
This commit is contained in:
@@ -9,6 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -31,30 +32,44 @@ class GRPOStrategy:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def set_training_args_kwargs(cls, cfg):
|
def set_training_args_kwargs(cls, cfg):
|
||||||
grpo_args_kwargs = {}
|
grpo_args_kwargs = {}
|
||||||
if cfg.trl and cfg.trl.use_vllm:
|
|
||||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
if not hasattr(cfg, "trl") or not cfg.trl:
|
||||||
if cfg.trl and cfg.trl.vllm_device:
|
return grpo_args_kwargs
|
||||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
|
||||||
else:
|
trl: TRLConfig = cfg.trl # type: ignore
|
||||||
grpo_args_kwargs["vllm_device"] = "auto"
|
|
||||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
if trl.use_vllm:
|
||||||
|
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||||
|
grpo_args_kwargs["vllm_device"] = (
|
||||||
|
trl.vllm_device if trl.vllm_device else "auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
if trl.vllm_gpu_memory_utilization:
|
||||||
grpo_args_kwargs[
|
grpo_args_kwargs[
|
||||||
"vllm_gpu_memory_utilization"
|
"vllm_gpu_memory_utilization"
|
||||||
] = cfg.trl.vllm_gpu_memory_utilization
|
] = trl.vllm_gpu_memory_utilization
|
||||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
|
||||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
if trl.vllm_max_model_len:
|
||||||
if cfg.trl and cfg.trl.num_generations:
|
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
|
||||||
if cfg.trl and cfg.trl.sync_ref_model:
|
if trl.num_generations:
|
||||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
|
||||||
grpo_args_kwargs[
|
if trl.sync_ref_model:
|
||||||
"ref_model_mixup_alpha"
|
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
|
||||||
] = cfg.trl.ref_model_mixup_alpha
|
|
||||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
if trl.ref_model_mixup_alpha:
|
||||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
|
||||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
|
||||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
if trl.ref_model_sync_steps:
|
||||||
|
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
|
||||||
|
|
||||||
|
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
|
||||||
|
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||||
|
|
||||||
|
if trl.reward_weights:
|
||||||
|
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||||
|
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class TRLConfig(BaseModel):
|
|||||||
vllm_dtype: Optional[str] = "auto"
|
vllm_dtype: Optional[str] = "auto"
|
||||||
|
|
||||||
reward_funcs: Optional[List[str]] = None
|
reward_funcs: Optional[List[str]] = None
|
||||||
|
reward_weights: Optional[List[float]] = None
|
||||||
num_generations: Optional[int] = None
|
num_generations: Optional[int] = None
|
||||||
log_completions: Optional[bool] = False
|
log_completions: Optional[bool] = False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user