Fix shape
This commit is contained in:
@@ -775,19 +775,15 @@ class TestNemoGymE2E(unittest.TestCase):
|
||||
trainer = self._make_mock_trainer()
|
||||
producer._trainer = trainer
|
||||
|
||||
# Mock the prompt iterator (returns a batch of 1 input)
|
||||
producer._prompt_iter = iter(
|
||||
[
|
||||
[
|
||||
{
|
||||
"prompt": [{"role": "user", "content": "Play Wordle!"}],
|
||||
}
|
||||
]
|
||||
]
|
||||
)
|
||||
producer._prompt_dl = [
|
||||
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
|
||||
# Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
|
||||
# pre-expands prompts, so the iterator yields num_generations=2 consecutive
|
||||
# copies of each unique prompt — one entry per rollout.
|
||||
_prompt_batch = [
|
||||
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
||||
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
||||
]
|
||||
producer._prompt_iter = iter([_prompt_batch])
|
||||
producer._prompt_dl = [_prompt_batch]
|
||||
|
||||
# Call produce
|
||||
result = producer.produce(model=MagicMock(), global_step=1)
|
||||
@@ -853,10 +849,13 @@ class TestNemoGymE2E(unittest.TestCase):
|
||||
producer._request_timeout = 30
|
||||
producer._num_generations = 2
|
||||
producer._trainer = self._make_mock_trainer()
|
||||
producer._prompt_iter = iter(
|
||||
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
||||
)
|
||||
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
||||
# RepeatSampler pre-expands by num_generations=2.
|
||||
_prompt_batch = [
|
||||
{"prompt": [{"role": "user", "content": "Play!"}]},
|
||||
{"prompt": [{"role": "user", "content": "Play!"}]},
|
||||
]
|
||||
producer._prompt_iter = iter([_prompt_batch])
|
||||
producer._prompt_dl = [_prompt_batch]
|
||||
|
||||
result = producer.produce(model=MagicMock(), global_step=1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user