From 4702ef03f4337fa9fc77e8108cdf37d597d87375 Mon Sep 17 00:00:00 2001 From: Surya Dheeshjith <41594351+suryadheeshjith@users.noreply.github.com> Date: Wed, 28 Feb 2024 01:25:55 +0530 Subject: [PATCH] Update world_model.py shape --- src/models/world_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/world_model.py b/src/models/world_model.py index 0691034..0aa715e 100644 --- a/src/models/world_model.py +++ b/src/models/world_model.py @@ -97,7 +97,7 @@ def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValue def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs: Any) -> LossWithIntermediateLosses: with torch.no_grad(): - obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens # (BL, K) + obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens # (B, L, K) act_tokens = rearrange(batch['actions'], 'b l -> b l 1') tokens = rearrange(torch.cat((obs_tokens, act_tokens), dim=2), 'b l k1 -> b (l k1)') # (B, L(K+1))