You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thank you for your interest in our work! The current model uses a few history steps as inputs to predict the next step. Depending on the benchmark, the number of history steps varies from 2 to 10. These history steps serve as the input to the model, which can then predict multiple steps ahead in an autoregressive manner.
To achieve this, we use the following rollout function to generate multistep predictions:
def rollout(state, x, coords, prev_steps=2, pred_steps=1, rollout_steps=5):
b, t, h, w, c = x.shape
pred_list = []
for k in range(rollout_steps):
pred = vmap(state.apply_fn, (None, None, 0), out_axes=2)(state.params, x, coords[:, None, :])
pred = pred.reshape(b, pred_steps, h, w, c)
x = jnp.concatenate([x, pred], axis=1)
x = x[:, -prev_steps:]
pred_list.append(pred)
pred = jnp.concatenate(pred_list, axis=1)
return pred
At each iteration, the function predicts the next step, appends it to the input, and trims to retain the required history steps. This process repeats for the specified rollout_steps.
Hi,
I would like to understand how is the inference process? can I use initial step only to predict the multistep ahead prediction?
thank you
regards
The text was updated successfully, but these errors were encountered: