Description
I tried to replace/update some parts of the data stored in ReplayBuffer with new tensors.
Although the functionality was add (#2209), I found an unexpected behavior as follows:
import tensordict
import torch
from torchrl.data import LazyTensorStorage, SliceSampler, ReplayBuffer
def test():
buffer = ReplayBuffer(
storage=LazyTensorStorage(50),
sampler=SliceSampler(
traj_key="episode",
slice_len=5,
),
batch_size=10,
)
for i in range(50):
n, m = divmod(i, 10)
buffer.add(
tensordict.TensorDict(
{
"episode": torch.tensor(n),
"update": torch.tensor(False),
},
)
)
batch, batch_info = buffer.sample(return_info=True)
index = batch_info["index"]
print(f"{index=}")
new_update = torch.ones_like(batch["update"])
print("before: ", buffer["update"])
# try 1 ------
buffer[index]["update"] = new_update
print("try 1 (This doesn't update!) ----")
print(buffer["update"])
# try 2 ------
buffer[index].set_("update", new_update)
print("try 2 (This also doesn't update!) ----")
print(buffer["update"])
# try 3 ------
buffer["update"][index] = new_update
print("try 3 (This updated!)")
print(buffer["update"])
if __name__ == "__main__":
test()
This prints like:
index=(tensor([43, 44, 45, 46, 47, 13, 14, 15, 16, 17]),)
before: tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False])
try 1 (This doesn't update!) ----
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False])
try 2 (This also doesn't update!) ----
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False])
try 3 (This updated!)
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, True, True, True, True, True, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, True, True, True, True, True, False, False])
As described above, when I first got a tensordict by specifying the index (try 1 & 2), I failed to replace the part of "update".
However, when I first got a tensor by specifying the key "update" (try 3), I succeeded in replacing the part of "update".
(By the way, I also noticed that a substitution of tensordict worked as expected as below)
# try 4
batch.set_("update", new_update)
buffer[index] = batch # This replaces like above!
Although currently there are some workarounds (try-3/4), this behavior is confusing and some users unconsciously may try 1&2, like me.
A possible fix is to modify try-1/2 ways so that both can also update the data.
(Finally, my system info is like below)
>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.7.1 1.26.4 3.11.11 (main, Dec 4 2024, 08:55:07) [GCC 11.4.0] linux