Fix shape

This commit is contained in:
Wing Lian
2026-04-19 01:53:05 +00:00
parent a892d8cce1
commit 4a5281e61a

View File

@@ -775,19 +775,15 @@ class TestNemoGymE2E(unittest.TestCase):
trainer = self._make_mock_trainer() trainer = self._make_mock_trainer()
producer._trainer = trainer producer._trainer = trainer
# Mock the prompt iterator (returns a batch of 1 input) # Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
producer._prompt_iter = iter( # pre-expands prompts, so the iterator yields num_generations=2 consecutive
[ # copies of each unique prompt — one entry per rollout.
[ _prompt_batch = [
{ {"prompt": [{"role": "user", "content": "Play Wordle!"}]},
"prompt": [{"role": "user", "content": "Play Wordle!"}], {"prompt": [{"role": "user", "content": "Play Wordle!"}]},
}
]
]
)
producer._prompt_dl = [
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
] ]
producer._prompt_iter = iter([_prompt_batch])
producer._prompt_dl = [_prompt_batch]
# Call produce # Call produce
result = producer.produce(model=MagicMock(), global_step=1) result = producer.produce(model=MagicMock(), global_step=1)
@@ -853,10 +849,13 @@ class TestNemoGymE2E(unittest.TestCase):
producer._request_timeout = 30 producer._request_timeout = 30
producer._num_generations = 2 producer._num_generations = 2
producer._trainer = self._make_mock_trainer() producer._trainer = self._make_mock_trainer()
producer._prompt_iter = iter( # RepeatSampler pre-expands by num_generations=2.
[[{"prompt": [{"role": "user", "content": "Play!"}]}]] _prompt_batch = [
) {"prompt": [{"role": "user", "content": "Play!"}]},
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]] {"prompt": [{"role": "user", "content": "Play!"}]},
]
producer._prompt_iter = iter([_prompt_batch])
producer._prompt_dl = [_prompt_batch]
result = producer.produce(model=MagicMock(), global_step=1) result = producer.produce(model=MagicMock(), global_step=1)