Skip to content

Commit

Permalink
Merge pull request #1 from Seb-Good/branch1
Browse files Browse the repository at this point in the history
Branch1
  • Loading branch information
Seb-Good authored Apr 29, 2019
2 parents 462d42f + d155a8a commit f762987
Show file tree
Hide file tree
Showing 14 changed files with 951 additions and 0 deletions.
169 changes: 169 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Jupyter Notebook
.ipynb_checkpoints/*

# PyCharm
.idea/

# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839

# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf

# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml

# Gradle
.idea/**/gradle.xml
.idea/**/libraries

# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# .idea/vcs.xml

# CMake
cmake-build-*/

# Mongo Explorer plugin
.idea/**/mongoSettings.xml

# File-based project format
*.iws

# IntelliJ
out/

# mpeltonen/sbt-idea plugin
.idea_modules/

# JIRA plugin
atlassian-ide-plugin.xml

# Cursive Clojure plugin
.idea/replstate.xml

# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties

# Editor-based Rest Client
.idea/httpRequests

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# MNIST-Azure
##### Sebastian D. Goodfellow, Ph.D.

## Description
Productionization of MNIST Tensorflow model using Azure Machine Learning Service.
2 changes: 2 additions & 0 deletions data/images/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
2 changes: 2 additions & 0 deletions data/labels/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
File renamed without changes.
19 changes: 19 additions & 0 deletions mnistazure/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
config.py
By: Sebastian D. Goodfellow, Ph.D., 2019
"""

# 3rd party imports
import os

# Root working directory
WORKING_DIR = (
os.path.dirname(
os.path.dirname(
os.path.realpath(__file__)
)
)
)

# Projects path
DATA_PATH = os.path.join(WORKING_DIR, 'data')
117 changes: 117 additions & 0 deletions mnistazure/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
generator.py
By: Sebastian D. Goodfellow, Ph.D., 2019
"""

# 3rd party imports
import os
import json
import tensorflow as tf


class DataGenerator(object):

def __init__(self, path, mode, shape, batch_size, prefetch_buffer=1, seed=0, num_parallel_calls=1):

# Set parameters
self.path = path
self.mode = mode
self.shape = shape
self.batch_size = batch_size
self.prefetch_buffer = prefetch_buffer
self.seed = seed
self.num_parallel_calls = num_parallel_calls

# Set attributes
self.lookup_dict = self._get_lookup_dict()
self.file_names = self._get_file_names()
self.labels = self._get_labels()
self.num_samples = len(self.labels)
self.file_paths = self._get_file_paths()
self.num_batches = (self.num_samples + self.batch_size - 1) // self.batch_size
self.current_seed = 0

# Get lambda functions
self.import_image_train_fn = lambda file_path, label: self._import_image(file_path=file_path, label=label)
self.import_image_val_fn = lambda file_path, label: self._import_image(file_path=file_path, label=label)

# Get dataset
self.dataset = self._get_dataset()

# Get iterator
self.iterator = self.dataset.make_initializable_iterator()

def _get_next_seed(self):
"""update current seed"""
self.current_seed += 1
return self.current_seed

def _get_lookup_dict(self):
"""Load lookup dictionary {'file_name': label}."""
return json.load(open(os.path.join(self.path, 'labels', 'labels.json')))

def _get_file_names(self):
"""Get list of image file names."""
return [val[0] for val in self.lookup_dict[self.mode]]

def _get_labels(self):
"""Get list of labels."""
return [val[1] for val in self.lookup_dict[self.mode]]

def _get_file_paths(self):
"""Convert file names to full absolute file paths with .jpg extension."""
return [os.path.join(self.path, 'images', '{}.jpg'.format(file_name)) for file_name in self.file_names]

def _import_image(self, file_path, label):
"""Import and decode image files from file path strings."""
# Get image file name as string
image_string = tf.read_file(filename=file_path)

# Decode JPG image
image_decoded = tf.image.decode_jpeg(contents=image_string, channels=3)

# Normalize RGB values between 0 and 1
image_normalized = tf.image.convert_image_dtype(image=image_decoded, dtype=tf.float32)

# Set tensor shape
image = tf.reshape(tensor=image_normalized, shape=self.shape)

return image, label

def _get_dataset(self):
"""Retrieve tensorflow Dataset object."""
if self.mode == 'train':
return (
tf.data.Dataset.from_tensor_slices(tensors=(tf.constant(value=self.file_paths),
tf.reshape(tensor=tf.constant(self.labels), shape=[-1])))
.shuffle(buffer_size=self.num_samples, reshuffle_each_iteration=True)
.map(map_func=self.import_image_train_fn, num_parallel_calls=self.num_parallel_calls)
.repeat()
.batch(batch_size=self.batch_size)
.prefetch(buffer_size=self.prefetch_buffer)
)
else:
return (
tf.data.Dataset.from_tensor_slices(tensors=(tf.constant(value=self.file_paths),
tf.reshape(tensor=tf.constant(self.labels), shape=[-1])))
.map(map_func=self.import_image_val_fn, num_parallel_calls=self.num_parallel_calls)
.repeat()
.batch(batch_size=self.batch_size)
.prefetch(buffer_size=self.prefetch_buffer)
)

def _import_images(self, file_path, label):
"""Import and decode image files from file path strings."""
# Get image file name as string
image_string = tf.read_file(filename=file_path)

# Decode JPG image
image_decoded = tf.image.decode_jpeg(contents=image_string, channels=3)

# Normalize RGB values between 0 and 1
image_normalized = tf.image.convert_image_dtype(image=image_decoded, dtype=tf.float32)

# Set tensor shape
image = tf.reshape(tensor=image_normalized, shape=self.shape)

return image, label
Loading

0 comments on commit f762987

Please sign in to comment.