From 563f5eed7a77370602526a88fa89ef3adc3c2759 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 31 Jul 2025 11:17:17 -0400 Subject: [PATCH] update dependencies - liger + trl (#2987) * update dependencies * set dataset processes for tests * add support for GSPO --- requirements.txt | 4 ++-- src/axolotl/core/trainers/grpo/__init__.py | 5 +++++ src/axolotl/utils/schemas/trl.py | 8 ++++++++ tests/e2e/patched/test_activation_checkpointing.py | 1 + 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8e473bf6b..ae433193f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 autoawq==0.2.7.post3 -liger-kernel==0.6.0 +liger-kernel==0.6.1 # END section packaging==23.2 @@ -18,7 +18,7 @@ tokenizers>=0.21.1 accelerate==1.9.0 datasets==4.0.0 deepspeed>=0.17.0 -trl==0.19.1 +trl==0.20.0 hf_xet==1.1.5 optimum==1.16.2 diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 2c6eb8c6f..5f8e4a8b3 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -82,6 +82,11 @@ class GRPOStrategy: grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print + if trl.importance_sampling_level is not None: + grpo_args_kwargs["importance_sampling_level"] = ( + trl.importance_sampling_level + ) + if cfg.sequence_parallel_degree > 1: grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index e4d17bc94..980474e87 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -80,6 +80,14 @@ class TRLConfig(BaseModel): "description": "Number of completions to print when log_completions is True." }, ) + importance_sampling_level: Literal["sequence", "token"] | None = Field( + default=None, + json_schema_extra={ + "description": "Controls whether importance sampling ratios are computed at the `'token'` or `'sequence'` level. " + "For GSPO, use `sequence`, default is None which corresponds to the original GRPO paper." + }, + ) + sync_ref_model: bool | None = Field( default=False, json_schema_extra={"description": "Whether to sync the reference model."}, diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py index 3d5b3dc56..06e3de274 100644 --- a/tests/e2e/patched/test_activation_checkpointing.py +++ b/tests/e2e/patched/test_activation_checkpointing.py @@ -70,6 +70,7 @@ class TestActivationCheckpointing: "save_safetensors": True, "gradient_checkpointing": gradient_checkpointing, "save_first_step": False, + "dataset_processes": 4, } )