-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_predictions.py
36 lines (31 loc) · 1.22 KB
/
get_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import gradio as gd
from transformers import BertTokenizer
from transformers import TFBertForSequenceClassification
import tensorflow as tf
import numpy as np
from os import getcwd
global model
model_path = getcwd() + '/model.h5'
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels = 3)
model.load_weights(model_path)
global tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)
def get_result(text, model = model, tokenizer = tokenizer):
inputs = tokenizer(text, return_tensors = 'tf')
inputs['labels'] = tf.reshape(tf.constant(1), (-1,1))
outputs = model(inputs)
loss, logits = outputs[:2]
sentiment = np.argmax(np.array(logits))
if sentiment == 2:
return 'Positive'
if sentiment == 1:
return 'Neutral'
else:
return 'Negative'
def predict_on_dataset(data, company, predict_on):
company_data = data[data['tickers'] == company].copy()
if predict_on == 'text':
company_data['BERT_sentiment'] = company_data['text'].apply(get_result, args = (model, tokenizer))
elif predict_on == 'title':
company_data['BERT_sentiment'] = company_data['title'].apply(get_result, args = (model, tokenizer))
return company_data