Support for Async GRPO (#3486)
* async grpo support * implement data producer * use fast async * handle call to create data producer * fix liger kernel setup * fix replay buffer * chore: lint * make gpus go brrr * chore: lint * inplace div_, unwrap model for logits in bf16 * fuse selective softmax and empty cuda cache on each scoring step * remove waiting for synch time and fix race * make fp8 work and allow lora kernels w rl * grpo with lora vllm sync and fixes for sharded distributed * update docs * more patches so it works against trl main * address PR feedback for corerabbit
This commit is contained in:
220
tests/core/test_async_grpo.py
Normal file
220
tests/core/test_async_grpo.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Unit tests for async GRPO"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TestReplayBuffer(unittest.TestCase):
|
||||
"""Tests for ReplayBuffer edge cases."""
|
||||
|
||||
def test_add_noop_when_max_size_zero(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=0)
|
||||
buf.add(1.0, {"data": "test"})
|
||||
self.assertEqual(len(buf), 0)
|
||||
|
||||
def test_add_noop_when_max_size_negative(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=-1)
|
||||
buf.add(1.0, {"data": "test"})
|
||||
self.assertEqual(len(buf), 0)
|
||||
|
||||
def test_sample_returns_none_when_max_size_zero(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=0)
|
||||
self.assertIsNone(buf.sample(1))
|
||||
|
||||
def test_sample_returns_none_when_empty(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=5)
|
||||
self.assertIsNone(buf.sample(1))
|
||||
|
||||
def test_normal_add_and_sample(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=3)
|
||||
buf.add(1.0, {"a": 1})
|
||||
buf.add(2.0, {"a": 2})
|
||||
buf.add(3.0, {"a": 3})
|
||||
self.assertEqual(len(buf), 3)
|
||||
result = buf.sample(1)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_replaces_lowest_when_full(self):
|
||||
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
|
||||
|
||||
buf = ReplayBuffer(max_size=2)
|
||||
buf.add(1.0, {"a": 1})
|
||||
buf.add(2.0, {"a": 2})
|
||||
buf.add(3.0, {"a": 3}) # should replace score=1.0
|
||||
self.assertEqual(len(buf), 2)
|
||||
scores = sorted(item[0] for item in buf._heap)
|
||||
self.assertEqual(scores, [2.0, 3.0])
|
||||
|
||||
|
||||
class TestGRPOStrategyConflict(unittest.TestCase):
|
||||
"""Tests for sequence_parallel + async_grpo conflict detection."""
|
||||
|
||||
def test_raises_on_both_enabled(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=True)
|
||||
self.assertIn("sequence_parallel", str(ctx.exception))
|
||||
self.assertIn("async_grpo", str(ctx.exception))
|
||||
|
||||
def test_sequence_parallel_only(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
)
|
||||
|
||||
cls = GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=False)
|
||||
self.assertIs(cls, AxolotlGRPOSequenceParallelTrainer)
|
||||
|
||||
def test_async_only(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlAsyncGRPOTrainer
|
||||
|
||||
cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=True)
|
||||
self.assertIs(cls, AxolotlAsyncGRPOTrainer)
|
||||
|
||||
def test_neither(self):
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
|
||||
cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=False)
|
||||
self.assertIs(cls, AxolotlGRPOTrainer)
|
||||
|
||||
|
||||
class TestDequantizeFP8TailBlocks(unittest.TestCase):
|
||||
"""Tests for FP8 dequantization with non-divisible dimensions."""
|
||||
|
||||
def test_exact_divisible_shape(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
W = torch.randn(256, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.ones(2, 1, dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (256, 128))
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
def test_non_divisible_rows(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
# 130 rows, scale has 2 blocks (block_size ~65 for exact div, but with
|
||||
# tail blocks: first block=65 rows, second=65 rows, 130%2=0 actually).
|
||||
# Use 131 rows with 2 scale blocks to trigger tail handling.
|
||||
W = torch.ones(131, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.tensor([[2.0], [3.0]], dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (131, 128))
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
def test_non_divisible_cols(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
W = torch.ones(128, 200, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.ones(1, 2, dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (128, 200))
|
||||
|
||||
def test_scalar_scale(self):
|
||||
from axolotl.kernels.quantize import dequantize_fp8
|
||||
|
||||
W = torch.ones(64, 64, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
|
||||
scale_inv = torch.tensor(2.0, dtype=torch.bfloat16)
|
||||
result = dequantize_fp8(W, scale_inv)
|
||||
self.assertEqual(result.shape, (64, 64))
|
||||
|
||||
|
||||
class TestLoraFP8Guard(unittest.TestCase):
|
||||
"""Tests that get_lora_parameters only uses weight_scale_inv for FP8 weights."""
|
||||
|
||||
def test_non_fp8_weight_skips_scale_inv(self):
|
||||
"""Non-FP8 weight should NOT pick up weight_scale_inv as quant_state."""
|
||||
from axolotl.kernels.lora import get_lora_parameters
|
||||
|
||||
proj = MagicMock()
|
||||
proj.disable_adapters = True
|
||||
base_layer = MagicMock(spec=[]) # empty spec to control attrs precisely
|
||||
|
||||
# Use a real tensor for weight (bf16, no quant_state attr)
|
||||
base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16)
|
||||
base_layer.bias = None
|
||||
base_layer.weight_scale_inv = torch.ones(1) # should NOT be used for bf16
|
||||
|
||||
proj.base_layer = base_layer
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||
# quant_state should be None since weight is bf16, not FP8
|
||||
self.assertIsNone(quant_state)
|
||||
|
||||
def test_fp8_weight_uses_scale_inv(self):
|
||||
"""FP8 weight should pick up weight_scale_inv as quant_state."""
|
||||
from axolotl.kernels.lora import get_lora_parameters
|
||||
|
||||
proj = MagicMock()
|
||||
proj.disable_adapters = True
|
||||
base_layer = MagicMock()
|
||||
proj.base_layer = base_layer
|
||||
|
||||
# FP8 weight
|
||||
base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
base_layer.bias = None
|
||||
scale_inv = torch.ones(1)
|
||||
base_layer.weight_scale_inv = scale_inv
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||
self.assertIs(quant_state, scale_inv)
|
||||
|
||||
|
||||
class TestValidateQuantPatchRestore(unittest.TestCase):
|
||||
"""Test that validate_quantization_for_training is restored after trainer creation."""
|
||||
|
||||
def test_patch_restored_on_success(self):
|
||||
"""Monkeypatch should be restored even after successful trainer creation."""
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
original = _trainer_module.validate_quantization_for_training
|
||||
|
||||
# After the build() method runs, original should be restored.
|
||||
# We can't easily test the full build(), but we can test the pattern.
|
||||
_orig = _trainer_module.validate_quantization_for_training
|
||||
_trainer_module.validate_quantization_for_training = lambda model: None
|
||||
try:
|
||||
pass # simulate trainer_cls() succeeding
|
||||
finally:
|
||||
_trainer_module.validate_quantization_for_training = _orig
|
||||
|
||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||
|
||||
def test_patch_restored_on_error(self):
|
||||
"""Monkeypatch should be restored even if trainer creation raises."""
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
original = _trainer_module.validate_quantization_for_training
|
||||
|
||||
_orig = _trainer_module.validate_quantization_for_training
|
||||
_trainer_module.validate_quantization_for_training = lambda model: None
|
||||
try:
|
||||
raise ValueError("test error")
|
||||
except ValueError:
|
||||
pass
|
||||
finally:
|
||||
_trainer_module.validate_quantization_for_training = _orig
|
||||
|
||||
self.assertIs(_trainer_module.validate_quantization_for_training, original)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
286
tests/monkeypatch/test_trl_vllm.py
Normal file
286
tests/monkeypatch/test_trl_vllm.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Unit tests for TRL vLLM monkeypatches.
|
||||
|
||||
Tests:
|
||||
- split_tensor_dict: scalar type preservation (int/float/bool)
|
||||
- shuffle_sequence_dict: scalar type preservation
|
||||
- extract_logprobs: NaN → 0.0 replacement
|
||||
- VLLMClient.batch_update_named_params: method exists after patch
|
||||
- VLLMGeneration: weight_sync_chunk_size attribute after patch
|
||||
- Patch idempotency: applying patch twice doesn't break anything
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TestSplitTensorDict(unittest.TestCase):
|
||||
"""Tests for patched split_tensor_dict."""
|
||||
|
||||
def setUp(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import _patched_split_tensor_dict
|
||||
|
||||
self.split = _patched_split_tensor_dict
|
||||
|
||||
def test_scalar_int_preserved(self):
|
||||
d = {"a": torch.randn(4, 3), "count": 42}
|
||||
chunks = self.split(d, 2)
|
||||
self.assertEqual(len(chunks), 2)
|
||||
self.assertEqual(chunks[0]["count"], 42)
|
||||
self.assertEqual(chunks[1]["count"], 42)
|
||||
|
||||
def test_scalar_float_preserved(self):
|
||||
d = {"a": torch.randn(6, 2), "lr": 1e-5}
|
||||
chunks = self.split(d, 3)
|
||||
for c in chunks:
|
||||
self.assertEqual(c["lr"], 1e-5)
|
||||
|
||||
def test_scalar_bool_preserved(self):
|
||||
d = {"a": torch.randn(4, 2), "flag": True}
|
||||
chunks = self.split(d, 2)
|
||||
for c in chunks:
|
||||
self.assertTrue(c["flag"])
|
||||
|
||||
def test_none_preserved(self):
|
||||
d = {"a": torch.randn(4, 2), "b": None}
|
||||
chunks = self.split(d, 2)
|
||||
for c in chunks:
|
||||
self.assertIsNone(c["b"])
|
||||
|
||||
def test_tensor_split(self):
|
||||
t = torch.arange(8).reshape(4, 2)
|
||||
d = {"a": t, "n": 10}
|
||||
chunks = self.split(d, 2)
|
||||
self.assertEqual(chunks[0]["a"].shape, (2, 2))
|
||||
self.assertEqual(chunks[1]["a"].shape, (2, 2))
|
||||
torch.testing.assert_close(chunks[0]["a"], t[:2])
|
||||
torch.testing.assert_close(chunks[1]["a"], t[2:])
|
||||
|
||||
def test_0d_tensor_preserved(self):
|
||||
d = {"a": torch.randn(4, 2), "scalar_t": torch.tensor(3.14)}
|
||||
chunks = self.split(d, 2)
|
||||
for c in chunks:
|
||||
self.assertAlmostEqual(c["scalar_t"].item(), 3.14, places=5)
|
||||
|
||||
def test_list_split(self):
|
||||
d = {"a": torch.randn(4, 2), "names": ["a", "b", "c", "d"]}
|
||||
chunks = self.split(d, 2)
|
||||
self.assertEqual(chunks[0]["names"], ["a", "b"])
|
||||
self.assertEqual(chunks[1]["names"], ["c", "d"])
|
||||
|
||||
|
||||
class TestShuffleSequenceDict(unittest.TestCase):
|
||||
"""Tests for patched shuffle_sequence_dict."""
|
||||
|
||||
def setUp(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import _patched_shuffle_sequence_dict
|
||||
|
||||
self.shuffle = _patched_shuffle_sequence_dict
|
||||
|
||||
def test_scalar_int_preserved(self):
|
||||
d = {"a": torch.randn(4, 3), "count": 42}
|
||||
result = self.shuffle(d)
|
||||
self.assertEqual(result["count"], 42)
|
||||
|
||||
def test_scalar_float_preserved(self):
|
||||
d = {"a": torch.randn(4, 3), "lr": 1e-5}
|
||||
result = self.shuffle(d)
|
||||
self.assertEqual(result["lr"], 1e-5)
|
||||
|
||||
def test_scalar_bool_preserved(self):
|
||||
d = {"a": torch.randn(4, 3), "flag": False}
|
||||
result = self.shuffle(d)
|
||||
self.assertFalse(result["flag"])
|
||||
|
||||
def test_none_preserved(self):
|
||||
d = {"a": torch.randn(4, 3), "b": None}
|
||||
result = self.shuffle(d)
|
||||
self.assertIsNone(result["b"])
|
||||
|
||||
def test_tensor_permuted(self):
|
||||
torch.manual_seed(42)
|
||||
t = torch.arange(4).float()
|
||||
d = {"a": t}
|
||||
result = self.shuffle(d)
|
||||
# Same elements, possibly different order
|
||||
self.assertEqual(sorted(result["a"].tolist()), sorted(t.tolist()))
|
||||
self.assertEqual(result["a"].shape, t.shape)
|
||||
|
||||
def test_list_permuted(self):
|
||||
torch.manual_seed(42)
|
||||
d = {"a": torch.randn(3, 2), "names": ["x", "y", "z"]}
|
||||
result = self.shuffle(d)
|
||||
self.assertEqual(sorted(result["names"]), ["x", "y", "z"])
|
||||
self.assertEqual(len(result["names"]), 3)
|
||||
|
||||
def test_0d_tensor_preserved(self):
|
||||
d = {"a": torch.randn(4, 2), "scalar_t": torch.tensor(3.14)}
|
||||
result = self.shuffle(d)
|
||||
self.assertAlmostEqual(result["scalar_t"].item(), 3.14, places=5)
|
||||
|
||||
|
||||
class TestExtractLogprobs(unittest.TestCase):
|
||||
"""Tests for patched extract_logprobs (NaN → 0.0)."""
|
||||
|
||||
def setUp(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import _patched_extract_logprobs
|
||||
|
||||
self.extract = _patched_extract_logprobs
|
||||
|
||||
def _make_output(self, logprob_values):
|
||||
"""Create a mock vLLM RequestOutput with given logprob values."""
|
||||
|
||||
@dataclass
|
||||
class LogprobItem:
|
||||
logprob: float
|
||||
rank: int
|
||||
|
||||
@dataclass
|
||||
class SeqOutput:
|
||||
logprobs: list[dict[int, LogprobItem]] | None
|
||||
|
||||
@dataclass
|
||||
class RequestOutput:
|
||||
outputs: list[SeqOutput]
|
||||
|
||||
logprobs_list = []
|
||||
for vals in logprob_values:
|
||||
lp_dict = {i: LogprobItem(logprob=v, rank=i) for i, v in enumerate(vals)}
|
||||
logprobs_list.append(lp_dict)
|
||||
|
||||
return RequestOutput(outputs=[SeqOutput(logprobs=logprobs_list)])
|
||||
|
||||
def test_nan_replaced_with_zero(self):
|
||||
output = self._make_output([[float("nan"), 0.5], [-0.3, float("nan")]])
|
||||
logprobs, token_ids = self.extract([output])
|
||||
self.assertEqual(logprobs[0][0][0], 0.0) # NaN → 0.0
|
||||
self.assertEqual(logprobs[0][0][1], 0.5)
|
||||
self.assertEqual(logprobs[0][1][0], -0.3)
|
||||
self.assertEqual(logprobs[0][1][1], 0.0) # NaN → 0.0
|
||||
|
||||
def test_normal_values_preserved(self):
|
||||
output = self._make_output([[-0.5, -1.2], [-0.1, -2.0]])
|
||||
logprobs, token_ids = self.extract([output])
|
||||
self.assertAlmostEqual(logprobs[0][0][0], -0.5)
|
||||
self.assertAlmostEqual(logprobs[0][0][1], -1.2)
|
||||
|
||||
def test_none_logprobs_returns_none(self):
|
||||
@dataclass
|
||||
class SeqOutput:
|
||||
logprobs: None = None
|
||||
|
||||
@dataclass
|
||||
class RequestOutput:
|
||||
outputs: list
|
||||
|
||||
output = RequestOutput(outputs=[SeqOutput()])
|
||||
logprobs, token_ids = self.extract([output])
|
||||
self.assertIsNone(logprobs)
|
||||
self.assertIsNone(token_ids)
|
||||
|
||||
def test_token_ids_extracted(self):
|
||||
output = self._make_output([[-0.5]])
|
||||
logprobs, token_ids = self.extract([output])
|
||||
self.assertEqual(token_ids[0][0], [0]) # token_id=0 from enumerate
|
||||
|
||||
|
||||
class TestPatchApplication(unittest.TestCase):
|
||||
"""Tests for patch_trl_vllm() application."""
|
||||
|
||||
def test_batch_update_added_to_client(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
|
||||
|
||||
patch_trl_vllm()
|
||||
from trl.generation.vllm_client import VLLMClient
|
||||
|
||||
self.assertTrue(hasattr(VLLMClient, "batch_update_named_params"))
|
||||
|
||||
def test_extract_logprobs_patched(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import (
|
||||
_patched_extract_logprobs,
|
||||
patch_trl_vllm,
|
||||
)
|
||||
|
||||
patch_trl_vllm()
|
||||
from trl.generation import vllm_generation
|
||||
|
||||
self.assertIs(vllm_generation.extract_logprobs, _patched_extract_logprobs)
|
||||
|
||||
def test_utils_patched(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import (
|
||||
_patched_shuffle_sequence_dict,
|
||||
_patched_split_tensor_dict,
|
||||
patch_trl_vllm,
|
||||
)
|
||||
|
||||
patch_trl_vllm()
|
||||
import trl.trainer.utils
|
||||
|
||||
self.assertIs(trl.trainer.utils.split_tensor_dict, _patched_split_tensor_dict)
|
||||
self.assertIs(
|
||||
trl.trainer.utils.shuffle_sequence_dict, _patched_shuffle_sequence_dict
|
||||
)
|
||||
|
||||
def test_patch_idempotent(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
|
||||
|
||||
patch_trl_vllm()
|
||||
patch_trl_vllm() # second call should not error
|
||||
from trl.generation.vllm_client import VLLMClient
|
||||
|
||||
self.assertTrue(hasattr(VLLMClient, "batch_update_named_params"))
|
||||
|
||||
|
||||
class TestBatchUpdateChunking(unittest.TestCase):
|
||||
"""Tests for batch_update_named_params chunking logic."""
|
||||
|
||||
def test_no_chunk_single_batch(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import _batch_update_named_params
|
||||
|
||||
# Test that with chunk_size=None, all params go in one chunk
|
||||
client = MagicMock()
|
||||
client.base_url = "http://localhost:8000"
|
||||
client.session.post.return_value = MagicMock(status_code=200)
|
||||
client.communicator = MagicMock()
|
||||
client.communicator.group = MagicMock()
|
||||
client.rank = 0
|
||||
|
||||
params = [
|
||||
("layer.0.weight", torch.randn(10, 10)),
|
||||
("layer.1.weight", torch.randn(10, 10)),
|
||||
]
|
||||
_batch_update_named_params(client, params, chunk_size=None)
|
||||
|
||||
# Should make exactly 1 HTTP call
|
||||
self.assertEqual(client.session.post.call_count, 1)
|
||||
|
||||
def test_chunk_splits_params(self):
|
||||
from axolotl.monkeypatch.trainer.trl_vllm import _batch_update_named_params
|
||||
|
||||
client = MagicMock()
|
||||
client.base_url = "http://localhost:8000"
|
||||
client.session.post.return_value = MagicMock(status_code=200)
|
||||
client.communicator = MagicMock()
|
||||
client.communicator.group = MagicMock()
|
||||
client.rank = 0
|
||||
|
||||
params = [
|
||||
("a", torch.randn(100)), # 100 elements
|
||||
("b", torch.randn(100)), # 100 elements
|
||||
("c", torch.randn(100)), # 100 elements
|
||||
]
|
||||
_batch_update_named_params(client, params, chunk_size=150)
|
||||
|
||||
# Should make 2 HTTP calls: [a,b] then [c] (100+100 > 150 triggers split)
|
||||
# Actually: a=100 < 150, a+b=200 > 150 → chunk [a], then b=100 < 150,
|
||||
# b+c=200 > 150 → chunk [b], then [c]. So 3 calls.
|
||||
# Wait: first a added (100 < 150), then b: 100+100=200 > 150, so chunk=[a],
|
||||
# new chunk starts with b (100 < 150), then c: 100+100=200 > 150, so chunk=[b],
|
||||
# final chunk=[c]. 3 HTTP calls.
|
||||
self.assertEqual(client.session.post.call_count, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user