Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
iconix committed Aug 21, 2018
1 parent 5ac6f88 commit da0fb21
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions model/serve.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
# serve.py - launch a simple PyTorch model server with Flask

from flask import Flask, jsonify, request
import os
import torch

from pytorchtextvae import generate # https://github.com/iconix/pytorch-text-vae

MODEL_DIR = '.'

app = Flask(__name__)
### Load my pre-trained PyTorch model from another package

print('Loading model')
DEVICE = torch.device('cpu') # CPU inference
# TODO: load model from Quilt
vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state = generate.load_model('reviews_and_metadata_5yrs_state.pt', 'reviews_and_metadata_5yrs_stored_info.pkl', '.', None, DEVICE)
num_sample, max_length, temp, print_z = 1, 50, 0.75, False

### Setup Flask app

app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
gens, zs, conditions = generate.generate(vae, num_sample, max_length, temp, print_z, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state, DEVICE)
return jsonify({'gens': str(gens), 'zs': str(zs), 'conditions': str(dataset.decode_genres(conditions[0]))})
#return jsonify(request.json)


### Error handling code
### App error handling

@app.errorhandler(400)
def handle_bad_request(error):
Expand All @@ -34,6 +34,7 @@ def handle_internal_server(error):
response = jsonify({'error': str(error)})
return response, 500

### Run app

if __name__ == '__main__':
app.run(host='0.0.0.0', port=4444)
app.run(host='0.0.0.0', port=4444)

0 comments on commit da0fb21

Please sign in to comment.