From 856b830bb320be0038289c9bc1d26c3c7c2b4524 Mon Sep 17 00:00:00 2001 From: iconix Date: Sat, 25 Aug 2018 21:59:40 -0700 Subject: [PATCH] Initialize apis, twitter_worker; adjust model.py to model updates; update software list --- .gitignore | 1 + README.md | 12 ++- apis/index.js | 22 +++++ apis/package.json | 24 +++++ apis/sheets.js | 166 +++++++++++++++++++++++++++++++++++ apis/spotify.js | 94 ++++++++++++++++++++ model/{serve.py => model.py} | 10 +-- workers/package.json | 23 +++++ workers/twitter_worker.js | 109 +++++++++++++++++++++++ 9 files changed, 454 insertions(+), 7 deletions(-) create mode 100644 apis/index.js create mode 100644 apis/package.json create mode 100644 apis/sheets.js create mode 100644 apis/spotify.js rename model/{serve.py => model.py} (60%) create mode 100644 workers/package.json create mode 100644 workers/twitter_worker.js diff --git a/.gitignore b/.gitignore index 407f5f3..a835c8a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ __pycache__ .vscode/ +node_modules/ diff --git a/README.md b/README.md index 820b089..7576340 100644 --- a/README.md +++ b/README.md @@ -54,10 +54,18 @@ Once generations for a new proposed tweet are available, an email will be sent t - Sohn, K., Yan, X., Lee, H. Learning Structured Output Representation using Deep Conditional Generative Models [CVAE [paper](http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models.pdf)] -[](http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models.pdf)#vae - Bernardo, F., Zbyszynski, M., Fiebrink, R., Grierson, M. (2016). Interactive Machine Learning for End-User Innovation [[paper](http://research.gold.ac.uk/19767/)] - #onlinelearning - https://devcenter.heroku.com/articles/getting-started-with-python +- https://blog.miguelgrinberg.com/post/the-flask-mega-tutorial-part-xviii-deployment-on-heroku +- https://developer.spotify.com/documentation/web-api/quick-start/ **Software…** +- [PyTorch](https://pytorch.org/) for deep learning - [Quilt](https://quiltdata.com/) for versioning and deploying data -- [Tweepy](https://github.com/tweepy/tweepy) or [Twython](https://github.com/ryanmcgrath/twython) for Python Twitter API access +- [Conda](https://conda.io/docs/) and [npm](https://www.npmjs.com/) for package and environment management in Python and JavaScript +- [Flask](http://flask.pocoo.org/) for a lightweight Python web (model) server +- [Express.js](https://expressjs.com/) for a lightweight Node.js web (API) server +- [Twit](https://github.com/ttezel/twit) for Node.js Twitter API access +- [Spotify Web API Node](https://github.com/thelinmichael/spotify-web-api-node) for Node.js Spotify Web API access +- [Node Google Spreadsheet](https://github.com/theoephraim/node-google-spreadsheet) for Node.js Google Sheets API access - [Heroku](https://www.heroku.com/) free tier for deployments ## Timeline @@ -116,4 +124,4 @@ _Mentor: [Natasha Jaques](https://twitter.com/natashajaques)_ - Assistance in debugging model training - Suggestions for model enhancement -### _Follow my progress this summer with this blog's [#openai](https://iconix.github.io/tags/openai) tag, or on [GitHub](https://github.com/iconix/openai)._ +### _Follow my progress this summer with my blog's [#openai](https://iconix.github.io/tags/openai) tag, or on [GitHub](https://github.com/iconix/openai)._ diff --git a/apis/index.js b/apis/index.js new file mode 100644 index 0000000..95b1fc3 --- /dev/null +++ b/apis/index.js @@ -0,0 +1,22 @@ +// index.js - launch API endpoints that communicate with Spotify, Google Sheets, and Twitter +var express = require('express'); + +var spotify = require('./spotify.js'); +var sheets = require('./sheets.js'); + +var app = express(); +app.use(express.json()); + +// takes in a string query, returns associated genres +app.post('/get_genres', spotify.get_genres); + +// takes in a tweet num, returns tweet at that index in Google Sheets +// TODO: replace with getting tweets from Twitter +app.post('/get_tweet', sheets.get_tweet); +// TODO: batching - /save_gens +app.post('/save_gen', sheets.save_gen); + + +port_num = 8888; +console.log(`Listening on ${port_num}`); +app.listen(port_num); diff --git a/apis/package.json b/apis/package.json new file mode 100644 index 0000000..3ad799e --- /dev/null +++ b/apis/package.json @@ -0,0 +1,24 @@ +{ + "name": "deephypebot_apis", + "version": "0.0.1", + "description": "APIs for pulling Twitter activity, requesting Spotify attributes, and saving generated tweets to Google Sheets", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/iconix/deephypebot.git" + }, + "author": "Nadja Rhodes", + "license": "MIT", + "bugs": { + "url": "https://github.com/iconix/deephypebot/issues" + }, + "homepage": "https://github.com/iconix/deephypebot#readme", + "dependencies": { + "express": "^4.16.3", + "google-spreadsheet": "^2.0.5", + "spotify-web-api-node": "^3.1.1" + } +} diff --git a/apis/sheets.js b/apis/sheets.js new file mode 100644 index 0000000..960a33b --- /dev/null +++ b/apis/sheets.js @@ -0,0 +1,166 @@ +var GoogleSpreadsheet = require('google-spreadsheet'); +var async = require('async'); + +// spreadsheet key is the long id in the sheets URL +var doc = new GoogleSpreadsheet(process.env.DEEPHYPEBOT_SHEETS_ID); +var sheet; + +var get_tweet = (req, res) => { + if (!req || !req.body) { + return res.status(400).send('Request body required'); + } + + var key = 'tweet_num'; + var tweet_num = req.body[key]; + if (tweet_num == undefined) { + return res.status(400).send(`Request JSON must contain "${key}" as a key`); + } + + return get_tweet_from_sheet(res, tweet_num); +} + +var save_gen = (req, res) => { + if (!req || !req.body) { + return res.status(400).send('Request body required'); + } + + var gen_key = 'gen'; + var gen = req.body[gen_key]; + if (gen == undefined) { + return res.status(400).send(`Request JSON must contain "${gen_key}" as a key`); + } + + var num_key = 'tweet_num'; + var num = req.body[num_key]; + if (num == undefined) { + return res.status(400).send(`Request JSON must contain "${num_key}" as a key`); + } + + return save_gen_to_sheet(res, num, gen); +} + +// TODO: cache tweet list (although this code should all go away) +function get_tweet_from_sheet(res, num) { + sheet_title = 'gens'; + col_name = 'tweet'; + + async.series([ + set_auth, + get_worksheet, + read_column + ], (err, results) => { + if (err) { + res.status(500).send(`error: ${err}`); + } else { + results = results.reduce((acc, val) => acc.concat(val), []).filter(Boolean); + + if (num >= results.length) { + res.status(400).send(`tweet_num (${num}) must be less than num of tweets available (${results.length})`); + } else { + res.send(results[num]); + } + } + }); +} + +function save_gen_to_sheet(res, num_param, gen_param) { + sheet_title = 'gens'; + col_name = 'commentary'; + num = num_param; + gen = gen_param; + + async.series([ + set_auth, + get_worksheet, + write_column + ], (err, results) => { + if (err) { + res.status(500).send(`error: ${err}`); + } else { + results = results.reduce((acc, val) => acc.concat(val), []).filter(Boolean); + + if (!results) { + res.status(400).send(`tweet_num (${num}) must be less than num of tweets available (${results.length})`); + } else { + res.send(results[num]); + } + } + }); +} + +var set_auth = (step) => { + var creds_json = { + client_email: process.env.DEEPHYPEBOT_SHEETS_CLIENT_EMAIL, + private_key: process.env.DEEPHYPEBOT_SHEETS_PRIVATE_KEY + } + + doc.useServiceAccountAuth(creds_json, step); +} + +var get_worksheet = (step) => { + doc.getInfo((err, info) => { + if (err) { + step(err); + } + + console.log(`loaded doc: '${info.title}' by ${info.author.email}`); + + for (var ws of info.worksheets) { + if (ws.title == sheet_title) { + sheet = ws; + break; + } + } + + console.log(`found sheet '${sheet.title}' ${sheet.rowCount} x ${sheet.colCount}`); + step(null); + }); +} + +var read_column = (step) => { + // google provides some query options + sheet.getRows({ + offset: 1 + }, (err, rows) => { + if (err) { + step(err); + } + + console.log(`read ${rows.length} rows`); + + var vals = []; + + rows.forEach(row => { + vals.push(row[col_name]); + }); + + step(null, vals); + }); +} + +var write_column = (step) => { + // google provides some query options + sheet.getRows({ + offset: 1 + }, (err, rows) => { + if (err) { + step(err); + } + + console.log(`read ${rows.length} rows`); + + rows[num][col_name] = gen; + + rows[num].save((err) => { + if (err) { + step(err); + } + else { + console.log('saved!'); + step(null, rows[num][col_name]); + } + }); + }); +} + +module.exports = { get_tweet: get_tweet, save_gen: save_gen }; diff --git a/apis/spotify.js b/apis/spotify.js new file mode 100644 index 0000000..87f62ee --- /dev/null +++ b/apis/spotify.js @@ -0,0 +1,94 @@ +var spotify_web_api = require('spotify-web-api-node'); + +var client_id = process.env.DEEPHYPEBOT_SPOTIFY_CLIENT_ID, +client_secret = process.env.DEEPHYPEBOT_SPOTIFY_CLIENT_SECRET; + +// create the api object with the credentials +var spotify_api = new spotify_web_api({ + clientId: client_id, + clientSecret: client_secret +}); + +var expiry_date; + +// search tracks whose name, album or artist contains the query +function get_genres_by_query(query) { + return spotify_api.searchTracks(query).then((data) => { + if (data.body) { + var track = data.body['tracks']['items'][0]; // always get the first track result + var artist_ids = [] + for (var artist of track['artists']) { + artist_ids.push(artist['id']); + } + + // get multiple artists + return spotify_api.getArtists(artist_ids).then((data) => { + var genres = [] + for (var artist of data.body['artists']) { + genres.push(...artist['genres']) + } + + // unique genre list + genres = [...new Set(genres)]; + + return Promise.resolve(genres); + }, (err) => { + return Promise.reject(err); + }); + } + }, (err) => { + return Promise.reject(err); + }); +} + +var get_genres = (req, res) => { + if (!req || !req.body) { + return res.status(400).send('Request body required'); + } + + var key = 'q'; + var q = req.body[key]; + if (q == undefined) { + return res.status(400).send(`Request JSON must contain "${key}" as a key`); + } + + var now = new Date(); + + if (spotify_api.getAccessToken() && expiry_date > now) { + get_genres_by_query(q).then((search_res) => { + res.send(search_res); + }, (err) => { + if (err instanceof WebapiError) { + res.status(err.statusCode).send(err.message); + } else { + res.status(500).send(`error: ${err}`); + } + }); + } else { + // retrieve an access token + spotify_api.clientCredentialsGrant().then((data) => { + expiry_date = new Date(now.getTime() + data.body['expires_in'] * 1000); + + console.log(`The access token expires at ${expiry_date}`); + console.log(`The access token is ${data.body['access_token']}`); + + // save the access token so that it's used in future calls + spotify_api.setAccessToken(data.body['access_token']); + + get_genres_by_query(q).then((search_res) => { + res.send(search_res); + }, (err) => { + if (err instanceof WebapiError) { + res.status(err.statusCode).send(err.message); + } else { + res.status(500).send(`error: ${err}`); + } + }); + }, + (err) => { + res.status(500).send(`Something went wrong when retrieving an access token: ${err}`); + }); + } +} + +module.exports = {get_genres: get_genres}; diff --git a/model/serve.py b/model/model.py similarity index 60% rename from model/serve.py rename to model/model.py index fda1964..9cb020c 100644 --- a/model/serve.py +++ b/model/model.py @@ -1,17 +1,17 @@ -# serve.py - launch a simple PyTorch model server with Flask +# model.py - launch a simple PyTorch model server with Flask from flask import Flask, jsonify, request import torch from pytorchtextvae import generate # https://github.com/iconix/pytorch-text-vae -### Load my pre-trained PyTorch model from another package +### Load my pre-trained PyTorch model from another package (TODO: slow) print('Loading model') DEVICE = torch.device('cpu') # CPU inference # TODO: load model from Quilt -vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state = generate.load_model('reviews_and_metadata_5yrs_state.pt', 'reviews_and_metadata_5yrs_stored_info.pkl', '.', None, DEVICE) -num_sample, max_length, temp, print_z = 1, 50, 0.75, False +vae, input_side, output_side, pairs, dataset, Z_SIZE, random_state = generate.load_model('reviews_and_metadata_5yrs_state.pt', 'reviews_and_metadata_5yrs_stored_info.pkl', DEVICE, cache_path='.') +num_sample = 1 ### Setup Flask app @@ -19,7 +19,7 @@ @app.route('/predict', methods=['POST']) def predict(): - gens, zs, conditions = generate.generate(vae, num_sample, max_length, temp, print_z, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state, DEVICE) + gens, zs, conditions = generate.generate(vae, input_side, output_side, pairs, dataset, Z_SIZE, random_state, DEVICE, genres=request.json['genres'], num_sample=1) return jsonify({'gens': str(gens), 'zs': str(zs), 'conditions': str(dataset.decode_genres(conditions[0]))}) ### App error handling diff --git a/workers/package.json b/workers/package.json new file mode 100644 index 0000000..ac2d066 --- /dev/null +++ b/workers/package.json @@ -0,0 +1,23 @@ +{ + "name": "deephypebot_workers", + "version": "0.0.1", + "description": "Worker agents for monitoring Twitter and Google Sheets", + "main": "twitter_worker.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/iconix/deephypebot.git" + }, + "author": "Nadja Rhodes", + "license": "MIT", + "bugs": { + "url": "https://github.com/iconix/deephypebot/issues" + }, + "homepage": "https://github.com/iconix/deephypebot#readme", + "dependencies": { + "async": "^2.6.1", + "request": "^2.88.0" + } +} diff --git a/workers/twitter_worker.js b/workers/twitter_worker.js new file mode 100644 index 0000000..9be6134 --- /dev/null +++ b/workers/twitter_worker.js @@ -0,0 +1,109 @@ +var request = require('request'); +var async = require('async'); + +var gen; +var genres; +var q; + +var get_tweet = (step) => { + // grab tweet + tweet_num = 2 + + var get_tweet_opts = { + uri: 'http://localhost:8888/get_tweet', + json: { + 'tweet_num': tweet_num + } + }; + + request.post(get_tweet_opts, function (error, response, body) { + if (!error && response.statusCode == 200) { + tweet = body; + + // parse song title + artist + const regex = /\"(.*)\" by ([^ http]*)|(.*)'s \"(.*)\"/gm; + res = regex.exec(tweet); + if (res[1]) { + q = `${res[1]} ${res[2]}`; + } else if (res[3]) { + q = `${res[3]} ${res[4]}`; + } // TODO: else?? this could be better + step(); + } else { + step(error); + } + }); +} + +var get_genres = (step) => { + var get_spotify_opts = { + uri: 'http://localhost:8888/get_genres', + json: { + 'q': q + } + }; + + request.post(get_spotify_opts, function (error, response, body) { + if (!error && response.statusCode == 200) { + genres = body; + step(); + } else { + step(error); + } + }); +} + +var generate = (step) => { + var get_gen_opts = { + uri: 'http://localhost:4444/predict', + json: { + 'genres': genres + } + }; + + request.post(get_gen_opts, function (error, response, body) { + if (!error && response.statusCode == 200) { + gen = JSON.stringify(JSON.parse(body['gens'].replace(/'/g, '"'))[0]); + step(); + } else { + step(error); + } + }); +} + +var save_gen = (step) => { + var get_save_opts = { + uri: 'http://localhost:8888/save_gen', + json: { + 'tweet_num': tweet_num, + 'gen': gen + } + }; + + request.post(get_save_opts, function (error, response, body) { + if (!error && response.statusCode == 200) { + step(); + } else { + step(error); + } + }); +} + +function loop(){ + async.series([ + get_tweet, + get_genres, + generate, + save_gen + ], (err) => { + if (err) { + console.log(`error ${err}`) + } else { + console.log(`saved ${gen}`); + } + }); +} + +setInterval(function(){ + loop(); +}, 60*1000); \ No newline at end of file