diff --git a/pytorch_forecasting/models/x_lstm_time/__init__.py b/pytorch_forecasting/models/x_lstm_time/__init__.py new file mode 100644 index 000000000..fb5bc9892 --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/__init__.py @@ -0,0 +1,5 @@ +"""xLSTMTime implementation for forecasting.""" + +from pytorch_forecasting.models.x_lstm_time.x_lstm import xLSTMTime + +__all__ = ["xLSTMTime"] diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/__init__.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py new file mode 100644 index 000000000..311292fd2 --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/cell.py @@ -0,0 +1,187 @@ +import math + +import torch +import torch.nn as nn + + +class mLSTMCell(nn.Module): + """Implements the Matrix Long Short-Term Memory (mLSTM) Cell. + + Implements the mLSTM algorithm as described in the paper: + (https://arxiv.org/pdf/2407.10240). + + Parameters + ---------- + input_size : int + Size of the input feature vector. + hidden_size : int + Number of hidden units in the LSTM cell. + dropout : float, optional + Dropout rate applied to inputs and hidden states, by default 0.2. + layer_norm : bool, optional + If True, apply Layer Normalization to gates and interactions, by default True. + device : torch.device, optional + Device for computation (CPU or CUDA), by default uses GPU if available. + + + Attributes + ---------- + Wq : nn.Linear + Linear layer for computing the query vector. + Wk : nn.Linear + Linear layer for computing the key vector. + Wv : nn.Linear + Linear layer for computing the value vector. + Wi : nn.Linear + Linear layer for the input gate. + Wf : nn.Linear + Linear layer for the forget gate. + Wo : nn.Linear + Linear layer for the output gate. + dropout : nn.Dropout + Dropout regularization layer. + ln_q, ln_k, ln_v, ln_i, ln_f, ln_o : nn.LayerNorm + Optional layer normalization layers for respective computations. + device : torch.device + Device used for computation. + """ + + def __init__( + self, input_size, hidden_size, dropout=0.2, layer_norm=True, device=None + ): + super(mLSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.layer_norm = layer_norm + + self.device = ( + device + if device is not None + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + + self.Wq = nn.Linear(input_size, hidden_size) + self.Wk = nn.Linear(input_size, hidden_size) + self.Wv = nn.Linear(input_size, hidden_size) + + self.Wi = nn.Linear(input_size, hidden_size) + self.Wf = nn.Linear(input_size, hidden_size) + self.Wo = nn.Linear(input_size, hidden_size) + + self.Wq.to(self.device) + self.Wk.to(self.device) + self.Wv.to(self.device) + self.Wi.to(self.device) + self.Wf.to(self.device) + self.Wo.to(self.device) + + self.dropout = nn.Dropout(dropout) + self.dropout.to(self.device) + + if layer_norm: + self.ln_q = nn.LayerNorm(hidden_size) + self.ln_k = nn.LayerNorm(hidden_size) + self.ln_v = nn.LayerNorm(hidden_size) + self.ln_i = nn.LayerNorm(hidden_size) + self.ln_f = nn.LayerNorm(hidden_size) + self.ln_o = nn.LayerNorm(hidden_size) + + self.ln_q.to(self.device) + self.ln_k.to(self.device) + self.ln_v.to(self.device) + self.ln_i.to(self.device) + self.ln_f.to(self.device) + self.ln_o.to(self.device) + + self.sigmoid = nn.Sigmoid() + self.tanh = nn.Tanh() + + def forward(self, x, h_prev, c_prev, n_prev): + """Compute the next hidden, cell, and normalized states in the mLSTM cell. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h_prev : torch.Tensor + Previous hidden state + c_prev : torch.Tensor + Previous cell state + n_prev : torch.Tensor + Previous normalized state + + Returns + ------- + tuple of torch.Tensor: + h : torch.Tensor + Current hidden state + c : torch.Tensor + Current cell state + n : torch.Tensor + Current normalized state + """ + + x = x.to(self.device) + h_prev = h_prev.to(self.device) + c_prev = c_prev.to(self.device) + n_prev = n_prev.to(self.device) + + batch_size = x.size(0) + assert ( + x.dim() == 2 + ), f"Input should be 2D (batch_size, input_size), got {x.dim()}D" + assert h_prev.size() == ( + batch_size, + self.hidden_size, + ), f"h_prev shape mismatch: {h_prev.size()}" + assert c_prev.size() == ( + batch_size, + self.hidden_size, + ), f"c_prev shape mismatch: {c_prev.size()}" + assert n_prev.size() == ( + batch_size, + self.hidden_size, + ), f"n_prev shape mismatch: {n_prev.size()}" + + x = self.dropout(x) + h_prev = self.dropout(h_prev) + + q = self.Wq(x) + k = self.Wk(x) / math.sqrt(self.hidden_size) + v = self.Wv(x) + + if self.layer_norm: + q = self.ln_q(q) + k = self.ln_k(k) + v = self.ln_v(v) + + i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x)) + f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x)) + o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x)) + + k_expanded = k.unsqueeze(-1) + v_expanded = v.unsqueeze(-2) + + kv_interaction = k_expanded @ v_expanded + + kv_sum = kv_interaction.sum(dim=1) + + c = f * c_prev + i * kv_sum + n = f * n_prev + i * k + + epsilon = 1e-8 + normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon) + h = o * self.tanh(c * normalized_n) + + return h, c, n + + def init_hidden(self, batch_size): + """ + Initialize hidden, cell, and normalization states. + """ + shape = (batch_size, self.hidden_size) + return ( + torch.zeros(shape, device=self.device), + torch.zeros(shape, device=self.device), + torch.zeros(shape, device=self.device), + ) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py new file mode 100644 index 000000000..728506a3e --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/layer.py @@ -0,0 +1,155 @@ +import torch +import torch.nn as nn + +from pytorch_forecasting.models.x_lstm_time.m_lstm.cell import mLSTMCell + + +class mLSTMLayer(nn.Module): + """Implements a mLSTM (Matrix LSTM) layer. + + This class stacks multiple mLSTM cells to form a deep recurrent layer. + It supports residual connections, layer normalization, and dropout. + + Parameters + ---------- + input_size : int + The number of features in the input. + hidden_size : int + The number of features in the hidden state. + num_layers : int + The number of mLSTM layers to stack. + dropout : float, optional + Dropout probability applied to the inputs and intermediate layers, + by default 0.2. + layer_norm : bool, optional + Whether to use layer normalization in each mLSTM cell, by default True. + residual_conn : bool, optional + Whether to enable residual connections between layers, by default True. + device : torch.device, optional + The device to run the computations on + + Attributes + ---------- + cells : nn.ModuleList + A list containing all mLSTM cells in the layer. + dropout : nn.Dropout + Dropout layer applied between layers. + + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers, + dropout=0.2, + layer_norm=True, + residual_conn=True, + device=None, + ): + super(mLSTMLayer, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.layer_norm = layer_norm + self.residual_conn = residual_conn + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + self.dropout = nn.Dropout(dropout).to(self.device) + + self.cells = nn.ModuleList( + [ + mLSTMCell( + input_size if i == 0 else hidden_size, + hidden_size, + dropout, + layer_norm, + self.device, + ) + for i in range(num_layers) + ] + ) + + def init_hidden(self, batch_size): + """ + Initialize hidden, cell, and normalization states for all layers. + """ + hidden_states, cell_states, norm_states = zip( + *[self.cells[i].init_hidden(batch_size) for i in range(self.num_layers)] + ) + + return ( + torch.stack(hidden_states).to(self.device), + torch.stack(cell_states).to(self.device), + torch.stack(norm_states).to(self.device), + ) + + def forward(self, x, h=None, c=None, n=None): + """Forward pass through the mLSTM layer. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h : torch.Tensor, optional + Initial hidden states for all layers + If None, initialized to zeros, by default None. + c : torch.Tensor, optional + Initial cell states for all layers + If None, initialized to zeros, by default None. + n : torch.Tensor, optional + Initial normalized states for all layers + If None, initialized to zeros, by default None. + + Returns + ------- + tuple + output : torch.Tensor + Final output tensor from the last layer + (h, c, n) : tuple of torch.Tensor + Final hidden, cell, and normalized states for all layers: + - h : torch.Tensor + - c : torch.Tensor + - n : torch.Tensor + """ + + x = x.to(self.device).transpose(0, 1) + batch_size, seq_len, _ = x.size() + + if h is None or c is None or n is None: + h, c, n = self.init_hidden(batch_size) + + outputs = [] + + for t in range(seq_len): + layer_input = x[:, t, :] + next_hidden_states = [] + next_cell_states = [] + next_norm_states = [] + + for i, cell in enumerate(self.cells): + + h_i, c_i, n_i = cell(layer_input, h[i], c[i], n[i]) + + if self.residual_conn and i > 0: + h_i = h_i + layer_input + + layer_input = h_i + + next_hidden_states.append(h_i) + next_cell_states.append(c_i) + next_norm_states.append(n_i) + + h = torch.stack(next_hidden_states).to(self.device) + c = torch.stack(next_cell_states).to(self.device) + n = torch.stack(next_norm_states).to(self.device) + + outputs.append(h[-1]) + + output = torch.stack(outputs, dim=1) + + output = output.transpose(0, 1) + + return output, (h, c, n) diff --git a/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py b/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py new file mode 100644 index 000000000..89e46450b --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/m_lstm/network.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn + +from pytorch_forecasting.models.x_lstm_time.m_lstm.layer import mLSTMLayer + + +class mLSTMNetwork(nn.Module): + """Implements the mLSTM Network, a complete model based on stacked mLSTM layers. + + This network combines stacked mLSTM layers and a fully connected output layer. + + Parameters + ---------- + input_size : int + Number of features in the input. + hidden_size : int + Number of features in the hidden state of each mLSTM layer. + num_layers : int + Number of mLSTM layers to stack. + output_size : int + Number of features in the output. + dropout : float, optional + Dropout probability for the mLSTM layers, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization in the mLSTM layers, by default True. + use_residual : bool, optional + Whether to use residual connections in the mLSTM layers, by default True. + device : torch.device, optional + Device to run the computations on + + Attributes + ---------- + mlstm_layer : mLSTMLayer + Stacked mLSTM layers used for processing input sequences. + fc : nn.Linear + Fully connected layer to generate final output. + + + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers, + output_size, + dropout=0.0, + use_layer_norm=True, + use_residual=True, + device=None, + ): + super(mLSTMNetwork, self).__init__() + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + self.mlstm_layer = mLSTMLayer( + input_size, + hidden_size, + num_layers, + dropout, + use_layer_norm, + use_residual, + self.device, + ) + self.fc = nn.Linear(hidden_size, output_size) + + def forward(self, x, h=None, c=None, n=None): + """Forward pass through the mLSTM Network. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h : torch.Tensor, optional + Initial hidden states for all layers. + If None, initialized to zeros, by default None. + c : torch.Tensor, optional + Initial cell states for all layers. + If None, initialized to zeros, by default None. + n : torch.Tensor, optional + Initial normalized states for all layers. + If None, initialized to zeros, by default None. + + Returns + ------- + tuple + output : torch.Tensor + Final output tensor from the fully connected layer. + (h, c, n) : tuple of torch.Tensor + Final hidden, cell, and normalized states for all layers: + - h : torch.Tensor + - c : torch.Tensor + - n : torch.Tensor + """ + output, (h, c, n) = self.mlstm_layer(x, h, c, n) + + output = self.fc(output[-1]) + + return output, (h, c, n) + + def init_hidden(self, batch_size): + """Initialize hidden, cell, and normalization states.""" + return self.mlstm_layer.init_hidden(batch_size) diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/__init__.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py new file mode 100644 index 000000000..c4380356f --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/cell.py @@ -0,0 +1,162 @@ +import math + +import torch +import torch.nn as nn + + +class sLSTMCell(nn.Module): + """Implements the stabilized LSTM cell + + Implements the sLSTM algorithm as described in the paper: + (https://arxiv.org/pdf/2407.10240). + + Parameters + ---------- + input_size : int + Number of input features for the cell. + hidden_size : int + Number of features in the hidden state of the cell. + dropout : float, optional + Dropout probability for the cell's input and hidden state, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization for the cell's internal computations, + by default True. + device : torch.device, optional + The device to run the computations on + + Attributes + ---------- + input_weights : nn.Linear + Linear layer for processing input features into gate computations. + hidden_weights : nn.Linear + Linear layer for processing hidden state features into gate computations. + ln_cell : nn.LayerNorm + Layer normalization for the cell state, applied if use_layer_norm is True. + ln_hidden : nn.LayerNorm + Layer normalization for the output hidden state, + applied if use_layer_norm is True. + ln_input : nn.LayerNorm + Layer normalization for input gates, applied if use_layer_norm is True. + ln_hidden_update : nn.LayerNorm + Layer normalization for hidden state gates, applied if use_layer_norm is True. + dropout_layer : nn.Dropout + Dropout layer applied to inputs and hidden states. + grad_clip : float + Gradient clipping threshold to improve training stability. + eps : float + Small constant for numerical stability in calculations. + tanh : nn.Tanh + Tanh activation function. + sigmoid : nn.Sigmoid + Sigmoid activation function. + """ + + def __init__( + self, input_size, hidden_size, dropout=0.0, use_layer_norm=True, device=None + ): + super(sLSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.dropout = dropout + self.use_layer_norm = use_layer_norm + self.eps = 1e-6 + + self.device = ( + device + if device is not None + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + + self.input_weights = nn.Linear(input_size, 4 * hidden_size).to(self.device) + self.hidden_weights = nn.Linear(hidden_size, 4 * hidden_size).to(self.device) + + if use_layer_norm: + self.ln_cell = nn.LayerNorm(hidden_size).to(self.device) + self.ln_hidden = nn.LayerNorm(hidden_size).to(self.device) + self.ln_input = nn.LayerNorm(4 * hidden_size).to(self.device) + self.ln_hidden_update = nn.LayerNorm(4 * hidden_size).to(self.device) + + self.dropout_layer = nn.Dropout(dropout).to(self.device) + + self.reset_parameters() + + self.grad_clip = 5.0 + + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + + self.to(self.device) + + def reset_parameters(self): + """Initialize parameters using Xavier/Glorot initialization""" + std = 1.0 / math.sqrt(self.hidden_size) + for weight in self.parameters(): + weight.data.uniform_(-std, std) + + def normalized_exp_gate(self, pre_gate): + """Compute normalized exponential gate activation""" + centered = pre_gate - torch.mean(pre_gate, dim=1, keepdim=True) + exp_val = torch.exp(torch.clamp(centered, min=-5.0, max=5.0)) + normalizer = torch.sum(exp_val, dim=1, keepdim=True) + self.eps + return exp_val / normalizer + + def forward(self, x, h_prev, c_prev): + """Forward pass with stabilized exponential gating. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h_prev : torch.Tensor + Previous hidden state tensor. + c_prev : torch.Tensor + Previous cell state tensor. + + Returns + ------- + h : torch.Tensor + Updated hidden state tensor. + c : torch.Tensor + Updated cell state tensor. + """ + x = x.to(self.device) + h_prev = h_prev.to(self.device) + c_prev = c_prev.to(self.device) + + x = self.dropout_layer(x) + h_prev = self.dropout_layer(h_prev) + + gates_x = self.input_weights(x) + gates_h = self.hidden_weights(h_prev) + + if self.use_layer_norm: + gates_x = self.ln_input(gates_x) + gates_h = self.ln_hidden_update(gates_h) + + gates = gates_x + gates_h + i, f, g, o = gates.chunk(4, dim=1) + + i = self.normalized_exp_gate(i) + f = self.normalized_exp_gate(f) + gate_sum = i + f + i = i / (gate_sum + self.eps) + f = f / (gate_sum + self.eps) + + c_tilde = self.tanh(g) + c = f * c_prev + i * c_tilde + if self.use_layer_norm: + c = self.ln_cell(c) + + o = self.sigmoid(o) + c_out = self.tanh(c) + if self.use_layer_norm: + c_out = self.ln_hidden(c_out) + h = o * c_out + + return h, c + + def init_hidden(self, batch_size): + return ( + torch.zeros(batch_size, self.hidden_size, device=self.device), + torch.zeros(batch_size, self.hidden_size, device=self.device), + ) diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py new file mode 100644 index 000000000..0c842966d --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/layer.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn + +from pytorch_forecasting.models.x_lstm_time.s_lstm.cell import sLSTMCell + + +class sLSTMLayer(nn.Module): + """Implements the sLSTM Layer, which consists of multiple stacked sLSTM cells. + + This layer is designed for sequence modeling tasks, supporting multiple layers + with optional residual connections and layer normalization. + + Parameters + ---------- + input_size : int + Number of features in the input. + hidden_size : int + Number of features in the hidden state of each sLSTM cell. + num_layers : int, optional + Number of stacked sLSTM layers, by default 1. + dropout : float, optional + Dropout probability for the input of each sLSTM cell, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization for each sLSTM cell, by default True. + use_residual : bool, optional + Whether to use residual connections in each sLSTM layer, by default True. + device : torch.device, optional + The device to run the computations on + + Attributes + ---------- + cells : nn.ModuleList + List of sLSTMCell objects, one for each layer. + input_projection : nn.Linear or None + Linear layer for projecting input to match hidden state size, + used when residual connections are enabled. + layer_norm_layers : nn.ModuleList + List of LayerNorm layers, one for each sLSTM layer (if use_layer_norm is True). + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers=1, + dropout=0.0, + use_layer_norm=True, + use_residual=True, + device=None, + ): + super(sLSTMLayer, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.dropout = dropout + self.use_layer_norm = use_layer_norm + self.use_residual = use_residual + self.device = ( + device + if device + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + + self.input_projection = None + if self.use_residual and input_size != hidden_size: + self.input_projection = nn.Linear(input_size, hidden_size, bias=False).to( + self.device + ) + + self.cells = nn.ModuleList( + [ + sLSTMCell( + input_size if layer == 0 else hidden_size, + hidden_size, + dropout=dropout, + use_layer_norm=use_layer_norm, + device=self.device, + ) + for layer in range(num_layers) + ] + ) + + if self.use_layer_norm: + self.layer_norm_layers = nn.ModuleList( + [nn.LayerNorm(hidden_size).to(self.device) for _ in range(num_layers)] + ) + + def forward(self, x, h=None, c=None): + """Forward pass through the sLSTM Layer. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h : list of torch.Tensor, optional + Initial hidden states for each layer. + If None, hidden states are initialized to zeros. + c : list of torch.Tensor, optional + Initial cell states for each layer. + If None, cell states are initialized to zeros. + + Returns + ------- + output : torch.Tensor + Tensor containing hidden states for each time step. + (h, c) : tuple of lists + Final hidden and cell states for each layer. + """ + seq_len, batch_size, _ = x.size() + + if h is None or c is None: + h, c = self.init_hidden(batch_size) + + x = x.to(self.device) + h = [hi.to(self.device) for hi in h] + c = [ci.to(self.device) for ci in c] + + outputs = [] + + for t in range(seq_len): + input_t = x[t] + layer_input = input_t + + for layer in range(self.num_layers): + h[layer], c[layer] = self.cells[layer](layer_input, h[layer], c[layer]) + + if self.use_residual: + if layer == 0 and self.input_projection is not None: + residual = self.input_projection(layer_input) + else: + residual = ( + layer_input + if (layer_input.size(-1) == self.hidden_size) + else 0 + ) + h[layer] = h[layer] + residual + + if self.use_layer_norm: + h[layer] = self.layer_norm_layers[layer](h[layer]) + + layer_input = h[layer] + + outputs.append(h[-1]) + + output = torch.stack(outputs) + + h = [hi.detach() for hi in h] + c = [ci.detach() for ci in c] + + return output, (h, c) + + def init_hidden(self, batch_size): + """Initialize hidden and cell states for each layer.""" + return ( + [ + torch.zeros(batch_size, self.hidden_size, device=self.device) + for _ in range(self.num_layers) + ], + [ + torch.zeros(batch_size, self.hidden_size, device=self.device) + for _ in range(self.num_layers) + ], + ) diff --git a/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py b/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py new file mode 100644 index 000000000..5f94023c5 --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/s_lstm/network.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn + +from pytorch_forecasting.models.x_lstm_time.s_lstm.layer import sLSTMLayer + + +class sLSTMNetwork(nn.Module): + """Implements the Stabilized LSTM Network with multiple sLSTM layers. + + This network combines sLSTM layers with a fully connected output layer for + prediction. + + Parameters + ---------- + input_size : int + Number of features in the input. + hidden_size : int + Number of features in the hidden state of each sLSTM layer. + num_layers : int + Number of stacked sLSTM layers in the network. + output_size : int + Number of features in the output prediction. + dropout : float, optional + Dropout probability for the input of each sLSTM layer, by default 0.0. + use_layer_norm : bool, optional + Whether to use layer normalization in each sLSTM layer, by default True. + device : torch.device, optional + Device to run the computations on + + Attributes + ---------- + slstm_layer : sLSTMLayer + Stacked sLSTM layers used for processing input sequences. + fc : nn.Linear + Fully connected layer to generate the final output predictions. + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers, + output_size, + dropout=0.0, + use_layer_norm=True, + device=None, + ): + super(sLSTMNetwork, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.output_size = output_size + self.dropout = dropout + self.device = ( + device + if device + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + + self.slstm_layer = sLSTMLayer( + input_size, + hidden_size, + num_layers, + dropout, + use_layer_norm, + device=self.device, + ) + self.fc = nn.Linear(hidden_size, output_size).to(self.device) + + def forward(self, x, h=None, c=None): + """ + Forward pass through the sLSTM network. + + Parameters + ---------- + x : torch.Tensor + The number of features in the input. + h : list of torch.Tensor, optional + Initial hidden states for each layer. + If None, hidden states are initialized to zeros. + c : list of torch.Tensor, optional + Initial cell states for each layer. + If None, cell states are initialized to zeros. + + Returns + ------- + output : torch.Tensor + Tensor containing the final output predictions. + (h, c) : tuple of lists + Final hidden and cell states for each layer. + """ + output, (h, c) = self.slstm_layer(x, h, c) + output = self.fc(output[-1]) + return output, (h, c) + + def init_hidden(self, batch_size): + """Initialize hidden and cell states for the entire network.""" + return self.slstm_layer.init_hidden(batch_size) diff --git a/pytorch_forecasting/models/x_lstm_time/x_lstm.py b/pytorch_forecasting/models/x_lstm_time/x_lstm.py new file mode 100644 index 000000000..eb02c7306 --- /dev/null +++ b/pytorch_forecasting/models/x_lstm_time/x_lstm.py @@ -0,0 +1,158 @@ +from copy import copy +from typing import Dict, Literal, Optional, Tuple, Union + +import torch +from torch import nn + +from pytorch_forecasting.metrics import SMAPE, Metric +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel +from pytorch_forecasting.models.x_lstm_time.m_lstm.network import mLSTMNetwork +from pytorch_forecasting.models.x_lstm_time.s_lstm.network import sLSTMNetwork + + +class SeriesDecomposition(nn.Module): + """Implements series decomposition using learnable moving averages.""" + + def __init__(self, kernel_size: int): + super(SeriesDecomposition, self).__init__() + self.kernel_size = kernel_size + self.padding = kernel_size // 2 + self.avg_pool = nn.AvgPool1d( + kernel_size=kernel_size, stride=1, padding=self.padding + ) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Decomposes input series into trend and seasonal components. + + Args: + x: Input tensor of shape (batch_size, seq_len, n_features) + + Returns: + Tuple of (trend_component, seasonal_component) + """ + batch_size, seq_len, n_features = x.shape + x_reshaped = x.reshape(batch_size * n_features, 1, seq_len) + trend = self.avg_pool(x_reshaped) + trend = trend.reshape(batch_size, seq_len, n_features) + seasonal = x - trend + + return trend, seasonal + + +class xLSTMTime(AutoRegressiveBaseModel): + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: int, + xlstm_type: Literal["slstm", "mlstm"] = "slstm", + num_layers: int = 1, + decomposition_kernel: int = 25, + input_projection_size: Optional[int] = None, + dropout: float = 0.1, + loss: Metric = SMAPE(), + device: Optional[torch.device] = None, + **kwargs, + ): + + if "target" in kwargs: + del kwargs["target"] + if "target_lags" in kwargs: + del kwargs["target_lags"] + self.save_hyperparameters() + super().__init__(loss=loss, **kwargs) + + if xlstm_type not in ["slstm", "mlstm"]: + raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'") + + self.xlstm_type = xlstm_type + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self.to(self._device) + + self.decomposition = SeriesDecomposition(decomposition_kernel) + self.batch_norm = nn.BatchNorm1d(hidden_size) + + self.input_projection_size = input_projection_size or hidden_size + + self.input_linear = nn.Linear(input_size * 2, self.input_projection_size).to( + self.device + ) + + if xlstm_type == "mlstm": + self.lstm = mLSTMNetwork( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_size=hidden_size, + dropout=dropout, + device=self.device, + ) + else: # slstm + self.lstm = sLSTMNetwork( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + output_size=hidden_size, + dropout=dropout, + device=self.device, + ) + + self.output_linear = nn.Linear(hidden_size, output_size) + self.instance_norm = nn.InstanceNorm1d(output_size) + + def forward( + self, + x: Dict[str, torch.Tensor], + hidden_states: Optional[ + Union[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ] + ] = None, + ) -> Dict[str, torch.Tensor]: + encoder_cont = x["encoder_cont"] + batch_size, seq_len, n_features = encoder_cont.shape + + trend, seasonal = self.decomposition(encoder_cont) + + x = torch.cat([trend, seasonal], dim=-1) + + x = self.input_linear(x) + + x = x.transpose(1, 2) + x = self.batch_norm(x) + x = x.transpose(1, 2) + + if hidden_states is None: + hidden_states = self.lstm.init_hidden(batch_size) + + x = x.transpose(0, 1) + output, hidden_states = self.lstm(x, *hidden_states) + + if isinstance(output, tuple): + output = output[0] + + if output.dim() == 2: + output = output.unsqueeze(0) + + output = self.output_linear(output) + + output = output.transpose(1, 2) + output = self.instance_norm(output) + output = output.transpose(1, 2) + + output = output[0, ..., : self.hparams.output_size] + return self.to_network_output(prediction=output) + + @classmethod + def from_dataset(cls, dataset, **kwargs): + new_kwargs = copy(kwargs) + new_kwargs.update( + cls.deduce_default_output_parameters(dataset, kwargs, SMAPE()) + ) + + return super().from_dataset(dataset, **kwargs) diff --git a/tests/test_models/test_x_lstm.py b/tests/test_models/test_x_lstm.py new file mode 100644 index 000000000..1392f7b9e --- /dev/null +++ b/tests/test_models/test_x_lstm.py @@ -0,0 +1,112 @@ +import shutil + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger +import pytest + +from pytorch_forecasting.metrics import SMAPE +from pytorch_forecasting.models.x_lstm_time.x_lstm import xLSTMTime + + +def _integration( + dataloaders_fixed_window_without_covariates, tmp_path, xlstm_type="slstm", **kwargs +): + + train_dataloader = dataloaders_fixed_window_without_covariates["train"] + val_dataloader = dataloaders_fixed_window_without_covariates["val"] + test_dataloader = dataloaders_fixed_window_without_covariates["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + trainer = pl.Trainer( + max_epochs=3, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + ) + + model_kwargs = { + "input_size": 1, + "output_size": 1, + "hidden_size": 32, + "xlstm_type": xlstm_type, + "learning_rate": 0.01, + "loss": SMAPE(), + } + + model_kwargs.update(kwargs) + + net = xLSTMTime.from_dataset(train_dataloader.dataset, **model_kwargs) + + try: + + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + + net = xLSTMTime.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"xlstm_type": "mlstm"}, + {"num_layers": 2}, + {"xlstm_type": "slstm", "input_projection_size": 32}, + { + "xlstm_type": "mlstm", + "decomposition_kernel": 13, + "dropout": 0.2, + }, + ], +) +def test_integration(dataloaders_fixed_window_without_covariates, tmp_path, kwargs): + _integration(dataloaders_fixed_window_without_covariates, tmp_path, **kwargs) + + +@pytest.fixture(scope="session") +def model(dataloaders_fixed_window_without_covariates): + dataset = dataloaders_fixed_window_without_covariates["train"].dataset + net = xLSTMTime.from_dataset( + dataset, + input_size=1, + hidden_size=32, + output_size=1, + xlstm_type="slstm", + learning_rate=0.01, + loss=SMAPE(), + ) + return net