diff --git a/configs/dataset/line.yaml b/configs/dataset/line.yaml new file mode 100644 index 00000000..e63ff455 --- /dev/null +++ b/configs/dataset/line.yaml @@ -0,0 +1,12 @@ +_name_: line +seq_len: 24 +pred_len: 12 +n_train: 1000 +n_val: 200 +n_test: 200 +slope_range: [0.1, 1.0] +intercept_range: [0.0, 1.0] +noise_std: 0.0 +seed: 0 +__l_max: ${eval:${.seq_len}+${.pred_len}} + diff --git a/src/dataloaders/__init__.py b/src/dataloaders/__init__.py index 9213cde0..24df09e5 100644 --- a/src/dataloaders/__init__.py +++ b/src/dataloaders/__init__.py @@ -1,2 +1,2 @@ -from . import audio, basic, et, lm, lra, synthetic, ts, vision +from . import audio, basic, et, lm, lra, synthetic, ts, vision, line from .base import SequenceDataset diff --git a/src/dataloaders/datasets/line.py b/src/dataloaders/datasets/line.py new file mode 100644 index 00000000..f47fba17 --- /dev/null +++ b/src/dataloaders/datasets/line.py @@ -0,0 +1,26 @@ +import torch + +class LineDataset(torch.utils.data.TensorDataset): + def __init__(self, seq_len=24, pred_len=12, n_samples=1000, + slope_range=(0.1, 1.0), intercept_range=(0.0, 1.0), + noise_std=0.0, seed=0): + self.seq_len = seq_len + self.pred_len = pred_len + self.n_samples = n_samples + self.slope_range = slope_range + self.intercept_range = intercept_range + self.noise_std = noise_std + self.seed = seed + + generator = torch.Generator().manual_seed(seed) + total_len = seq_len + pred_len + t = torch.arange(total_len, dtype=torch.float32) + slopes = torch.empty(n_samples).uniform_(slope_range[0], slope_range[1], generator=generator) + intercepts = torch.empty(n_samples).uniform_(intercept_range[0], intercept_range[1], generator=generator) + lines = slopes[:, None] * t + intercepts[:, None] + if noise_std > 0: + lines += noise_std * torch.randn(n_samples, total_len, generator=generator) + x = lines[:, :seq_len].unsqueeze(-1) + y = lines[:, seq_len:].unsqueeze(-1) + super().__init__(x, y) + self.forecast_horizon = pred_len diff --git a/src/dataloaders/line.py b/src/dataloaders/line.py new file mode 100644 index 00000000..b60473ae --- /dev/null +++ b/src/dataloaders/line.py @@ -0,0 +1,49 @@ +from src.dataloaders.base import SequenceDataset +from .datasets.line import LineDataset + +class Line(SequenceDataset): + _name_ = "line" + d_input = 1 + d_output = 1 + + @property + def init_defaults(self): + return { + "seq_len": 24, + "pred_len": 12, + "n_train": 1000, + "n_val": 200, + "n_test": 200, + "slope_range": (0.1, 1.0), + "intercept_range": (0.0, 1.0), + "noise_std": 0.0, + "seed": 0, + } + + @property + def l_output(self): + return self.pred_len + + def setup(self): + self.dataset_train = LineDataset( + self.seq_len, self.pred_len, self.n_train, + self.slope_range, self.intercept_range, + self.noise_std, seed=self.seed + ) + self.dataset_val = LineDataset( + self.seq_len, self.pred_len, self.n_val, + self.slope_range, self.intercept_range, + self.noise_std, seed=self.seed + 1 + ) + self.dataset_test = LineDataset( + self.seq_len, self.pred_len, self.n_test, + self.slope_range, self.intercept_range, + self.noise_std, seed=self.seed + 2 + ) + # forecast horizon property used by forecasting task + self.dataset_train.forecast_horizon = self.pred_len + self.dataset_val.forecast_horizon = self.pred_len + self.dataset_test.forecast_horizon = self.pred_len + + def __str__(self): + return f"line{self.seq_len}_{self.pred_len}"