Skip to content

Commit

Permalink
-initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Seb-Good committed May 2, 2019
1 parent 2d1e5b0 commit 11b9cfc
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 9 deletions.
13 changes: 6 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# Local imports
from mnistazure.graph import Graph
from mnistazure.network import Network
from mnistazure.config import DATA_PATH, TENSORBOARD_PATH


def main(args):
def train(args):
"""Train MNIST tensorflow model."""
# Image shape
image_shape = (28, 28, 1)
Expand All @@ -27,10 +26,8 @@ def main(args):
network = Network(height=image_shape[0], width=image_shape[1],
channels=image_shape[2], num_labels=num_labels, seed=0)

DATA_PATH = '/home/sebastiangoodfellow/Documents/Code/mnist-azure/data'

# Initialize graph
graph = Graph(network=network, save_path=TENSORBOARD_PATH, data_path=DATA_PATH, max_to_keep=args.max_to_keep)
graph = Graph(network=network, save_path=args.log_dir, data_path=args.data_dir, max_to_keep=args.max_to_keep)

with tf.Session() as sess:

Expand Down Expand Up @@ -83,9 +80,11 @@ def get_parser():
parser = ArgumentParser(description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter)

# Setup arguments
parser.add_argument("--data_dir", dest="data_dir", type=str)
parser.add_argument("--log_dir", dest="log_dir", type=str)
parser.add_argument("--batch_size", dest="batch_size", type=int, default=32)
parser.add_argument("--learning_rate", dest="learning_rate", type=float, default=1e-3)
parser.add_argument("--epochs", dest="epochs", type=int, default=10)
parser.add_argument("--epochs", dest="epochs", type=int, default=5)
parser.add_argument("--max_to_keep", dest="max_to_keep", type=int, default=1)
parser.add_argument("--seed", dest="seed", type=int, default=0)

Expand All @@ -98,4 +97,4 @@ def get_parser():
arguments = get_parser().parse_args()

# Run main function
main(args=arguments)
train(args=arguments)
72 changes: 72 additions & 0 deletions train_azureml_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
train_azureml_local.py
----------------------
By: Sebastian D. Goodfellow, Ph.D., 2019
"""

# 3rd party imports
from azureml.train.dnn import TensorFlow
from azureml.core import Workspace, Datastore, Experiment, RunConfiguration
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter


def train(args):
"""Train MNIST tensorflow model."""
# Set input parameters
script_params = {'--data_dir': r'C:\Users\sebig\Documents\Code\mnist-azure\data',
'--log_dir': './logs',
'--batch_size': 32,
'--learning_rate': 1e-3,
'--epochs': 5,
'--max_to_keep': 1,
'--seed': 0}

# create local compute target
run_local = RunConfiguration()
run_local.environment.python.user_managed_dependencies = True

# Get workspace
ws = Workspace(subscription_id=args.subscription_id, resource_group=args.resource_group,
workspace_name='mnist-azure')

print(ws.name)
print(ws.get_details())
print(ws.experiments)
print('')
for ct in ws.compute_targets:
print(ct.name, ct.type)

print(ws.get_default_compute_target(type='GPU'))

# Get data store
ds = Datastore.get(ws, datastore_name='workspacefilestore')

# Define experiment
ex = Experiment(workspace=ws, name='Experiment_2')

tf_estimator = TensorFlow(source_directory='./', compute_target='local',
entry_script='train.py', script_params=script_params)

run = ex.submit(tf_estimator)
print(run.get_details())


def get_parser():
"""Get parser object for script train.py."""
# Initialize parser
parser = ArgumentParser(description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter)

# Setup arguments
parser.add_argument("--subscription_id", dest="subscription_id", type=str)
parser.add_argument("--resource_group", dest="resource_group", type=str)

return parser


if __name__ == "__main__":

# Parse arguments
arguments = get_parser().parse_args()

# Run main function
train(args=arguments)
3 changes: 1 addition & 2 deletions upload_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def main(args):
"""Upload MNIST dataset to Azure Workspace data store."""
# Get workspace
ws = Workspace(subscription_id=args.subscription_id, resource_group=args.resource_group,
workspace_name=args.workspace_name)
workspace_name='mnist-azure')

# Get data store
ds = Datastore.get(ws, datastore_name='workspacefilestore')
Expand All @@ -33,7 +33,6 @@ def get_parser():
# Setup arguments
parser.add_argument("--subscription_id", dest="subscription_id", type=str)
parser.add_argument("--resource_group", dest="resource_group", type=str)
parser.add_argument("--workspace_name", dest="workspace_name", type=str)

return parser

Expand Down

0 comments on commit 11b9cfc

Please sign in to comment.