Skip to content

Commit

Permalink
-Added train.py main function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Seb-Good committed Apr 30, 2019
1 parent 936c279 commit 87bcf22
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 96 deletions.
7 changes: 7 additions & 0 deletions .amlignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.ipynb_checkpoints
azureml-logs
.azureml
.git
outputs
azureml-setup
docs
18 changes: 1 addition & 17 deletions mnistazure/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _import_image(self, file_path, label):
image_string = tf.read_file(filename=file_path)

# Decode JPG image
image_decoded = tf.image.decode_jpeg(contents=image_string, channels=3)
image_decoded = tf.image.decode_jpeg(contents=image_string, channels=1)

# Normalize RGB values between 0 and 1
image_normalized = tf.image.convert_image_dtype(image=image_decoded, dtype=tf.float32)
Expand Down Expand Up @@ -99,19 +99,3 @@ def _get_dataset(self):
.batch(batch_size=self.batch_size)
.prefetch(buffer_size=self.prefetch_buffer)
)

def _import_images(self, file_path, label):
"""Import and decode image files from file path strings."""
# Get image file name as string
image_string = tf.read_file(filename=file_path)

# Decode JPG image
image_decoded = tf.image.decode_jpeg(contents=image_string, channels=3)

# Normalize RGB values between 0 and 1
image_normalized = tf.image.convert_image_dtype(image=image_decoded, dtype=tf.float32)

# Set tensor shape
image = tf.reshape(tensor=image_normalized, shape=self.shape)

return image, label
Empty file removed mnistazure/train.py
Empty file.
94 changes: 16 additions & 78 deletions notebooks/train_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -32,7 +32,7 @@
"import matplotlib.pylab as plt\n",
"\n",
"# Import local Libraries\n",
"sys.path.insert(0, r'C:\\Users\\sebig\\Documents\\Code\\mnist-azure')\n",
"sys.path.insert(0, '/home/sebastiangoodfellow/Documents/Code/mnist-azure')\n",
"from mnistazure.config import DATA_PATH, TENSORBOARD_PATH\n",
"from mnistazure.generator import DataGenerator\n",
"from mnistazure.graph import Graph\n",
Expand All @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -58,9 +58,6 @@
"# Maximum number of checkpoints to keep\n",
"max_to_keep = 1\n",
"\n",
"# Experiment name\n",
"name = 'test'\n",
"\n",
"# Random seed\n",
"seed = 0\n",
"\n",
Expand All @@ -77,20 +74,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<PrefetchDataset shapes: ((?, 28, 28, 1), (?,)), types: (tf.float32, tf.int32)>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Initialize generator\n",
"generator = DataGenerator(path=DATA_PATH, mode='train', shape=image_shape, batch_size=32, \n",
Expand All @@ -109,31 +95,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "InvalidArgumentError",
"evalue": "Input to reshape is a tensor with 2352 values, but the requested shape has 784\n\t [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32](convert_image, Reshape/shape)]]\n\t [[Node: next_batch/IteratorGetNext = IteratorGetNext[output_shapes=[[?,28,28,1], [?]], output_types=[DT_FLOAT, DT_INT32], _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](iterator/IteratorFromStringHandle)]]\n\t [[Node: gradients/ConvNet/flatten/Reshape_grad/Shape-1-1-VecPermuteNCHWToNHWC-LayoutOptimizer/_23 = _HostRecv[client_terminated=false, recv_device=\"/job:localhost/replica:0/task:0/device:GPU:0\", send_device=\"/job:localhost/replica:0/task:0/device:CPU:0\", send_device_incarnation=1, tensor_name=\"edge_180_g...tOptimizer\", tensor_type=DT_INT32, _device=\"/job:localhost/replica:0/task:0/device:GPU:0\"]()]]",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mInvalidArgumentError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\Anaconda3\\envs\\mnist-azure\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m 1321\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1322\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1323\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\envs\\mnist-azure\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[1;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[0;32m 1306\u001b[0m return self._call_tf_sessionrun(\n\u001b[1;32m-> 1307\u001b[1;33m options, feed_dict, fetch_list, target_list, run_metadata)\n\u001b[0m\u001b[0;32m 1308\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\envs\\mnist-azure\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[1;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[0;32m 1408\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1409\u001b[1;33m run_metadata)\n\u001b[0m\u001b[0;32m 1410\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mInvalidArgumentError\u001b[0m: Input to reshape is a tensor with 2352 values, but the requested shape has 784\n\t [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32](convert_image, Reshape/shape)]]\n\t [[Node: next_batch/IteratorGetNext = IteratorGetNext[output_shapes=[[?,28,28,1], [?]], output_types=[DT_FLOAT, DT_INT32], _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](iterator/IteratorFromStringHandle)]]\n\t [[Node: gradients/ConvNet/flatten/Reshape_grad/Shape-1-1-VecPermuteNCHWToNHWC-LayoutOptimizer/_23 = _HostRecv[client_terminated=false, recv_device=\"/job:localhost/replica:0/task:0/device:GPU:0\", send_device=\"/job:localhost/replica:0/task:0/device:CPU:0\", send_device_incarnation=1, tensor_name=\"edge_180_g...tOptimizer\", tensor_type=DT_INT32, _device=\"/job:localhost/replica:0/task:0/device:GPU:0\"]()]]",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[1;31mInvalidArgumentError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-4-da1d95ac3061>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 48\u001b[0m feed_dict={graph.batch_size: batch_size, graph.is_training: True,\n\u001b[0;32m 49\u001b[0m \u001b[0mgraph\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mlearning_rate\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 50\u001b[1;33m graph.mode_handle: handle_train})\n\u001b[0m\u001b[0;32m 51\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 52\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\envs\\mnist-azure\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 898\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 899\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[1;32m--> 900\u001b[1;33m run_metadata_ptr)\n\u001b[0m\u001b[0;32m 901\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 902\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\envs\\mnist-azure\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[1;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1133\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1134\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[1;32m-> 1135\u001b[1;33m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[0;32m 1136\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1137\u001b[0m \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\envs\\mnist-azure\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1314\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1315\u001b[0m return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[1;32m-> 1316\u001b[1;33m run_metadata)\n\u001b[0m\u001b[0;32m 1317\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1318\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\envs\\mnist-azure\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m 1333\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1334\u001b[0m \u001b[1;32mpass\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1335\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mtype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnode_def\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mop\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1336\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1337\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_extend_graph\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mInvalidArgumentError\u001b[0m: Input to reshape is a tensor with 2352 values, but the requested shape has 784\n\t [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32](convert_image, Reshape/shape)]]\n\t [[Node: next_batch/IteratorGetNext = IteratorGetNext[output_shapes=[[?,28,28,1], [?]], output_types=[DT_FLOAT, DT_INT32], _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](iterator/IteratorFromStringHandle)]]\n\t [[Node: gradients/ConvNet/flatten/Reshape_grad/Shape-1-1-VecPermuteNCHWToNHWC-LayoutOptimizer/_23 = _HostRecv[client_terminated=false, recv_device=\"/job:localhost/replica:0/task:0/device:GPU:0\", send_device=\"/job:localhost/replica:0/task:0/device:CPU:0\", send_device_incarnation=1, tensor_name=\"edge_180_g...tOptimizer\", tensor_type=DT_INT32, _device=\"/job:localhost/replica:0/task:0/device:GPU:0\"]()]]"
]
}
],
"outputs": [],
"source": [
"# Initialize network\n",
"network = Network(height=image_shape[0], width=image_shape[1], \n",
Expand All @@ -147,10 +111,10 @@
"learning_rate = 1e-3\n",
"\n",
"# Number of epochs\n",
"epochs = 100\n",
"epochs = 5\n",
"\n",
"# Batch size\n",
"batch_size = 32\n",
"batch_size = 128\n",
"\n",
"with tf.Session() as sess:\n",
" \n",
Expand Down Expand Up @@ -180,44 +144,18 @@
" # Loop through train dataset batches\n",
" for batch in range(steps_per_epoch):\n",
" \n",
" _, _, _, _ = sess.run(fetches=[graph.train_op, graph.update_metrics_op,\n",
" graph.train_summary_metrics_op, graph.global_step],\n",
" feed_dict={graph.batch_size: batch_size, graph.is_training: True,\n",
" graph.learning_rate: learning_rate,\n",
" graph.mode_handle: handle_train})\n",
" loss, accuracy, _, _, _, _ = sess.run(fetches=[graph.loss, graph.accuracy, graph.train_op, \n",
" graph.update_metrics_op, graph.train_summary_metrics_op, \n",
" graph.global_step],\n",
" feed_dict={graph.batch_size: batch_size, graph.is_training: True,\n",
" graph.learning_rate: learning_rate,\n",
" graph.mode_handle: handle_train})\n",
" print(loss, accuracy)\n",
"\n",
" \n",
" # Initialize the train dataset iterator at the end of each epoch\n",
" sess.run(fetches=[graph.generator_train.iterator.initializer],\n",
" feed_dict={graph.batch_size: batch_size})"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"25088"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"28*28*32"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
tensorflow==1.8.0
numpy==1.16.3
matplotlib==3.0.3
opencv==3.4.2
opencv==3.4.2
azureml-core==1.0.33.1
96 changes: 96 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
train.py
--------
By: Sebastian D. Goodfellow, Ph.D., 2019
"""

# 3rd party imports
import numpy as np
import tensorflow as tf
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter

# Local imports
from mnistazure.graph import Graph
from mnistazure.network import Network
from mnistazure.config import DATA_PATH, TENSORBOARD_PATH


def main(args):
"""Build saved model for serving."""
# Image shape
image_shape = (28, 28, 1)

# Number of unique labels
num_labels = 10

# Initialize network
network = Network(height=image_shape[0], width=image_shape[1],
channels=image_shape[2], num_labels=num_labels, seed=0)

# Initialize graph
graph = Graph(network=network, save_path=TENSORBOARD_PATH, data_path=DATA_PATH, max_to_keep=args.max_to_keep)

with tf.Session() as sess:

# Initialize variables
sess.run(graph.init_global)

# Get number of training batches
num_train_batches = graph.generator_train.num_batches.eval(
feed_dict={graph.batch_size: args.batch_size})

# Get number of batch steps per epoch
steps_per_epoch = int(np.ceil(num_train_batches / 1))

# Get mode handle for training
handle_train = sess.run(graph.generator_train.iterator.string_handle())

# Initialize the train dataset iterator at the beginning of each epoch
sess.run(fetches=[graph.generator_train.iterator.initializer],
feed_dict={graph.batch_size: args.batch_size})

# Loop through epochs
for epoch in range(args.epochs):

# Initialize metrics
sess.run(fetches=[graph.init_metrics_op])

# Loop through train dataset batches
for batch in range(steps_per_epoch):
loss, accuracy, _, _, _, _ = sess.run(
fetches=[graph.loss, graph.accuracy, graph.train_op, graph.update_metrics_op,
graph.train_summary_metrics_op, graph.global_step],
feed_dict={graph.batch_size: args.batch_size, graph.is_training: True,
graph.learning_rate: args.learning_rate, graph.mode_handle: handle_train}
)

if batch % 100 == 0:
print('Loss: {}, Accuracy: {}'.format(loss, accuracy))

# Initialize the train dataset iterator at the end of each epoch
sess.run(fetches=[graph.generator_train.iterator.initializer],
feed_dict={graph.batch_size: args.batch_size})


def get_parser():
"""Get parser object for script predict.py."""
# Initialize parser
parser = ArgumentParser(description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter)

# Setup arguments
parser.add_argument("--batch_size", dest="batch_size", type=int, default=32)
parser.add_argument("--learning_rate", dest="learning_rate", type=float, default=1e-3)
parser.add_argument("--epochs", dest="epochs", type=int, default=10)
parser.add_argument("--max_to_keep", dest="max_to_keep", type=int, default=1)
parser.add_argument("--seed", dest="seed", type=int, default=0)

return parser


if __name__ == "__main__":

# Parse arguments
arguments = get_parser().parse_args()

# Run main function
main(args=arguments)

0 comments on commit 87bcf22

Please sign in to comment.