Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ See [notebooks/](notebooks/) for visualizations explaining some concepts behind

[example.py](example.py) is a self-contained training script for MNIST and CIFAR that imports the standalone S4 file. The default settings `python example.py` reaches 88% accuracy on sequential CIFAR with a very simple S4D model of 200k parameters.
This script can be used as an example for using S4 variants in external repositories.
In addition, [`line_forecast.py`](line_forecast.py) provides a minimal demonstration of training S4 on a synthetic linear forecasting task.

### Training with this Repository (Internal Usage)

Expand Down
12 changes: 12 additions & 0 deletions configs/dataset/line.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_name_: line
seq_len: 24
pred_len: 12
n_train: 1000
n_val: 200
n_test: 200
slope_range: [0.1, 1.0]
intercept_range: [0.0, 1.0]
noise_std: 0.0
seed: 0
__l_max: ${eval:${.seq_len}+${.pred_len}}

57 changes: 57 additions & 0 deletions line_forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch.nn as nn
from models.s4.s4d import S4D
from torch.utils.data import DataLoader

from src.dataloaders.datasets.line import LineDataset


class ForecastModel(nn.Module):
def __init__(self, d_model=64, n_layers=2, dropout=0.0):
super().__init__()
self.encoder = nn.Linear(1, d_model)
self.s4_layers = nn.ModuleList(
[S4D(d_model, dropout=dropout, transposed=True) for _ in range(n_layers)]
)
self.norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(n_layers)])
self.decoder = nn.Linear(d_model, 1)

def forward(self, x):
x = self.encoder(x)
x = x.transpose(-1, -2)
for layer, norm in zip(self.s4_layers, self.norms):
z, _ = layer(x)
x = norm((x + z).transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(-1, -2)
x_last = x[:, -1]
out = self.decoder(x_last)
return out


def train_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset = LineDataset()
loader = DataLoader(dataset, batch_size=32, shuffle=True)

model = ForecastModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

for epoch in range(10):
model.train()
total_loss = 0.0
for x, y in loader:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
out = model(x)
loss = criterion(out, y.squeeze(1))
loss.backward()
optimizer.step()
total_loss += loss.item() * x.size(0)
avg_loss = total_loss / len(loader.dataset)
print(f"Epoch {epoch+1}: loss {avg_loss:.6f}")


if __name__ == "__main__":
train_model()
2 changes: 1 addition & 1 deletion src/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import audio, basic, et, lm, lra, synthetic, ts, vision
from . import audio, basic, et, lm, lra, synthetic, ts, vision, line
from .base import SequenceDataset
26 changes: 26 additions & 0 deletions src/dataloaders/datasets/line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch

class LineDataset(torch.utils.data.TensorDataset):
def __init__(self, seq_len=24, pred_len=12, n_samples=1000,
slope_range=(0.1, 1.0), intercept_range=(0.0, 1.0),
noise_std=0.0, seed=0):
self.seq_len = seq_len
self.pred_len = pred_len
self.n_samples = n_samples
self.slope_range = slope_range
self.intercept_range = intercept_range
self.noise_std = noise_std
self.seed = seed

generator = torch.Generator().manual_seed(seed)
total_len = seq_len + pred_len
t = torch.arange(total_len, dtype=torch.float32)
slopes = torch.empty(n_samples).uniform_(slope_range[0], slope_range[1], generator=generator)
intercepts = torch.empty(n_samples).uniform_(intercept_range[0], intercept_range[1], generator=generator)
lines = slopes[:, None] * t + intercepts[:, None]
if noise_std > 0:
lines += noise_std * torch.randn(n_samples, total_len, generator=generator)
x = lines[:, :seq_len].unsqueeze(-1)
y = lines[:, seq_len:].unsqueeze(-1)
super().__init__(x, y)
self.forecast_horizon = pred_len
49 changes: 49 additions & 0 deletions src/dataloaders/line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from src.dataloaders.base import SequenceDataset
from .datasets.line import LineDataset

class Line(SequenceDataset):
_name_ = "line"
d_input = 1
d_output = 1

@property
def init_defaults(self):
return {
"seq_len": 24,
"pred_len": 12,
"n_train": 1000,
"n_val": 200,
"n_test": 200,
"slope_range": (0.1, 1.0),
"intercept_range": (0.0, 1.0),
"noise_std": 0.0,
"seed": 0,
}

@property
def l_output(self):
return self.pred_len

def setup(self):
self.dataset_train = LineDataset(
self.seq_len, self.pred_len, self.n_train,
self.slope_range, self.intercept_range,
self.noise_std, seed=self.seed
)
self.dataset_val = LineDataset(
self.seq_len, self.pred_len, self.n_val,
self.slope_range, self.intercept_range,
self.noise_std, seed=self.seed + 1
)
self.dataset_test = LineDataset(
self.seq_len, self.pred_len, self.n_test,
self.slope_range, self.intercept_range,
self.noise_std, seed=self.seed + 2
)
# forecast horizon property used by forecasting task
self.dataset_train.forecast_horizon = self.pred_len
self.dataset_val.forecast_horizon = self.pred_len
self.dataset_test.forecast_horizon = self.pred_len

def __str__(self):
return f"line{self.seq_len}_{self.pred_len}"