diff --git a/tests/integrations/test_nemo_gym.py b/tests/integrations/test_nemo_gym.py index 652f744ca..83206043c 100644 --- a/tests/integrations/test_nemo_gym.py +++ b/tests/integrations/test_nemo_gym.py @@ -775,19 +775,15 @@ class TestNemoGymE2E(unittest.TestCase): trainer = self._make_mock_trainer() producer._trainer = trainer - # Mock the prompt iterator (returns a batch of 1 input) - producer._prompt_iter = iter( - [ - [ - { - "prompt": [{"role": "user", "content": "Play Wordle!"}], - } - ] - ] - ) - producer._prompt_dl = [ - [{"prompt": [{"role": "user", "content": "Play Wordle!"}]}] + # Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations) + # pre-expands prompts, so the iterator yields num_generations=2 consecutive + # copies of each unique prompt — one entry per rollout. + _prompt_batch = [ + {"prompt": [{"role": "user", "content": "Play Wordle!"}]}, + {"prompt": [{"role": "user", "content": "Play Wordle!"}]}, ] + producer._prompt_iter = iter([_prompt_batch]) + producer._prompt_dl = [_prompt_batch] # Call produce result = producer.produce(model=MagicMock(), global_step=1) @@ -853,10 +849,13 @@ class TestNemoGymE2E(unittest.TestCase): producer._request_timeout = 30 producer._num_generations = 2 producer._trainer = self._make_mock_trainer() - producer._prompt_iter = iter( - [[{"prompt": [{"role": "user", "content": "Play!"}]}]] - ) - producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]] + # RepeatSampler pre-expands by num_generations=2. + _prompt_batch = [ + {"prompt": [{"role": "user", "content": "Play!"}]}, + {"prompt": [{"role": "user", "content": "Play!"}]}, + ] + producer._prompt_iter = iter([_prompt_batch]) + producer._prompt_dl = [_prompt_batch] result = producer.produce(model=MagicMock(), global_step=1)