-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from Seb-Good/branch1
Branch1
- Loading branch information
Showing
14 changed files
with
951 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Jupyter Notebook | ||
.ipynb_checkpoints/* | ||
|
||
# PyCharm | ||
.idea/ | ||
|
||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm | ||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | ||
|
||
# User-specific stuff | ||
.idea/**/workspace.xml | ||
.idea/**/tasks.xml | ||
.idea/**/usage.statistics.xml | ||
.idea/**/dictionaries | ||
.idea/**/shelf | ||
|
||
# Sensitive or high-churn files | ||
.idea/**/dataSources/ | ||
.idea/**/dataSources.ids | ||
.idea/**/dataSources.local.xml | ||
.idea/**/sqlDataSources.xml | ||
.idea/**/dynamic.xml | ||
.idea/**/uiDesigner.xml | ||
.idea/**/dbnavigator.xml | ||
|
||
# Gradle | ||
.idea/**/gradle.xml | ||
.idea/**/libraries | ||
|
||
# Gradle and Maven with auto-import | ||
# When using Gradle or Maven with auto-import, you should exclude module files, | ||
# since they will be recreated, and may cause churn. Uncomment if using | ||
# auto-import. | ||
# .idea/modules.xml | ||
# .idea/*.iml | ||
# .idea/modules | ||
# .idea/vcs.xml | ||
|
||
# CMake | ||
cmake-build-*/ | ||
|
||
# Mongo Explorer plugin | ||
.idea/**/mongoSettings.xml | ||
|
||
# File-based project format | ||
*.iws | ||
|
||
# IntelliJ | ||
out/ | ||
|
||
# mpeltonen/sbt-idea plugin | ||
.idea_modules/ | ||
|
||
# JIRA plugin | ||
atlassian-ide-plugin.xml | ||
|
||
# Cursive Clojure plugin | ||
.idea/replstate.xml | ||
|
||
# Crashlytics plugin (for Android Studio and IntelliJ) | ||
com_crashlytics_export_strings.xml | ||
crashlytics.properties | ||
crashlytics-build.properties | ||
fabric.properties | ||
|
||
# Editor-based Rest Client | ||
.idea/httpRequests | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# dotenv | ||
.env | ||
|
||
# virtualenv | ||
.venv | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# MNIST-Azure | ||
##### Sebastian D. Goodfellow, Ph.D. | ||
|
||
## Description | ||
Productionization of MNIST Tensorflow model using Azure Machine Learning Service. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
* | ||
!.gitignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
* | ||
!.gitignore |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
""" | ||
config.py | ||
By: Sebastian D. Goodfellow, Ph.D., 2019 | ||
""" | ||
|
||
# 3rd party imports | ||
import os | ||
|
||
# Root working directory | ||
WORKING_DIR = ( | ||
os.path.dirname( | ||
os.path.dirname( | ||
os.path.realpath(__file__) | ||
) | ||
) | ||
) | ||
|
||
# Projects path | ||
DATA_PATH = os.path.join(WORKING_DIR, 'data') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
""" | ||
generator.py | ||
By: Sebastian D. Goodfellow, Ph.D., 2019 | ||
""" | ||
|
||
# 3rd party imports | ||
import os | ||
import json | ||
import tensorflow as tf | ||
|
||
|
||
class DataGenerator(object): | ||
|
||
def __init__(self, path, mode, shape, batch_size, prefetch_buffer=1, seed=0, num_parallel_calls=1): | ||
|
||
# Set parameters | ||
self.path = path | ||
self.mode = mode | ||
self.shape = shape | ||
self.batch_size = batch_size | ||
self.prefetch_buffer = prefetch_buffer | ||
self.seed = seed | ||
self.num_parallel_calls = num_parallel_calls | ||
|
||
# Set attributes | ||
self.lookup_dict = self._get_lookup_dict() | ||
self.file_names = self._get_file_names() | ||
self.labels = self._get_labels() | ||
self.num_samples = len(self.labels) | ||
self.file_paths = self._get_file_paths() | ||
self.num_batches = (self.num_samples + self.batch_size - 1) // self.batch_size | ||
self.current_seed = 0 | ||
|
||
# Get lambda functions | ||
self.import_image_train_fn = lambda file_path, label: self._import_image(file_path=file_path, label=label) | ||
self.import_image_val_fn = lambda file_path, label: self._import_image(file_path=file_path, label=label) | ||
|
||
# Get dataset | ||
self.dataset = self._get_dataset() | ||
|
||
# Get iterator | ||
self.iterator = self.dataset.make_initializable_iterator() | ||
|
||
def _get_next_seed(self): | ||
"""update current seed""" | ||
self.current_seed += 1 | ||
return self.current_seed | ||
|
||
def _get_lookup_dict(self): | ||
"""Load lookup dictionary {'file_name': label}.""" | ||
return json.load(open(os.path.join(self.path, 'labels', 'labels.json'))) | ||
|
||
def _get_file_names(self): | ||
"""Get list of image file names.""" | ||
return [val[0] for val in self.lookup_dict[self.mode]] | ||
|
||
def _get_labels(self): | ||
"""Get list of labels.""" | ||
return [val[1] for val in self.lookup_dict[self.mode]] | ||
|
||
def _get_file_paths(self): | ||
"""Convert file names to full absolute file paths with .jpg extension.""" | ||
return [os.path.join(self.path, 'images', '{}.jpg'.format(file_name)) for file_name in self.file_names] | ||
|
||
def _import_image(self, file_path, label): | ||
"""Import and decode image files from file path strings.""" | ||
# Get image file name as string | ||
image_string = tf.read_file(filename=file_path) | ||
|
||
# Decode JPG image | ||
image_decoded = tf.image.decode_jpeg(contents=image_string, channels=3) | ||
|
||
# Normalize RGB values between 0 and 1 | ||
image_normalized = tf.image.convert_image_dtype(image=image_decoded, dtype=tf.float32) | ||
|
||
# Set tensor shape | ||
image = tf.reshape(tensor=image_normalized, shape=self.shape) | ||
|
||
return image, label | ||
|
||
def _get_dataset(self): | ||
"""Retrieve tensorflow Dataset object.""" | ||
if self.mode == 'train': | ||
return ( | ||
tf.data.Dataset.from_tensor_slices(tensors=(tf.constant(value=self.file_paths), | ||
tf.reshape(tensor=tf.constant(self.labels), shape=[-1]))) | ||
.shuffle(buffer_size=self.num_samples, reshuffle_each_iteration=True) | ||
.map(map_func=self.import_image_train_fn, num_parallel_calls=self.num_parallel_calls) | ||
.repeat() | ||
.batch(batch_size=self.batch_size) | ||
.prefetch(buffer_size=self.prefetch_buffer) | ||
) | ||
else: | ||
return ( | ||
tf.data.Dataset.from_tensor_slices(tensors=(tf.constant(value=self.file_paths), | ||
tf.reshape(tensor=tf.constant(self.labels), shape=[-1]))) | ||
.map(map_func=self.import_image_val_fn, num_parallel_calls=self.num_parallel_calls) | ||
.repeat() | ||
.batch(batch_size=self.batch_size) | ||
.prefetch(buffer_size=self.prefetch_buffer) | ||
) | ||
|
||
def _import_images(self, file_path, label): | ||
"""Import and decode image files from file path strings.""" | ||
# Get image file name as string | ||
image_string = tf.read_file(filename=file_path) | ||
|
||
# Decode JPG image | ||
image_decoded = tf.image.decode_jpeg(contents=image_string, channels=3) | ||
|
||
# Normalize RGB values between 0 and 1 | ||
image_normalized = tf.image.convert_image_dtype(image=image_decoded, dtype=tf.float32) | ||
|
||
# Set tensor shape | ||
image = tf.reshape(tensor=image_normalized, shape=self.shape) | ||
|
||
return image, label |
Oops, something went wrong.