Skip to content

Tutorial: Your first learner

Gustavo Rosa edited this page Apr 7, 2021 · 8 revisions

Learning with Learnergy is easy as pie. Every code starts with some imports, right?

import torch
import torchvision

from learnergy.models.bernoulli import RBM

The next step is to load any data, where it should be torch.utils.data.Dataset. In this case, we will be loading the standard MNIST dataset from torchvision package.

# Creating training and testing dataset
train = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())

Finally, we can instantiate the RBM class. Note that it is possible to change several input arguments while using it.

# Creating an RBM
model = RBM(n_visible=784, n_hidden=128, steps=1, learning_rate=0.1,
            momentum=0, decay=0, temperature=1, use_gpu=True)

Now, invoke a simple command in order to train the model.

# Training an RBM
mse, pl = model.fit(train, batch_size=128, epochs=5)

Further, new reconstructions can be realized with another simple command.

# Reconstructing test set
rec_mse, v = model.reconstruct(test, batch_size=10000)

One can also save the model using PyTorch's saving function or even check its learning history.

# Saving model
torch.save(model, 'model.pth')

# Checking the model's history
print(model.history)
Clone this wiki locally