Skip to content

Commit 5e03a55

Browse files
kurtamohlervmoens
authored andcommitted
[Benchmark] Add benchmark for compiled ReplayBuffer.extend/sample
ghstack-source-id: d456269 Pull Request resolved: #2514
1 parent 0f29c7e commit 5e03a55

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

benchmarks/test_replaybuffer_benchmark.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
LazyMemmapStorage,
1414
LazyTensorStorage,
1515
ListStorage,
16+
ReplayBuffer,
1617
TensorDictPrioritizedReplayBuffer,
1718
TensorDictReplayBuffer,
1819
)
@@ -172,6 +173,65 @@ def test_rb_populate(benchmark, rb, storage, sampler, size):
172173
)
173174

174175

176+
class create_tensor_rb:
177+
def __init__(self, rb, storage, sampler, size=1_000_000, iters=100):
178+
self.storage = storage
179+
self.rb = rb
180+
self.sampler = sampler
181+
self.size = size
182+
self.iters = iters
183+
184+
def __call__(self):
185+
kwargs = {}
186+
if self.sampler is not None:
187+
kwargs["sampler"] = self.sampler()
188+
if self.storage is not None:
189+
kwargs["storage"] = self.storage(10 * self.size)
190+
191+
rb = self.rb(batch_size=3, **kwargs)
192+
data = torch.randn(self.size, 1)
193+
return ((rb, data, self.iters), {})
194+
195+
196+
def extend_and_sample(rb, td, iters):
197+
for _ in range(iters):
198+
rb.extend(td)
199+
rb.sample()
200+
201+
202+
def extend_and_sample_compiled(rb, td, iters):
203+
@torch.compile
204+
def fn(td):
205+
rb.extend(td)
206+
rb.sample()
207+
208+
for _ in range(iters):
209+
fn(td)
210+
211+
212+
@pytest.mark.parametrize(
213+
"rb,storage,sampler,size,iters,compiled",
214+
[
215+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True],
216+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False],
217+
],
218+
)
219+
def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled):
220+
benchmark.pedantic(
221+
extend_and_sample_compiled if compiled else extend_and_sample,
222+
setup=create_tensor_rb(
223+
rb=rb,
224+
storage=storage,
225+
sampler=sampler,
226+
size=size,
227+
iters=iters,
228+
),
229+
iterations=1,
230+
warmup_rounds=10,
231+
rounds=50,
232+
)
233+
234+
175235
if __name__ == "__main__":
176236
args, unknown = argparse.ArgumentParser().parse_known_args()
177237
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)