update dependencies - liger + trl (#2987)
* update dependencies * set dataset processes for tests * add support for GSPO
This commit is contained in:
@@ -6,7 +6,7 @@ triton>=3.0.0
|
|||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
liger-kernel==0.6.0
|
liger-kernel==0.6.1
|
||||||
# END section
|
# END section
|
||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
@@ -18,7 +18,7 @@ tokenizers>=0.21.1
|
|||||||
accelerate==1.9.0
|
accelerate==1.9.0
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.19.1
|
trl==0.20.0
|
||||||
hf_xet==1.1.5
|
hf_xet==1.1.5
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
|
|||||||
@@ -82,6 +82,11 @@ class GRPOStrategy:
|
|||||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||||
|
|
||||||
|
if trl.importance_sampling_level is not None:
|
||||||
|
grpo_args_kwargs["importance_sampling_level"] = (
|
||||||
|
trl.importance_sampling_level
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.sequence_parallel_degree > 1:
|
if cfg.sequence_parallel_degree > 1:
|
||||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||||
|
|
||||||
|
|||||||
@@ -80,6 +80,14 @@ class TRLConfig(BaseModel):
|
|||||||
"description": "Number of completions to print when log_completions is True."
|
"description": "Number of completions to print when log_completions is True."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
importance_sampling_level: Literal["sequence", "token"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Controls whether importance sampling ratios are computed at the `'token'` or `'sequence'` level. "
|
||||||
|
"For GSPO, use `sequence`, default is None which corresponds to the original GRPO paper."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
sync_ref_model: bool | None = Field(
|
sync_ref_model: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to sync the reference model."},
|
json_schema_extra={"description": "Whether to sync the reference model."},
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ class TestActivationCheckpointing:
|
|||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
"gradient_checkpointing": gradient_checkpointing,
|
"gradient_checkpointing": gradient_checkpointing,
|
||||||
"save_first_step": False,
|
"save_first_step": False,
|
||||||
|
"dataset_processes": 4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user