diff --git a/mnistazure/network.py b/mnistazure/network.py index a057d07..cf80878 100644 --- a/mnistazure/network.py +++ b/mnistazure/network.py @@ -20,7 +20,7 @@ def __init__(self, height, width, channels, labels, seed=0): self.height = height self.width = width self.channels = channels - self.classes = labels + self.labels = labels self.seed = seed def inference(self, input_layer, is_training): @@ -64,6 +64,9 @@ def inference(self, input_layer, is_training): predictions = {'classes': tf.argmax(input=logits, axis=1), 'probabilities': tf.nn.softmax(logits, name="softmax_tensor")} + # Check output dimensions + assert net.shape[1] == self.labels + return logits, predictions def create_placeholders(self):