Skip to content

[BUG] An unexpected behavior in replacing/updaing a part of data stored in ReplayBuffer #2810

Open
@tmparticle

Description

@tmparticle

I tried to replace/update some parts of the data stored in ReplayBuffer with new tensors.

Although the functionality was add (#2209), I found an unexpected behavior as follows:

import tensordict
import torch
from torchrl.data import LazyTensorStorage, SliceSampler, ReplayBuffer


def test():
    buffer = ReplayBuffer(
        storage=LazyTensorStorage(50),
        sampler=SliceSampler(
            traj_key="episode",
            slice_len=5,
        ),
        batch_size=10,
    )
    for i in range(50):
        n, m = divmod(i, 10)
        buffer.add(
            tensordict.TensorDict(
                {
                    "episode": torch.tensor(n),
                    "update": torch.tensor(False),
                },
            )
        )
    batch, batch_info = buffer.sample(return_info=True)
    index = batch_info["index"]
    print(f"{index=}")
    new_update = torch.ones_like(batch["update"])
    
    print("before: ", buffer["update"])
    
    # try 1 ------
    buffer[index]["update"] = new_update
    print("try 1 (This doesn't update!) ----")
    print(buffer["update"])
    
    # try 2 ------
    buffer[index].set_("update", new_update)
    print("try 2 (This also doesn't update!) ----")
    print(buffer["update"])
    
    # try 3 ------
    buffer["update"][index] = new_update
    print("try 3 (This updated!)")
    print(buffer["update"])


if __name__ == "__main__":
    test()

This prints like:

index=(tensor([43, 44, 45, 46, 47, 13, 14, 15, 16, 17]),)
before:  tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
try 1 (This doesn't update!) ----
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
try 2 (This also doesn't update!) ----
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
try 3 (This updated!)
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True,  True,  True,  True,  True, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True,  True,  True,  True,  True, False, False])

As described above, when I first got a tensordict by specifying the index (try 1 & 2), I failed to replace the part of "update".
However, when I first got a tensor by specifying the key "update" (try 3), I succeeded in replacing the part of "update".

(By the way, I also noticed that a substitution of tensordict worked as expected as below)

# try 4
batch.set_("update", new_update)
buffer[index] = batch # This replaces like above!

Although currently there are some workarounds (try-3/4), this behavior is confusing and some users unconsciously may try 1&2, like me.
A possible fix is to modify try-1/2 ways so that both can also update the data.

(Finally, my system info is like below)

>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.7.1 1.26.4 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0] linux

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions