Skip to content
This repository was archived by the owner on Mar 2, 2025. It is now read-only.

Add CIFAR10 dataset, dataloader and training scripts #76

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "mimage"]
path = mimage
url = https://github.com/fnands/mimage.git
72 changes: 71 additions & 1 deletion basalt/utils/dataloader.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from memory import memcpy

from basalt import dtype, nelts
from basalt import Tensor, TensorShape

from basalt.utils.datasets import CIFAR10

@value
struct Batch[dtype: DType](CollectionElement):
Expand Down Expand Up @@ -108,3 +108,73 @@ struct DataLoader:
self._data_batch_shape,
self._label_batch_shape,
)

@value
struct CIFARDataLoader:
var dataset: CIFAR10 #BaseDataset
# The error I get is: "dynamic traits not supported yet, please use a compile time generic instead of 'BaseDataset'"
var batch_size: Int
var _current_index: Int
var _num_batches: Int
#var _data_batch_shape: TensorShape
#var _label_batch_shape: TensorShape

fn __init__(
inout self,
owned dataset: CIFAR10, #BaseDataset,
batch_size: Int,
):
self.dataset = dataset^
self.batch_size = batch_size

# Number of batches to iter, NOTE: ignore the remainder for now
# var remainder = 1 if self.data.dim(0) % self.batch_size != 0 else 0
self._current_index = 0
self._num_batches = len(self.dataset) // self.batch_size # + remainder


@always_inline
fn __len__(self) -> Int:
"""
Returns the number of the batches left in the dataset.
"""
return self._num_batches

fn __iter__(self) -> Self:
# TODO: Starting the iterator requires to return (COPY!) the whole dataloader which containts the whole dataset
# Does this mean that the whole dataset is copied every epoch ?!
return self

fn __next__(inout self) raises -> Batch[dtype]:

var init_data = self.dataset[self._current_index]

var data = Tensor[dtype](self.batch_size, init_data[0].shape()[0], init_data[0].shape()[1], init_data[0].shape()[2])
var labels = Tensor[dtype](self.batch_size, 1)

var offset = 0
var imsize = init_data[0].num_elements()
for i in range(imsize):
data[i] = init_data[0][i]


labels[0] = init_data[1]

offset += imsize

for i in range(1, self.batch_size):
var next_item = self.dataset[self._current_index + i]
for j in range(imsize):
data[j + offset] = next_item[0][j]
labels[i] = next_item[1]

offset += imsize

var temp_current_index = self._current_index
self._current_index += self.batch_size
self._num_batches -= 1

return Batch[dtype](
batch_data = data,
batch_labels = labels
)
76 changes: 74 additions & 2 deletions basalt/utils/datasets.mojo
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from algorithm import vectorize
from math import div
import os
from pathlib import Path

from basalt import dtype
from basalt import Tensor, TensorShape
from basalt.utils.tensorutils import elwise_op, tmean, tstd

from basalt.utils.tensorutils import elwise_op, tmean, tstd, transpose
import mimage

struct BostonHousing:
alias n_inputs = 13
Expand Down Expand Up @@ -80,6 +82,76 @@ struct MNIST:
vectorize[vecdiv, nelts](self.data.num_elements())



trait BaseDataset(Sized, Copyable, Movable):
fn __getitem__(self, idx: Int) raises -> Tuple[Tensor[dtype], Int]: ...


from tensor import TensorShape as _TensorShape




struct CIFAR10(BaseDataset):
var labels: List[Int]
var file_paths: List[String]


fn __init__(inout self, image_folder: String, label_file: String) raises:
self.labels = List[Int]()
self.file_paths = List[String]()

var label_dict = Dict[String, Int]()

with open(label_file, 'r') as f:
var label_list = f.read().split("\n")
for i in range(len(label_list)):
label_dict[label_list[i]] = i

var files = os.listdir(image_folder)
var file_dir = Path(image_folder)


for i in range(len(files)):
self.file_paths.append(file_dir / files[i])
self.labels.append(label_dict[files[i].split("_")[1].split(".")[0]])

fn __copyinit__(inout self, other: CIFAR10):
self.labels = other.labels
self.file_paths = other.file_paths

# Do I need the ^ here?
fn __moveinit__(inout self, owned other: CIFAR10):
self.labels = other.labels^
self.file_paths = other.file_paths^

fn __len__(self) -> Int:
return len(self.file_paths)

fn __getitem__(self, idx: Int) raises -> Tuple[Tensor[dtype], Int]:
# Get image and cast to dtype
var img = mimage.imread(self.file_paths[idx])

# Transform UInt8 MojoTensor into f32 BasaltTensor
var basalt_tensor = Tensor(img.astype[dtype]())

# Transpose to channels-first
var data = transpose(basalt_tensor, TensorShape(2, 0, 1))

# Normalize data
alias nelts = simdwidthof[dtype]()

@parameter
fn vecdiv[nelts: Int](vec_index: Int):
data.store[nelts](vec_index, div(data.load[nelts](vec_index), 255.0))

vectorize[vecdiv, nelts](data.num_elements())


return Tuple(data, self.labels[idx])



fn read_file(file_path: String) raises -> String:
var s: String
with open(file_path, "r") as f:
Expand Down
124 changes: 124 additions & 0 deletions examples/cifar10.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from time.time import now

import basalt.nn as nn
from basalt import Tensor, TensorShape
from basalt import Graph, Symbol, OP, dtype
from basalt.utils.datasets import CIFAR10
from basalt.utils.dataloader import CIFARDataLoader
from basalt.autograd.attributes import AttributeVector, Attribute


# def plot_image(data: Tensor, num: Int):
# from python.python import Python, PythonObject

# np = Python.import_module("numpy")
# plt = Python.import_module("matplotlib.pyplot")

# var pyimage: PythonObject = np.empty((28, 28), np.float64)
# for m in range(28):
# for n in range(28):
# pyimage.itemset((m, n), data[num * 28 * 28 + m * 28 + n])

# plt.imshow(pyimage)
# plt.show()


fn create_CNN(batch_size: Int) -> Graph:
var g = Graph()
var x = g.input(TensorShape(batch_size, 3, 32, 32))

var x1 = nn.Conv2d(g, x, out_channels=16, kernel_size=5, padding=2)
var x2 = nn.ReLU(g, x1)
var x3 = nn.MaxPool2d(g, x2, kernel_size=2)
var x4 = nn.Conv2d(g, x3, out_channels=32, kernel_size=5, padding=2)
var x5 = nn.ReLU(g, x4)
var x6 = nn.MaxPool2d(g, x5, kernel_size=2)
var x7 = g.op(
OP.RESHAPE,
x6,
attributes=AttributeVector(
Attribute(
"shape",
TensorShape(x6.shape[0], x6.shape[1] * x6.shape[2] * x6.shape[3]),
)
),
)
var out = nn.Linear(g, x7, n_outputs=10)
g.out(out)

var y_true = g.input(TensorShape(batch_size, 10))
var loss = nn.CrossEntropyLoss(g, out, y_true)
# var loss = nn.MSELoss(g, out, y_true)
g.loss(loss)

return g^


fn main() raises:
alias num_epochs = 2
alias batch_size = 8
alias learning_rate = 1e-3

alias graph = create_CNN(batch_size)

# try: graph.render("operator")
# except: print("Could not render graph")

var model = nn.Model[graph]()
var optim = nn.optim.Adam[graph](Reference(model.parameters), lr=learning_rate)

print("Loading data ...")
var train_data: CIFAR10
try:
# Training on test set for as smaller
train_data = CIFAR10(image_folder="./examples/data/cifar/train/", label_file="./examples/data/cifar/labels.txt")
# _ = plot_image(train_data.data, 1)
except e:
print("Could not load data")
print(e)
return

var training_loader = CIFARDataLoader(dataset=train_data, batch_size=batch_size)

print("Training started/")
var start = now()

for epoch in range(num_epochs):
var num_batches: Int = 0
var epoch_loss: Float32 = 0.0
var epoch_start = now()
for batch in training_loader:
# [ONE HOT ENCODING!]
var labels_one_hot = Tensor[dtype](batch.labels.dim(0), 10)
for bb in range(batch.labels.dim(0)):
labels_one_hot[int((bb * 10 + batch.labels[bb]))] = 1.0

# Forward pass
var loss = model.forward(batch.data, labels_one_hot)

# Backward pass
optim.zero_grad()
model.backward()
optim.step()

epoch_loss += loss[0]
num_batches += 1
if num_batches % 100 == 0:
print(
"Epoch [",
epoch + 1,
"/",
num_epochs,
"],\t Step [",
num_batches,
"/",
len(train_data) // batch_size,
"],\t Loss:",
epoch_loss / num_batches,
)

print("Epoch time: ", (now() - epoch_start) / 1e9, "seconds")

print("Training finished: ", (now() - start) / 1e9, "seconds")

model.print_perf_metrics("ms", True)
Loading