Skip to content

Commit

Permalink
Softmax override so that base models can be used directly for classif…
Browse files Browse the repository at this point in the history
…ication
  • Loading branch information
dominiek committed Mar 8, 2017
1 parent 46e9d1d commit c3e0968
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 4 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ make test

## Todo

* Classification: Allow base model classification passthru
* Classification: Refactor bottleneck creation phase
* Object Detection: Clean up settings
* Object Detection: Add non-face example
Expand Down
2 changes: 1 addition & 1 deletion models/inception_v3/labels.json
Original file line number Diff line number Diff line change
Expand Up @@ -998,5 +998,5 @@
{"level": 0, "parents": [{"level": 1, "id": "/wordnet/n12992868", "expanded": true, "name": "fungus"}, {"level": 2, "id": "/wordnet/n4475", "expanded": true, "name": "organism, being"}, {"level": 3, "id": "/wordnet/n4258", "expanded": true, "name": "living thing, animate thing"}, {"level": 4, "id": "/wordnet/n3553", "expanded": true, "name": "whole, unit"}, {"level": 5, "id": "/wordnet/n2684", "expanded": true, "name": "object, physical object"}, {"level": 6, "id": "/wordnet/n1930", "expanded": true, "name": "physical entity"}, {"level": 7, "id": "/wordnet/n1740", "expanded": true, "name": "entity"}], "id": "/wordnet/n13052670", "node_id": 812, "name": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa"},
{"level": 0, "parents": [{"level": 1, "id": "/wordnet/n12992868", "expanded": true, "name": "fungus"}, {"level": 2, "id": "/wordnet/n4475", "expanded": true, "name": "organism, being"}, {"level": 3, "id": "/wordnet/n4258", "expanded": true, "name": "living thing, animate thing"}, {"level": 4, "id": "/wordnet/n3553", "expanded": true, "name": "whole, unit"}, {"level": 5, "id": "/wordnet/n2684", "expanded": true, "name": "object, physical object"}, {"level": 6, "id": "/wordnet/n1930", "expanded": true, "name": "physical entity"}, {"level": 7, "id": "/wordnet/n1740", "expanded": true, "name": "entity"}], "id": "/wordnet/n13054560", "node_id": 981, "name": "bolete"},
{"level": 0, "parents": [{"level": 1, "id": "/wordnet/n13134947", "expanded": true, "name": "fruit"}, {"level": 2, "id": "/wordnet/n11675842", "expanded": true, "name": "reproductive structure"}, {"level": 3, "id": "/wordnet/n13087625", "expanded": true, "name": "plant organ"}, {"level": 4, "id": "/wordnet/n13086908", "expanded": true, "name": "plant part, plant structure"}, {"level": 5, "id": "/wordnet/n19128", "expanded": true, "name": "natural object"}, {"level": 6, "id": "/wordnet/n3553", "expanded": true, "name": "whole, unit"}, {"level": 7, "id": "/wordnet/n2684", "expanded": true, "name": "object, physical object"}, {"level": 8, "id": "/wordnet/n1930", "expanded": true, "name": "physical entity"}, {"level": 9, "id": "/wordnet/n1740", "expanded": true, "name": "entity"}], "id": "/wordnet/n13133613", "node_id": 329, "name": "ear, spike, capitulum"},
{"level": 0, "parents": [{"level": 1, "id": "/wordnet/n15074962", "expanded": true, "name": "tissue, tissue paper"}, {"level": 2, "id": "/wordnet/n14974264", "expanded": true, "name": "paper"}, {"level": 3, "id": "/wordnet/n14580897", "expanded": true, "name": "material, stuff"}, {"level": 4, "id": "/wordnet/n19613", "expanded": true, "name": "substance"}, {"level": 5, "id": "/wordnet/n20827", "expanded": true, "name": "matter"}, {"level": 6, "id": "/wordnet/n1930", "expanded": true, "name": "physical entity"}, {"level": 7, "id": "/wordnet/n1740", "expanded": true, "name": "entity"}], "id": "/wordnet/n15075141", "node_id": 889, "name": "toilet tissue, toilet paper, bathroom tissue"},
{"level": 0, "parents": [{"level": 1, "id": "/wordnet/n15074962", "expanded": true, "name": "tissue, tissue paper"}, {"level": 2, "id": "/wordnet/n14974264", "expanded": true, "name": "paper"}, {"level": 3, "id": "/wordnet/n14580897", "expanded": true, "name": "material, stuff"}, {"level": 4, "id": "/wordnet/n19613", "expanded": true, "name": "substance"}, {"level": 5, "id": "/wordnet/n20827", "expanded": true, "name": "matter"}, {"level": 6, "id": "/wordnet/n1930", "expanded": true, "name": "physical entity"}, {"level": 7, "id": "/wordnet/n1740", "expanded": true, "name": "entity"}], "id": "/wordnet/n15075141", "node_id": 889, "name": "toilet tissue, toilet paper, bathroom tissue"}
]}
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def run(self):

setup(
name="transferflow",
version="0.1.2",
version="0.1.3",
description='Transfer learning for Tensorflow',
url='https://github.com/dominiek/transferflow',
cmdclass={'install': install},
Expand Down
5 changes: 5 additions & 0 deletions test/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class ClassificationTest(unittest.TestCase):
def setUp(self):
pass

def test_1_run_base_model(self):
runner = Runner(base_models_dir + '/inception_v3', softmax_layer='softmax:0')
labels = runner.run(test_dir + '/fixtures/images/lake.jpg')
self.assertEqual(labels[0]['name'], 'boathouse')

def test_2_train_scene_type(self):
scaffold_dir = test_dir + '/fixtures/scaffolds/scene_type'
output_model_path = test_dir + '/fixtures/tmp/scene_type_test'
Expand Down
Binary file modified test/fixtures/.DS_Store
Binary file not shown.
5 changes: 4 additions & 1 deletion transferflow/classification/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def __init__(self, model_dir, softmax_layer='retrained_layer:0', namespace='clas
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name=namespace)
self.softmax_tensor = self.sess.graph.get_tensor_by_name(namespace + '/' + softmax_layer)
softmax_path = softmax_layer
if namespace:
softmax_path = namespace + '/' + softmax_path
self.softmax_tensor = self.sess.graph.get_tensor_by_name(softmax_path)
labels = load_labels(model_dir)
self.labels_by_node_id = {}
for label_id in labels:
Expand Down

0 comments on commit c3e0968

Please sign in to comment.