Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

index and key are not present when using predict #35

Open
vidarsumo opened this issue Aug 7, 2022 · 7 comments
Open

index and key are not present when using predict #35

vidarsumo opened this issue Aug 7, 2022 · 7 comments
Labels
documentation Improvements or additions to documentation

Comments

@vidarsumo
Copy link

I have future known values so I need to pass the data set to new_data in predict() but it only gives me .pred_lower, .pred and .pred_upper. No index (date) or key (id) present in the output which is important as the data set contains multiple time series.

> forecasts_default <- predict(fitted_default, new_data = test, past_data = train)
> forecasts_default
# A tibble: 520 × 3
   .pred_lower .pred .pred_upper
         <dbl> <dbl>       <dbl>
 1       1498. 2242.       2874.
 2       1373. 1791.       2134.
 3       1259. 1671.       2016.
 4       1142. 1463.       1730.
 5       1619. 2090.       2485.
 6        467. 3079.       5370.
 7       1361. 1743.       2070.
 8       2375. 3815.       5073.
 9       2591. 3305.       3896.
10       5128. 7600.       9728.
# … with 510 more rows

Using generics::forecast() instead and skipping the known future data gives the desired output (i.e. containing date and id):

# A tibble: 520 × 5
   date       id        .pred_lower .pred .pred_upper
 * <date>     <chr>           <dbl> <dbl>       <dbl>
 1 2021-08-09 oes_13078       1498. 2242.       2874.
 2 2021-08-16 oes_13078       1598. 2323.       2971.
 3 2021-08-23 oes_13078       1510. 2379.       3257.
 4 2021-08-30 oes_13078       1572. 2410.       3216.
 5 2021-09-06 oes_13078       1356. 2280.       3259.
 6 2021-09-13 oes_13078       1359. 2090.       2809.
 7 2021-09-20 oes_13078       1383. 2130.       2771.
 8 2021-09-27 oes_13078       1324. 2074.       2847.
 9 2021-10-04 oes_13078       1355. 2254.       2973.
10 2021-10-11 oes_13078       1371. 2068.       2524.
# … with 510 more rows

Do I need to do anything for the predict() function to return the index and the key?

@vidarsumo
Copy link
Author

I guess my question is how do I use known and static information when creating forecasts?

@cregouby
Copy link
Collaborator

cregouby commented Aug 17, 2022

Hello @vidarsumo,

Definitively the generics::forecast() shall be used.

To define the covariates, have you used the Getting started web page syntax ? Does the 'specifying the covariate' section fail with your experiment ?

Hope it helps

@vidarsumo
Copy link
Author

vidarsumo commented Aug 20, 2022

If I'm not mistaken, then generics::forecast() does not accept new_data while predict() does.
The only thing that worked was to use predict() and then bind_cols() to get the id, Date, etc.

suppressPackageStartupMessages(library(tidymodels))
library(tft)
set.seed(1)
torch::torch_manual_seed(1)



# Preparing data
data_tbl <- timetk::walmart_sales_weekly %>%
  select(id, Dept, Date, Weekly_Sales, IsHoliday) %>% 
  mutate(
    Dept = paste0("Dept_", Dept),
    IsHoliday = ifelse(IsHoliday, "yes", "no"))

fit_data <- data_tbl %>% 
  filter(Date <= "2012-08-03")

future_data <- data_tbl %>% 
  filter(Date > "2012-08-03") %>% 
  mutate(Weekly_Sales = NA_real_)


# TFT
rec <- recipe(Weekly_Sales ~ ., data = fit_data) %>%
  timetk::step_timeseries_signature(Date) %>%
  step_zv(all_predictors()) %>% 
  step_normalize(all_numeric_predictors())


spec <- tft_dataset_spec(rec, fit_data) %>%
  spec_covariate_index(Date) %>%
  spec_covariate_key(id) %>%
  spec_covariate_known(starts_with("Date_"), IsHoliday) %>%
  spec_covariate_static(Dept) %>% 
  spec_time_splits(lookback = 52, horizon = 12) %>%
  prep()

tft_model <- temporal_fusion_transformer(spec)

fitted <- tft_model %>%
  fit(
    transform(spec),
    epochs = 1,
    verbose = TRUE,
    dataloader_options = list(batch_size = 64, num_workers = 4)
  )

# Using forecast() without new_data
generics::forecast(fitted, past_data = fit_data)
Error in `adjust_new_data()`:
! Known or static variable is missing from `new_data`.
✖ Check for `Dept`.

# Using forecast() with new_data
generics::forecast(fitted, past_data = fit_data, new_data = future_data)
Error in forecast.tft_result(fitted, past_data = fit_data, new_data = future_data) : 
  unused argument (new_data = future_data)

# Using predict() (no id, date etc. present)
predict(object = fitted, new_data = future_data, past_data = fit_data)
# A tibble: 84 × 3
   .pred_lower  .pred .pred_upper
         <dbl>  <dbl>       <dbl>
 1      13556. 22793.      41542.
 2      12762. 20972.      40915.
 3      11930. 19535.      39995.
 4      12875. 21650.      43980.
 5      12946. 20197.      34923.
 6      14671. 21054.      36676.
 7      14294. 20423.      36560.
 8      14550. 20447.      38125.
 9      12553. 20296.      28059.
10      13694. 21042.      32854.

# Using predict() with bind_cols() does the trick.
predict(object = fitted, new_data = future_data, past_data = fit_data) %>% 
  bind_cols(future_data %>% select(-Weekly_Sales))

Am I doing something wrong here?

@cregouby
Copy link
Collaborator

Hello @vidarsumo

Sorry, my mistake, you are right :

  1. forcast() provides keys and index, but is documented to "can only be used if the model object doesn't include known predictors"
  2. predict() uses known predictor in new_data = but source code removes keys and index variables just before releasing the result at
    dplyr::select(out, dplyr::starts_with(".pred"))
    .
    Maybe @dfalbel knows a design pattern that prevent to add additionnal variables in the predict() output via a switch parameter like all_vars = FALSE ?
    Anyway, it would be easy to modify it, so maybe you can propose a pull-request ?

@Ujjwal4CULS
Copy link

Will it be possible to predict complete data like
predict(object = fitted, new_data = data_tbl, past_data = fit_data)?
, if yes, please let me know. It will be very helpful to the accuracy assessment of the fitted model. The above line generates an error, if someone helps regarding this is highly appreciated.

@cregouby cregouby added the documentation Improvements or additions to documentation label Mar 28, 2023
@cregouby
Copy link
Collaborator

Hello @Ujjwal4CULS

I cannot reproduce your issue with the example documented here. Could you please open a dedicated issue with a Reproductible Example

@Ujjwal4CULS
Copy link

Sorry for my incomplete information. For example, on the above code, can it possible to predict application on data_tbl like this

suppressPackageStartupMessages(library(tidymodels))
library(tft)
set.seed(1)
torch::torch_manual_seed(1)

Preparing data

data_tbl <- timetk::walmart_sales_weekly %>%
select(id, Dept, Date, Weekly_Sales, IsHoliday) %>%
mutate(
Dept = paste0("Dept_", Dept),
IsHoliday = ifelse(IsHoliday, "yes", "no"))

fit_data <- data_tbl %>%
filter(Date <= "2012-08-03")

future_data <- data_tbl %>%
filter(Date > "2012-08-03") %>%
mutate(Weekly_Sales = NA_real_)

TFT

rec <- recipe(Weekly_Sales ~ ., data = fit_data) %>%
timetk::step_timeseries_signature(Date) %>%
step_zv(all_predictors()) %>%
step_normalize(all_numeric_predictors())

spec <- tft_dataset_spec(rec, fit_data) %>%
spec_covariate_index(Date) %>%
spec_covariate_key(id) %>%
spec_covariate_known(starts_with("Date_"), IsHoliday) %>%
spec_covariate_static(Dept) %>%
spec_time_splits(lookback = 52, horizon = 12) %>%
prep()

tft_model <- temporal_fusion_transformer(spec)

fitted <- tft_model %>%
fit(
transform(spec),
epochs = 1,
verbose = TRUE,
dataloader_options = list(batch_size = 64, num_workers = 4)
)
predict(object = fitted, new_data = data_tbl, past_data = fit_data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

3 participants