From 1e7a17b8964b8810ed5b5b3af1419b27bac351f1 Mon Sep 17 00:00:00 2001 From: Zach-Attach Date: Fri, 8 Nov 2024 14:20:46 -0500 Subject: [PATCH] Updated im_loss mask to work with action spaces > 1 --- rllte/xplore/reward/e3b.py | 2 ++ rllte/xplore/reward/icm.py | 6 ++++-- rllte/xplore/reward/pseudo_counts.py | 2 ++ rllte/xplore/reward/ride.py | 6 ++++-- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/rllte/xplore/reward/e3b.py b/rllte/xplore/reward/e3b.py index 7bbeae37..e9e38213 100644 --- a/rllte/xplore/reward/e3b.py +++ b/rllte/xplore/reward/e3b.py @@ -219,6 +219,8 @@ def update(self, samples: Dict[str, th.Tensor]) -> None: # use a random mask to select a subset of the training data mask = th.rand(len(im_loss), device=self.device) mask = (mask < self.update_proportion).type(th.FloatTensor).to(self.device) + # expand the mask to match action spaces > 1 + mask = mask.unsqueeze(1).expand_as(im_loss) # get the masked loss im_loss = (im_loss * mask).sum() / th.max( mask.sum(), th.tensor([1], device=self.device, dtype=th.float32) diff --git a/rllte/xplore/reward/icm.py b/rllte/xplore/reward/icm.py index 6a315a75..61324743 100644 --- a/rllte/xplore/reward/icm.py +++ b/rllte/xplore/reward/icm.py @@ -221,9 +221,11 @@ def update(self, samples: Dict[str, th.Tensor]) -> None: # use a random mask to select a subset of the training data mask = th.rand(len(im_loss), device=self.device) mask = (mask < self.update_proportion).type(th.FloatTensor).to(self.device) + # expand the mask to match action spaces > 1 + im_mask = mask.unsqueeze(1).expand_as(im_loss) # get the masked losses - im_loss = (im_loss * mask).sum() / th.max( - mask.sum(), th.tensor([1], device=self.device, dtype=th.float32) + im_loss = (im_loss * im_mask).sum() / th.max( + im_mask.sum(), th.tensor([1], device=self.device, dtype=th.float32) ) fm_loss = (fm_loss * mask).sum() / th.max( mask.sum(), th.tensor([1], device=self.device, dtype=th.float32) diff --git a/rllte/xplore/reward/pseudo_counts.py b/rllte/xplore/reward/pseudo_counts.py index a2b7a8f1..5e3b7986 100644 --- a/rllte/xplore/reward/pseudo_counts.py +++ b/rllte/xplore/reward/pseudo_counts.py @@ -300,6 +300,8 @@ def update(self, samples: Dict[str, th.Tensor]) -> None: # use a random mask to select a subset of the training data mask = th.rand(len(im_loss), device=self.device) mask = (mask < self.update_proportion).type(th.FloatTensor).to(self.device) + # expand the mask to match action spaces > 1 + mask = mask.unsqueeze(1).expand_as(im_loss) # get the masked loss im_loss = (im_loss * mask).sum() / th.max( mask.sum(), th.tensor([1], device=self.device, dtype=th.float32) diff --git a/rllte/xplore/reward/ride.py b/rllte/xplore/reward/ride.py index e5216218..e6cf9316 100644 --- a/rllte/xplore/reward/ride.py +++ b/rllte/xplore/reward/ride.py @@ -338,9 +338,11 @@ def update(self, samples: Dict[str, th.Tensor]) -> None: # use a random mask to select a subset of the training data mask = th.rand(len(im_loss), device=self.device) mask = (mask < self.update_proportion).type(th.FloatTensor).to(self.device) + # expand the mask to match action spaces > 1 + im_mask = mask.unsqueeze(1).expand_as(im_loss) # get the masked losses - im_loss = (im_loss * mask).sum() / th.max( - mask.sum(), th.tensor([1], device=self.device, dtype=th.float32) + im_loss = (im_loss * im_mask).sum() / th.max( + im_mask.sum(), th.tensor([1], device=self.device, dtype=th.float32) ) fm_loss = (fm_loss * mask).sum() / th.max( mask.sum(), th.tensor([1], device=self.device, dtype=th.float32)