update dependencies - liger + trl (#2987)
* update dependencies * set dataset processes for tests * add support for GSPO
This commit is contained in:
@@ -82,6 +82,11 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
|
||||
if trl.importance_sampling_level is not None:
|
||||
grpo_args_kwargs["importance_sampling_level"] = (
|
||||
trl.importance_sampling_level
|
||||
)
|
||||
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||
|
||||
|
||||
@@ -80,6 +80,14 @@ class TRLConfig(BaseModel):
|
||||
"description": "Number of completions to print when log_completions is True."
|
||||
},
|
||||
)
|
||||
importance_sampling_level: Literal["sequence", "token"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Controls whether importance sampling ratios are computed at the `'token'` or `'sequence'` level. "
|
||||
"For GSPO, use `sequence`, default is None which corresponds to the original GRPO paper."
|
||||
},
|
||||
)
|
||||
|
||||
sync_ref_model: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to sync the reference model."},
|
||||
|
||||
Reference in New Issue
Block a user