diff --git a/Julia_MNIST.ipynb b/Julia_MNIST.ipynb new file mode 100644 index 0000000..2279727 --- /dev/null +++ b/Julia_MNIST.ipynb @@ -0,0 +1,1005 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bcde6ed1", + "metadata": {}, + "source": [ + "# Recognizing handwritten digits using a neural network" + ] + }, + { + "cell_type": "markdown", + "id": "2df55b03", + "metadata": {}, + "source": [ + "We have now reached the point where we can tackle a very interesting task: applying the knowledge we have gained with machine learning in general, and `Flux.jl` in particular, to create a neural network that can recognize handwritten digits! The data are from a data set called MNIST, which has become a classic in the machine learning world.\n", + "\n", + "[We could also try to apply the techniques to the original images of fruit instead. However, the fruit images are much larger than the MNIST images, which makes the learning a suitable neural network too slow.]" + ] + }, + { + "cell_type": "markdown", + "id": "df4f0499", + "metadata": {}, + "source": [ + "## Data Munging\n" + ] + }, + { + "cell_type": "markdown", + "id": "1b7e07a5", + "metadata": {}, + "source": [ + "As we know, the first difficulty with any new data set is locating it, understanding what format it is stored in, reading it in and decoding it into a useful data structure in Julia.\n", + "\n", + "The original MNIST data is available [here](http://yann.lecun.com/exdb/mnist); see also the [Wikipedia page](https://en.wikipedia.org/wiki/MNIST_database). However, the format that the data is stored in is rather obscure.\n", + "\n", + "Fortunately, various packages in Julia provide nicer interfaces to access it. We will use the one provided by `Flux.jl`.\n", + "\n", + "The data are images of handwritten digits, and the corresponding labels that were determined by hand (i.e. by humans). Our job will be to get the computer to **learn** to recognize digits by learning, as usual, the function that relates the input and output data." + ] + }, + { + "cell_type": "markdown", + "id": "f361270a", + "metadata": {}, + "source": [ + "### Loading and examining the data" + ] + }, + { + "cell_type": "markdown", + "id": "601bbe54", + "metadata": {}, + "source": [ + "First we load the required packages:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0c02cd6d", + "metadata": {}, + "outputs": [], + "source": [ + "using Flux, Images, MLDatasets, OneHotArrays, MLUtils" + ] + }, + { + "cell_type": "markdown", + "id": "70d39836", + "metadata": {}, + "source": [ + "Now we read the data:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dccccbd5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(features = [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; … ;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], targets = [7, 2, 1, 0, 4, 1, 4, 9, 5, 9 … 7, 8, 9, 0, 1, 2, 3, 4, 5, 6])" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xtrain, ytrain = MLDatasets.MNIST(:train)[:]\n", + "xtest, ytest = MLDatasets.MNIST(:test)[:]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7eca5c8d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(28, 28, 60000)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "size(xtrain)" + ] + }, + { + "cell_type": "markdown", + "id": "4acabc60", + "metadata": {}, + "source": [ + "Looking at this we see that there are 60000 images of size 28x28" + ] + }, + { + "cell_type": "markdown", + "id": "81c97ec7", + "metadata": {}, + "source": [ + "#### Creating the features\n", + "\n", + "We want to create a vector of feature vectors, each with the floating point values of the 784 pixels.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e223a927", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "784×10000 Matrix{Float32}:\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " ⋮ ⋮ ⋱ ⋮ \n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xtrain = Flux.flatten(xtrain)\n", + "xtest = Flux.flatten(xtest) #add semicolon to suppress the output" + ] + }, + { + "cell_type": "markdown", + "id": "d73a2dcc", + "metadata": {}, + "source": [ + "#### Getting number of inputs and outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6e76c1c9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_inputs = size(xtrain)[1] # 784\n", + "n_outputs = length(unique(ytrain)) # 10" + ] + }, + { + "cell_type": "markdown", + "id": "c69ae367", + "metadata": {}, + "source": [ + "#### Creating the labels\n", + "\n", + "We can just use `onehotbatch` to create them:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "836e8926", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "60000-element Vector{Int64}:\n", + " 5\n", + " 0\n", + " 4\n", + " 1\n", + " 9\n", + " 2\n", + " 1\n", + " 3\n", + " 1\n", + " 4\n", + " 3\n", + " 5\n", + " 3\n", + " ⋮\n", + " 7\n", + " 8\n", + " 9\n", + " 2\n", + " 9\n", + " 5\n", + " 1\n", + " 8\n", + " 3\n", + " 5\n", + " 6\n", + " 8" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ytrain" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f915f294", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Bool[0 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 0], Bool[0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ebefbbc1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10×60000 OneHotMatrix(::Vector{UInt32}) with eltype Bool:\n", + " ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n", + " ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅\n", + " ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n", + " ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅\n", + " ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n", + " 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 ⋅ ⋅\n", + " ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅\n", + " ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅\n", + " ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1\n", + " ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ytrain" + ] + }, + { + "cell_type": "markdown", + "id": "fd7edccc", + "metadata": {}, + "source": [ + "#### Dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "12a3954a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "trainloader (generic function with 1 method)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function trainloader(size)\n", + " train_loader = DataLoader((xtrain, ytrain), batchsize = size, shuffle = true)\n", + " return train_loader\n", + "end\n" + ] + }, + { + "cell_type": "markdown", + "id": "bcab0b33", + "metadata": {}, + "source": [ + "### Setting up a neural network\n", + "\n", + "In the previous notebooks, we arranged the input data for Flux as a `Vector` of `Vector`s.\n", + "Now we will use an alternative arrangement, as a matrix, since that allows `Flux` to use matrix operations, which are more efficient.\n", + "\n", + "The column $i$ of the matrix is a vector consisting of the $i$th data point $\\mathbf{x}^{(i)}$. Similarly, the desired outputs are given as a matrix, with the $i$th column being the desired output $\\mathbf{y}^{(i)}$." + ] + }, + { + "cell_type": "markdown", + "id": "5b01b1ee", + "metadata": {}, + "source": [ + "Now we must set up a neural network. Since the data is complicated, we may expect to need several layers.\n", + "But we can start with a single layer.\n", + "\n", + "- The network will take as inputs the vectors $\\mathbf{x}^{(i)}$, so the input layer has $n$ nodes.\n", + "\n", + "- The output will be a one-hot vector encoding the digit from 1 to 9 or 0 that is desired. There are 10 possible categories, so we need an output layer of size 10.\n", + "\n", + "It is then our task as neural network designers to insert layers between these input and output layers, whose weights will be tuned during the learning process. *This is an art, not a science*! But major advances have come from finding a good structure for the network." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a37d1440", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "784" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ccb28e88", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Chain(\n", + " Dense(784 => 10), \u001b[90m# 7_850 parameters\u001b[39m\n", + " NNlib.softmax,\n", + ") " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = Chain(Dense(n_inputs, n_outputs, identity), softmax)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "516ebe48", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10×60000 Matrix{Float32}:\n", + " 0.0584831 0.0791119 0.0740139 … 0.0960662 0.113541 0.0710475\n", + " 0.0986182 0.0318864 0.110993 0.0720082 0.0755236 0.0503368\n", + " 0.150523 0.18543 0.0926961 0.058545 0.0853627 0.157223\n", + " 0.0907781 0.060158 0.131178 0.053839 0.0396145 0.073929\n", + " 0.163655 0.200239 0.102506 0.121097 0.166219 0.0668063\n", + " 0.179576 0.104816 0.107718 … 0.250003 0.114821 0.275188\n", + " 0.0431743 0.0553528 0.0644514 0.0603587 0.052873 0.0510857\n", + " 0.119632 0.121493 0.13477 0.0976854 0.125548 0.0897675\n", + " 0.0501339 0.113836 0.127531 0.070316 0.17275 0.0794764\n", + " 0.0454265 0.0476765 0.0541426 0.120081 0.0537472 0.0851392" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(xtrain)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "73d17ef0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Descent(0.1)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "L(x,y) = Flux.crossentropy(model(x), y)\n", + "ps = Flux.params(model)\n", + "η = 0.1\n", + "opt = Descent(η)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0a300146", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataLoader{Tuple{Matrix{Float32}, OneHotMatrix{UInt32, 10, Vector{UInt32}}}, Random._GLOBAL_RNG, Val{nothing}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 0]), 60000, false, true, true, false, Val{nothing}(), Random._GLOBAL_RNG())" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainbatch = trainloader(60000)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "62300d53", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2.3590982f0" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "L(xtrain, ytrain)" + ] + }, + { + "cell_type": "markdown", + "id": "03603ed6", + "metadata": {}, + "source": [ + "### Training" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "8f5d884e", + "metadata": {}, + "outputs": [], + "source": [ + "Flux.train!(L, ps, trainbatch, opt)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "64c78591", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 1.025552 seconds (149 allocations: 382.960 MiB, 30.70% gc time)\n" + ] + } + ], + "source": [ + "@time Flux.train!(L, ps, trainbatch, opt)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "f8def27c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.61136353\n", + "0.49229488\n", + "0.44284844\n", + "0.414302\n", + "0.3951605\n" + ] + } + ], + "source": [ + "for i in 1:5\n", + " for i in 1:100\n", + " Flux.train!(L, ps, trainbatch, opt)\n", + " end\n", + " println(L(xtrain, ytrain))\n", + "end" + ] + }, + { + "cell_type": "markdown", + "id": "52072c06", + "metadata": {}, + "source": [ + "### Testing Phase" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "002031ad", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([5, 0, 4, 1, 9, 2, 1, 3, 1, 4 … 9, 2, 9, 5, 1, 8, 3, 5, 6, 8], [7, 2, 1, 0, 4, 1, 4, 9, 5, 9 … 7, 8, 9, 0, 1, 2, 3, 4, 5, 6])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ytrain, ytest = onecold(ytrain, 0:9), onecold(ytest, 0:9)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "805ecd13", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10-element Vector{Float32}:\n", + " 0.9962936\n", + " 4.6798272f-7\n", + " 0.00013244376\n", + " 0.0002770524\n", + " 2.3846585f-6\n", + " 0.0027600925\n", + " 0.00018166976\n", + " 9.3040995f-5\n", + " 0.00020297637\n", + " 5.6219604f-5" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(xtrain)[11:20]" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "4f6e804b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ytrain[2]" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "71bf056b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "predictiontest (generic function with 1 method)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predictiontrain(n) = onecold(model(xtrain), 0:9)[n]\n", + "predictiontest(n) = onecold(model(xtest), 0:9)[n]" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "d12e026c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predictiontrain(2)" + ] + }, + { + "cell_type": "markdown", + "id": "e72cd92b", + "metadata": {}, + "source": [ + "### Evaluate\n", + "Here we evaluate the training and test accuracy (this takes a while)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "13277ae5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8938" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainingaccuracy = sum(predictiontrain(i) == ytrain[i] for i in 1:60000)/60000" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "267e4b1a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8999" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "testaccuracy = sum(predictiontest(i) == ytest[i] for i in 1:10000)/10000" + ] + }, + { + "cell_type": "markdown", + "id": "bcb51a51", + "metadata": {}, + "source": [ + "## Improving the prediction\n", + "\n", + "So far we have used a single layer. In order to improve the prediction, we probably need to use more layers. Try adding more layers yourself and see how the performance changes." + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "4e570762", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Chain(\n", + " Dense(784 => 32, relu), \u001b[90m# 25_120 parameters\u001b[39m\n", + " Dense(32 => 10), \u001b[90m# 330 parameters\u001b[39m\n", + " NNlib.softmax,\n", + ") \u001b[90m # Total: 4 arrays, \u001b[39m25_450 parameters, 99.664 KiB." + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_hidden = 32\n", + "model2 = Chain(\n", + " Dense(n_inputs, n_hidden, relu),\n", + " Dense(n_hidden, n_outputs, identity),\n", + " softmax)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "242de6a1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Adam(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "L2(x, y) = Flux.crossentropy(model2(x), y)\n", + "ps2 = Flux.params(model2)\n", + "opt2 = Adam(0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "01f179ee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Bool[0 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 0], Bool[0 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "a8b08d4d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataLoader{Tuple{Matrix{Float32}, OneHotMatrix{UInt32, 10, Vector{UInt32}}}, Random._GLOBAL_RNG, Val{nothing}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Bool[0 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 0]), 60000, false, true, true, false, Val{nothing}(), Random._GLOBAL_RNG())" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainbatch2 = trainloader(60000)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "f28328bf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.38160208\n", + "0.28114498\n", + "0.23610607\n", + "0.20168594\n", + "0.17544319\n" + ] + } + ], + "source": [ + "for i in 1:5\n", + " for i in 1:100\n", + " Flux.train!(L2, ps2, trainbatch2, opt2)\n", + " end\n", + " println(L2(xtrain, ytrain))\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "914b2913", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([5, 0, 4, 1, 9, 2, 1, 3, 1, 4 … 9, 2, 9, 5, 1, 8, 3, 5, 6, 8], [7, 2, 1, 0, 4, 1, 4, 9, 5, 9 … 7, 8, 9, 0, 1, 2, 3, 4, 5, 6])" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ytrain, ytest = onecold(ytrain, 0:9), onecold(ytest, 0:9)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "bb1cf763", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "predictiontest2 (generic function with 1 method)" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predictiontrain2(n) = onecold(model2(xtrain), 0:9)[n]\n", + "predictiontest2(n) = onecold(model2(xtest), 0:9)[n]" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "f8ff556a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.9513333333333334" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainingaccuracy = sum(predictiontrain2(i) == ytrain[i] for i in 1:60000)/60000" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "546006b1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.9462" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "testaccuracy = sum(predictiontest2(i) == ytest[i] for i in 1:10000)/10000" + ] + }, + { + "cell_type": "markdown", + "id": "6c166639", + "metadata": {}, + "source": [ + "We can also try different methods like CNNs" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.7.3", + "language": "julia", + "name": "julia-1.7" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/README.md b/README.md index 0a0cfdc..ea4a9ea 100644 --- a/README.md +++ b/README.md @@ -1 +1,3 @@ -# Deep-Learning-with-Flux.jl \ No newline at end of file +# Deep-Learning-with-Flux.jl + +Computer vision is a field of AI that trains computers to understand visual data. In the final lecture of the Deep Learning with Flux.jl course on JuliaAcademy, we build a neural network to classify the MNIST database.