Skip to content

[ENH] xLSTMTime implementation #1709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pytorch_forecasting/models/x_lstm_time/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""xLSTMTime implementation for forecasting."""

from pytorch_forecasting.models.x_lstm_time.x_lstm import xLSTMTime

__all__ = ["xLSTMTime"]
Empty file.
187 changes: 187 additions & 0 deletions pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import math

import torch
import torch.nn as nn


class mLSTMCell(nn.Module):
"""Implements the Matrix Long Short-Term Memory (mLSTM) Cell.

Implements the mLSTM algorithm as described in the paper:
(https://arxiv.org/pdf/2407.10240).

Parameters
----------
input_size : int
Size of the input feature vector.
hidden_size : int
Number of hidden units in the LSTM cell.
dropout : float, optional
Dropout rate applied to inputs and hidden states, by default 0.2.
layer_norm : bool, optional
If True, apply Layer Normalization to gates and interactions, by default True.
device : torch.device, optional
Device for computation (CPU or CUDA), by default uses GPU if available.


Attributes
----------
Wq : nn.Linear
Linear layer for computing the query vector.
Wk : nn.Linear
Linear layer for computing the key vector.
Wv : nn.Linear
Linear layer for computing the value vector.
Wi : nn.Linear
Linear layer for the input gate.
Wf : nn.Linear
Linear layer for the forget gate.
Wo : nn.Linear
Linear layer for the output gate.
dropout : nn.Dropout
Dropout regularization layer.
ln_q, ln_k, ln_v, ln_i, ln_f, ln_o : nn.LayerNorm
Optional layer normalization layers for respective computations.
device : torch.device
Device used for computation.
"""

def __init__(
self, input_size, hidden_size, dropout=0.2, layer_norm=True, device=None
):
super(mLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.layer_norm = layer_norm

self.device = (
device
if device is not None
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)

self.Wq = nn.Linear(input_size, hidden_size)
self.Wk = nn.Linear(input_size, hidden_size)
self.Wv = nn.Linear(input_size, hidden_size)

self.Wi = nn.Linear(input_size, hidden_size)
self.Wf = nn.Linear(input_size, hidden_size)
self.Wo = nn.Linear(input_size, hidden_size)

self.Wq.to(self.device)
self.Wk.to(self.device)
self.Wv.to(self.device)
self.Wi.to(self.device)
self.Wf.to(self.device)
self.Wo.to(self.device)

self.dropout = nn.Dropout(dropout)
self.dropout.to(self.device)

if layer_norm:
self.ln_q = nn.LayerNorm(hidden_size)
self.ln_k = nn.LayerNorm(hidden_size)
self.ln_v = nn.LayerNorm(hidden_size)
self.ln_i = nn.LayerNorm(hidden_size)
self.ln_f = nn.LayerNorm(hidden_size)
self.ln_o = nn.LayerNorm(hidden_size)

self.ln_q.to(self.device)
self.ln_k.to(self.device)
self.ln_v.to(self.device)
self.ln_i.to(self.device)
self.ln_f.to(self.device)
self.ln_o.to(self.device)

self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()

def forward(self, x, h_prev, c_prev, n_prev):
"""Compute the next hidden, cell, and normalized states in the mLSTM cell.

Parameters
----------
x : torch.Tensor
The number of features in the input.
h_prev : torch.Tensor
Previous hidden state
c_prev : torch.Tensor
Previous cell state
n_prev : torch.Tensor
Previous normalized state

Returns
-------
tuple of torch.Tensor:
h : torch.Tensor
Current hidden state
c : torch.Tensor
Current cell state
n : torch.Tensor
Current normalized state
"""

x = x.to(self.device)
h_prev = h_prev.to(self.device)
c_prev = c_prev.to(self.device)
n_prev = n_prev.to(self.device)

batch_size = x.size(0)
assert (
x.dim() == 2
), f"Input should be 2D (batch_size, input_size), got {x.dim()}D"
assert h_prev.size() == (
batch_size,
self.hidden_size,
), f"h_prev shape mismatch: {h_prev.size()}"
assert c_prev.size() == (
batch_size,
self.hidden_size,
), f"c_prev shape mismatch: {c_prev.size()}"
assert n_prev.size() == (
batch_size,
self.hidden_size,
), f"n_prev shape mismatch: {n_prev.size()}"

x = self.dropout(x)
h_prev = self.dropout(h_prev)

q = self.Wq(x)
k = self.Wk(x) / math.sqrt(self.hidden_size)
v = self.Wv(x)

if self.layer_norm:
q = self.ln_q(q)
k = self.ln_k(k)
v = self.ln_v(v)

i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x))
f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x))
o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x))

k_expanded = k.unsqueeze(-1)
v_expanded = v.unsqueeze(-2)

kv_interaction = k_expanded @ v_expanded

kv_sum = kv_interaction.sum(dim=1)

c = f * c_prev + i * kv_sum
n = f * n_prev + i * k

epsilon = 1e-8
normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon)
h = o * self.tanh(c * normalized_n)

return h, c, n

def init_hidden(self, batch_size):
"""
Initialize hidden, cell, and normalization states.
"""
shape = (batch_size, self.hidden_size)
return (
torch.zeros(shape, device=self.device),
torch.zeros(shape, device=self.device),
torch.zeros(shape, device=self.device),
)
155 changes: 155 additions & 0 deletions pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import torch.nn as nn

from pytorch_forecasting.models.x_lstm_time.m_lstm.cell import mLSTMCell


class mLSTMLayer(nn.Module):
"""Implements a mLSTM (Matrix LSTM) layer.

This class stacks multiple mLSTM cells to form a deep recurrent layer.
It supports residual connections, layer normalization, and dropout.

Parameters
----------
input_size : int
The number of features in the input.
hidden_size : int
The number of features in the hidden state.
num_layers : int
The number of mLSTM layers to stack.
dropout : float, optional
Dropout probability applied to the inputs and intermediate layers,
by default 0.2.
layer_norm : bool, optional
Whether to use layer normalization in each mLSTM cell, by default True.
residual_conn : bool, optional
Whether to enable residual connections between layers, by default True.
device : torch.device, optional
The device to run the computations on

Attributes
----------
cells : nn.ModuleList
A list containing all mLSTM cells in the layer.
dropout : nn.Dropout
Dropout layer applied between layers.

"""

def __init__(
self,
input_size,
hidden_size,
num_layers,
dropout=0.2,
layer_norm=True,
residual_conn=True,
device=None,
):
super(mLSTMLayer, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.layer_norm = layer_norm
self.residual_conn = residual_conn
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)

self.dropout = nn.Dropout(dropout).to(self.device)

self.cells = nn.ModuleList(
[
mLSTMCell(
input_size if i == 0 else hidden_size,
hidden_size,
dropout,
layer_norm,
self.device,
)
for i in range(num_layers)
]
)

def init_hidden(self, batch_size):
"""
Initialize hidden, cell, and normalization states for all layers.
"""
hidden_states, cell_states, norm_states = zip(
*[self.cells[i].init_hidden(batch_size) for i in range(self.num_layers)]
)

return (
torch.stack(hidden_states).to(self.device),
torch.stack(cell_states).to(self.device),
torch.stack(norm_states).to(self.device),
)

def forward(self, x, h=None, c=None, n=None):
"""Forward pass through the mLSTM layer.

Parameters
----------
x : torch.Tensor
The number of features in the input.
h : torch.Tensor, optional
Initial hidden states for all layers
If None, initialized to zeros, by default None.
c : torch.Tensor, optional
Initial cell states for all layers
If None, initialized to zeros, by default None.
n : torch.Tensor, optional
Initial normalized states for all layers
If None, initialized to zeros, by default None.

Returns
-------
tuple
output : torch.Tensor
Final output tensor from the last layer
(h, c, n) : tuple of torch.Tensor
Final hidden, cell, and normalized states for all layers:
- h : torch.Tensor
- c : torch.Tensor
- n : torch.Tensor
"""

x = x.to(self.device).transpose(0, 1)
batch_size, seq_len, _ = x.size()

if h is None or c is None or n is None:
h, c, n = self.init_hidden(batch_size)

outputs = []

for t in range(seq_len):
layer_input = x[:, t, :]
next_hidden_states = []
next_cell_states = []
next_norm_states = []

for i, cell in enumerate(self.cells):

h_i, c_i, n_i = cell(layer_input, h[i], c[i], n[i])

if self.residual_conn and i > 0:
h_i = h_i + layer_input

layer_input = h_i

next_hidden_states.append(h_i)
next_cell_states.append(c_i)
next_norm_states.append(n_i)

h = torch.stack(next_hidden_states).to(self.device)
c = torch.stack(next_cell_states).to(self.device)
n = torch.stack(next_norm_states).to(self.device)

outputs.append(h[-1])

output = torch.stack(outputs, dim=1)

output = output.transpose(0, 1)

return output, (h, c, n)
Loading
Loading