From 6ad5f47914f779235665179e5627827e71a98471 Mon Sep 17 00:00:00 2001 From: konitaro524 <87004518+konitaro524@users.noreply.github.com> Date: Fri, 6 Jun 2025 23:01:22 +0900 Subject: [PATCH 1/3] Add synthetic line forecasting dataset --- configs/dataset/line.yaml | 12 ++++++++ src/dataloaders/__init__.py | 2 +- src/dataloaders/datasets/line.py | 26 +++++++++++++++++ src/dataloaders/line.py | 49 ++++++++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 configs/dataset/line.yaml create mode 100644 src/dataloaders/datasets/line.py create mode 100644 src/dataloaders/line.py diff --git a/configs/dataset/line.yaml b/configs/dataset/line.yaml new file mode 100644 index 00000000..e63ff455 --- /dev/null +++ b/configs/dataset/line.yaml @@ -0,0 +1,12 @@ +_name_: line +seq_len: 24 +pred_len: 12 +n_train: 1000 +n_val: 200 +n_test: 200 +slope_range: [0.1, 1.0] +intercept_range: [0.0, 1.0] +noise_std: 0.0 +seed: 0 +__l_max: ${eval:${.seq_len}+${.pred_len}} + diff --git a/src/dataloaders/__init__.py b/src/dataloaders/__init__.py index 9213cde0..24df09e5 100644 --- a/src/dataloaders/__init__.py +++ b/src/dataloaders/__init__.py @@ -1,2 +1,2 @@ -from . import audio, basic, et, lm, lra, synthetic, ts, vision +from . import audio, basic, et, lm, lra, synthetic, ts, vision, line from .base import SequenceDataset diff --git a/src/dataloaders/datasets/line.py b/src/dataloaders/datasets/line.py new file mode 100644 index 00000000..f47fba17 --- /dev/null +++ b/src/dataloaders/datasets/line.py @@ -0,0 +1,26 @@ +import torch + +class LineDataset(torch.utils.data.TensorDataset): + def __init__(self, seq_len=24, pred_len=12, n_samples=1000, + slope_range=(0.1, 1.0), intercept_range=(0.0, 1.0), + noise_std=0.0, seed=0): + self.seq_len = seq_len + self.pred_len = pred_len + self.n_samples = n_samples + self.slope_range = slope_range + self.intercept_range = intercept_range + self.noise_std = noise_std + self.seed = seed + + generator = torch.Generator().manual_seed(seed) + total_len = seq_len + pred_len + t = torch.arange(total_len, dtype=torch.float32) + slopes = torch.empty(n_samples).uniform_(slope_range[0], slope_range[1], generator=generator) + intercepts = torch.empty(n_samples).uniform_(intercept_range[0], intercept_range[1], generator=generator) + lines = slopes[:, None] * t + intercepts[:, None] + if noise_std > 0: + lines += noise_std * torch.randn(n_samples, total_len, generator=generator) + x = lines[:, :seq_len].unsqueeze(-1) + y = lines[:, seq_len:].unsqueeze(-1) + super().__init__(x, y) + self.forecast_horizon = pred_len diff --git a/src/dataloaders/line.py b/src/dataloaders/line.py new file mode 100644 index 00000000..b60473ae --- /dev/null +++ b/src/dataloaders/line.py @@ -0,0 +1,49 @@ +from src.dataloaders.base import SequenceDataset +from .datasets.line import LineDataset + +class Line(SequenceDataset): + _name_ = "line" + d_input = 1 + d_output = 1 + + @property + def init_defaults(self): + return { + "seq_len": 24, + "pred_len": 12, + "n_train": 1000, + "n_val": 200, + "n_test": 200, + "slope_range": (0.1, 1.0), + "intercept_range": (0.0, 1.0), + "noise_std": 0.0, + "seed": 0, + } + + @property + def l_output(self): + return self.pred_len + + def setup(self): + self.dataset_train = LineDataset( + self.seq_len, self.pred_len, self.n_train, + self.slope_range, self.intercept_range, + self.noise_std, seed=self.seed + ) + self.dataset_val = LineDataset( + self.seq_len, self.pred_len, self.n_val, + self.slope_range, self.intercept_range, + self.noise_std, seed=self.seed + 1 + ) + self.dataset_test = LineDataset( + self.seq_len, self.pred_len, self.n_test, + self.slope_range, self.intercept_range, + self.noise_std, seed=self.seed + 2 + ) + # forecast horizon property used by forecasting task + self.dataset_train.forecast_horizon = self.pred_len + self.dataset_val.forecast_horizon = self.pred_len + self.dataset_test.forecast_horizon = self.pred_len + + def __str__(self): + return f"line{self.seq_len}_{self.pred_len}" From d0fcbab4d6ac2e8a8da246b176e748b6dc762463 Mon Sep 17 00:00:00 2001 From: konitaro524 <87004518+konitaro524@users.noreply.github.com> Date: Sat, 7 Jun 2025 16:50:34 +0900 Subject: [PATCH 2/3] Add linear forecasting demo --- README.md | 1 + line_forecast.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 line_forecast.py diff --git a/README.md b/README.md index fb3269b9..1e6f5483 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ See [notebooks/](notebooks/) for visualizations explaining some concepts behind [example.py](example.py) is a self-contained training script for MNIST and CIFAR that imports the standalone S4 file. The default settings `python example.py` reaches 88% accuracy on sequential CIFAR with a very simple S4D model of 200k parameters. This script can be used as an example for using S4 variants in external repositories. +In addition, [`line_forecast.py`](line_forecast.py) provides a minimal demonstration of training S4 on a synthetic linear forecasting task. ### Training with this Repository (Internal Usage) diff --git a/line_forecast.py b/line_forecast.py new file mode 100644 index 00000000..ccff0c11 --- /dev/null +++ b/line_forecast.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +from models.s4.s4d import S4D +from torch.utils.data import DataLoader, Dataset + + +class LineDataset(Dataset): + """Synthetic dataset of linear sequences for forecasting.""" + + def __init__(self, seq_len=10, pred_len=1, size=1000): + super().__init__() + self.seq_len = seq_len + self.pred_len = pred_len + self.size = size + + def __len__(self): + return self.size + + def __getitem__(self, idx): + slope = torch.rand(1) * 2 - 1 # [-1, 1] + intercept = torch.rand(1) * 2 - 1 + t = torch.arange(self.seq_len + self.pred_len, dtype=torch.float) + y = slope * t + intercept + x = y[: self.seq_len].unsqueeze(-1) + target = y[self.seq_len :].unsqueeze(-1) + return x, target + + +class ForecastModel(nn.Module): + def __init__(self, d_model=64, n_layers=2, dropout=0.0): + super().__init__() + self.encoder = nn.Linear(1, d_model) + self.s4_layers = nn.ModuleList( + [S4D(d_model, dropout=dropout, transposed=True) for _ in range(n_layers)] + ) + self.norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(n_layers)]) + self.decoder = nn.Linear(d_model, 1) + + def forward(self, x): + x = self.encoder(x) + x = x.transpose(-1, -2) + for layer, norm in zip(self.s4_layers, self.norms): + z, _ = layer(x) + x = norm((x + z).transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(-1, -2) + x_last = x[:, -1] + out = self.decoder(x_last) + return out + + +def train_model(): + device = "cuda" if torch.cuda.is_available() else "cpu" + dataset = LineDataset() + loader = DataLoader(dataset, batch_size=32, shuffle=True) + + model = ForecastModel().to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + criterion = nn.MSELoss() + + for epoch in range(10): + model.train() + total_loss = 0.0 + for x, y in loader: + x = x.to(device) + y = y.to(device) + optimizer.zero_grad() + out = model(x) + loss = criterion(out, y.squeeze(1)) + loss.backward() + optimizer.step() + total_loss += loss.item() * x.size(0) + avg_loss = total_loss / len(loader.dataset) + print(f"Epoch {epoch+1}: loss {avg_loss:.6f}") + + +if __name__ == "__main__": + train_model() From 1bd3c05cce0ec2f49dd1301c2e00c2b61722c41b Mon Sep 17 00:00:00 2001 From: konitaro524 <87004518+konitaro524@users.noreply.github.com> Date: Sat, 7 Jun 2025 16:59:50 +0900 Subject: [PATCH 3/3] Use repository line dataset --- line_forecast.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/line_forecast.py b/line_forecast.py index ccff0c11..3b941060 100644 --- a/line_forecast.py +++ b/line_forecast.py @@ -1,29 +1,9 @@ import torch import torch.nn as nn from models.s4.s4d import S4D -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader - -class LineDataset(Dataset): - """Synthetic dataset of linear sequences for forecasting.""" - - def __init__(self, seq_len=10, pred_len=1, size=1000): - super().__init__() - self.seq_len = seq_len - self.pred_len = pred_len - self.size = size - - def __len__(self): - return self.size - - def __getitem__(self, idx): - slope = torch.rand(1) * 2 - 1 # [-1, 1] - intercept = torch.rand(1) * 2 - 1 - t = torch.arange(self.seq_len + self.pred_len, dtype=torch.float) - y = slope * t + intercept - x = y[: self.seq_len].unsqueeze(-1) - target = y[self.seq_len :].unsqueeze(-1) - return x, target +from src.dataloaders.datasets.line import LineDataset class ForecastModel(nn.Module):