diff --git a/elman_rnn/Elman_RNN.py b/elman_rnn/Elman_RNN.py index 236cd55..bb8464b 100644 --- a/elman_rnn/Elman_RNN.py +++ b/elman_rnn/Elman_RNN.py @@ -1,13 +1,12 @@ + import torch -from torch.autograd import Variable import numpy as np +import torch.nn as nn import pylab as pl -import torch.nn.init as init - torch.manual_seed(1) -dtype = torch.FloatTensor +dtype = torch.float32 input_size, hidden_size, output_size = 7, 6, 1 epochs = 200 seq_length = 20 @@ -17,15 +16,12 @@ data = np.sin(data_time_steps) data.resize((seq_length + 1, 1)) -x = Variable(torch.Tensor(data[:-1]).type(dtype), requires_grad=False) -y = Variable(torch.Tensor(data[1:]).type(dtype), requires_grad=False) +x = torch.tensor(data[:-1], dtype=dtype, requires_grad=False) +y = torch.tensor(data[1:], dtype=dtype, requires_grad=False) + +w1 = torch.normal(0.0, 0.4, size=(input_size, hidden_size), dtype=dtype, requires_grad=True) +w2 = torch.normal(0.0, 0.3, size=(hidden_size, output_size), dtype=dtype, requires_grad=True) -w1 = torch.FloatTensor(input_size, hidden_size).type(dtype) -init.normal(w1, 0.0, 0.4) -w1 = Variable(w1, requires_grad=True) -w2 = torch.FloatTensor(hidden_size, output_size).type(dtype) -init.normal(w2, 0.0, 0.3) -w2 = Variable(w2, requires_grad=True) def forward(input, context_state, w1, w2): xh = torch.cat((input, context_state), 1) @@ -33,32 +29,31 @@ def forward(input, context_state, w1, w2): out = context_state.mm(w2) return (out, context_state) +criterion = nn.MSELoss() for i in range(epochs): total_loss = 0 - context_state = Variable(torch.zeros((1, hidden_size)).type(dtype), requires_grad=True) + context_state = torch.zeros((1, hidden_size), dtype=dtype) for j in range(x.size(0)): input = x[j:(j+1)] target = y[j:(j+1)] (pred, context_state) = forward(input, context_state, w1, w2) - loss = (pred - target).pow(2).sum()/2 + loss = criterion(pred, target) total_loss += loss - loss.backward() + loss.backward(retain_graph=True) w1.data -= lr * w1.grad.data w2.data -= lr * w2.grad.data w1.grad.data.zero_() w2.grad.data.zero_() - context_state = Variable(context_state.data) if i % 10 == 0: - print("Epoch: {} loss {}".format(i, total_loss.data[0])) + print("Epoch: {} loss {}".format(i, total_loss.item())) -context_state = Variable(torch.zeros((1, hidden_size)).type(dtype), requires_grad=False) +context_state = torch.zeros((1, hidden_size), dtype=dtype) predictions = [] for i in range(x.size(0)): input = x[i:i+1] (pred, context_state) = forward(input, context_state, w1, w2) - context_state = context_state predictions.append(pred.data.numpy().ravel()[0]) @@ -66,5 +61,3 @@ def forward(input, context_state, w1, w2): pl.scatter(data_time_steps[1:], predictions, label="Predicted") pl.legend() pl.show() - -