Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] An unexpected behavior in replacing/updaing a part of data stored in ReplayBuffer #2810

Open
tmparticle opened this issue Feb 25, 2025 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@tmparticle
Copy link

I tried to replace/update some parts of the data stored in ReplayBuffer with new tensors.

Although the functionality was add (#2209), I found an unexpected behavior as follows:

import tensordict
import torch
from torchrl.data import LazyTensorStorage, SliceSampler, ReplayBuffer


def test():
    buffer = ReplayBuffer(
        storage=LazyTensorStorage(50),
        sampler=SliceSampler(
            traj_key="episode",
            slice_len=5,
        ),
        batch_size=10,
    )
    for i in range(50):
        n, m = divmod(i, 10)
        buffer.add(
            tensordict.TensorDict(
                {
                    "episode": torch.tensor(n),
                    "update": torch.tensor(False),
                },
            )
        )
    batch, batch_info = buffer.sample(return_info=True)
    index = batch_info["index"]
    print(f"{index=}")
    new_update = torch.ones_like(batch["update"])
    
    print("before: ", buffer["update"])
    
    # try 1 ------
    buffer[index]["update"] = new_update
    print("try 1 (This doesn't update!) ----")
    print(buffer["update"])
    
    # try 2 ------
    buffer[index].set_("update", new_update)
    print("try 2 (This also doesn't update!) ----")
    print(buffer["update"])
    
    # try 3 ------
    buffer["update"][index] = new_update
    print("try 3 (This updated!)")
    print(buffer["update"])


if __name__ == "__main__":
    test()

This prints like:

index=(tensor([43, 44, 45, 46, 47, 13, 14, 15, 16, 17]),)
before:  tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
try 1 (This doesn't update!) ----
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
try 2 (This also doesn't update!) ----
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
try 3 (This updated!)
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True,  True,  True,  True,  True, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True,  True,  True,  True,  True, False, False])

As described above, when I first got a tensordict by specifying the index (try 1 & 2), I failed to replace the part of "update".
However, when I first got a tensor by specifying the key "update" (try 3), I succeeded in replacing the part of "update".

(By the way, I also noticed that a substitution of tensordict worked as expected as below)

# try 4
batch.set_("update", new_update)
buffer[index] = batch # This replaces like above!

Although currently there are some workarounds (try-3/4), this behavior is confusing and some users unconsciously may try 1&2, like me.
A possible fix is to modify try-1/2 ways so that both can also update the data.

(Finally, my system info is like below)

>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.7.1 1.26.4 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0] linux
@tmparticle tmparticle added the bug Something isn't working label Feb 25, 2025
@vmoens
Copy link
Contributor

vmoens commented Feb 25, 2025

Hey!
This is unfortunately kind of expected. It's a pytorch thing, not really an RL / tensordict artifact. Think of tensordict as a regular dict:

  • when you do td[index] and index is a tensor, we return a new tensor with a copy of the data (ie the data does not share the memory with the original one). If the index is a slice or an int, the second example will update (try with index = slice(len(index[0]))). So as soon as you do buffer[index] we're screwed.
  • There is however a set_at_ method in tensordict to handle this:
        buffer[:].set_at_("update", ~new_update, index)
    that will do the trick.

that being said, I do agree it's confusing and we should preferably get ways to prevent the confusion. I'm open to suggestions.

What would be the most helpful? Better doc? Find a way to raise an error/warning if people do (1) or (2) (not sure it's achievable)?

@tmparticle
Copy link
Author

tmparticle commented Feb 25, 2025

Hi vmoens,

Thank you for your kind explanation! I could know the pytorch's indexing rule.

I tested int/slice for try-1/2 and confirmed that only try-2 worked, and then I wondered why int/slice for try-1 did not work. If int/slice indexing gives a reference to the part, try-1 also seems to work...?

But anyway, I think adding a document like your description is sufficiently helpful (for example, in the tutorial on replay buffer? (https://pytorch.org/rl/stable/tutorials/rb_tutorial.html)).

Otherwise, I think adding a kind method to ReplayBuffer is also helpful such as:

def update(self, key, value, index):
    self[:].set_at_(key, value, index)

@vmoens
Copy link
Contributor

vmoens commented Feb 25, 2025

Ah yeah, an update method would indeed be a killer feature!
Didn't really think about that but it makes total sense.
Perhaps set_at_, set_, update_ to keep the tensordict nomenclature?

cc @Darktex

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants