Fix shape
This commit is contained in:
@@ -775,19 +775,15 @@ class TestNemoGymE2E(unittest.TestCase):
|
|||||||
trainer = self._make_mock_trainer()
|
trainer = self._make_mock_trainer()
|
||||||
producer._trainer = trainer
|
producer._trainer = trainer
|
||||||
|
|
||||||
# Mock the prompt iterator (returns a batch of 1 input)
|
# Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
|
||||||
producer._prompt_iter = iter(
|
# pre-expands prompts, so the iterator yields num_generations=2 consecutive
|
||||||
[
|
# copies of each unique prompt — one entry per rollout.
|
||||||
[
|
_prompt_batch = [
|
||||||
{
|
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
||||||
"prompt": [{"role": "user", "content": "Play Wordle!"}],
|
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
|
||||||
}
|
|
||||||
]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
producer._prompt_dl = [
|
|
||||||
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
|
|
||||||
]
|
]
|
||||||
|
producer._prompt_iter = iter([_prompt_batch])
|
||||||
|
producer._prompt_dl = [_prompt_batch]
|
||||||
|
|
||||||
# Call produce
|
# Call produce
|
||||||
result = producer.produce(model=MagicMock(), global_step=1)
|
result = producer.produce(model=MagicMock(), global_step=1)
|
||||||
@@ -853,10 +849,13 @@ class TestNemoGymE2E(unittest.TestCase):
|
|||||||
producer._request_timeout = 30
|
producer._request_timeout = 30
|
||||||
producer._num_generations = 2
|
producer._num_generations = 2
|
||||||
producer._trainer = self._make_mock_trainer()
|
producer._trainer = self._make_mock_trainer()
|
||||||
producer._prompt_iter = iter(
|
# RepeatSampler pre-expands by num_generations=2.
|
||||||
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
_prompt_batch = [
|
||||||
)
|
{"prompt": [{"role": "user", "content": "Play!"}]},
|
||||||
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
|
{"prompt": [{"role": "user", "content": "Play!"}]},
|
||||||
|
]
|
||||||
|
producer._prompt_iter = iter([_prompt_batch])
|
||||||
|
producer._prompt_dl = [_prompt_batch]
|
||||||
|
|
||||||
result = producer.produce(model=MagicMock(), global_step=1)
|
result = producer.produce(model=MagicMock(), global_step=1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user