diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index 5f1aec8ff..fab01a644 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -281,7 +281,9 @@ class TestHFRLTrainerBuilder: # Other settings assert training_arguments.dataloader_num_workers == 1 assert training_arguments.dataloader_pin_memory is True - assert training_arguments.gradient_checkpointing is False + + # TODO(wing): restore once trl releases 0.22.0 + # assert training_arguments.gradient_checkpointing is True def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer): builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)