forked from dusty-nv/jetson-voice
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nlp.py
executable file
·345 lines (269 loc) · 12.6 KB
/
nlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
#!/usr/bin/env python3
# coding: utf-8
from jetson_voice.utils import load_resource
def NLP(resource, *args, **kwargs):
"""
Factory for automatically loading NLP models or services.
Returns an instance of:
- IntentSlotService
- QuestionAnswerService
- TextClassificationService
- TokenClassificationService
"""
from jetson_voice.auto import AutoModel
return AutoModel(resource, domain='nlp', *args, **kwargs)
def IntentSlot(resource, *args, **kwargs):
"""
Loads a NLP joint intent/slot classifier service or model.
See the IntentSlotService class for the signature that implementations use.
"""
factory_map = {
'tensorrt' : 'jetson_voice.models.nlp.IntentSlotEngine',
'onnxruntime' : 'jetson_voice.models.nlp.IntentSlotEngine'
}
return load_resource(resource, factory_map, *args, **kwargs)
class IntentSlotService():
"""
Intent/slot classifier service base class.
"""
def __init__(self, config, *args, **kwargs):
"""
Create service instance.
"""
self.config = config
def __call__(self, query):
"""
Perform intent/slot classification on the input query.
Parameters:
query (string) -- The text query, for example:
'What is the weather in San Francisco tomorrow?'
Returns a dict with the following keys:
'intent' (string) -- the classified intent label
'score' (float) -- the intent probability [0,1]
'slots' (list[dict]) -- a list of dicts, where each dict has the following keys:
'slot' (string) -- the slot label
'text' (string) -- the slot text from the query
'score' (float) -- the slot probability [0,1]
"""
pass
def QuestionAnswer(resource, *args, **kwargs):
"""
Loads a NLP question answering service or model.
See the QuestionAnswerService class for the signature that implementations use.
"""
factory_map = {
'tensorrt' : 'jetson_voice.models.nlp.QuestionAnswerEngine',
'onnxruntime' : 'jetson_voice.models.nlp.QuestionAnswerEngine'
}
return load_resource(resource, factory_map, *args, **kwargs)
class QuestionAnswerService():
"""
Question answering service base class.
"""
def __init__(self, config, *args, **kwargs):
"""
Create service instance.
"""
self.config = config
def __call__(self, query, top_k=1):
"""
Perform question/answering on the input query.
Parameters:
query (dict or tuple) -- Either a dict with 'question' and 'context' keys,
or a (question, context) tuple.
top_k (int) -- How many of the top results to return, sorted by score.
The default (topk=1) is to return just the top result.
If topk > 1, then a list of results will be returned.
Returns:
dict(s) with the following keys:
'answer' (string) -- the answer text
'score' (float) -- the probability [0,1]
'start' (int) -- the starting character index of the answer into the context text
'end' (int) -- the ending character index of the answer into the context text
If top_k > 1, a list of dicts with the topk results will be returned.
If top_k == 1, just the single dict with the top score will be returned.
"""
pass
def TextClassification(resource, *args, **kwargs):
"""
Loads a NLP text classification service or model.
See the TextClassificationService class for the signature that implementations use.
"""
factory_map = {
'tensorrt' : 'jetson_voice.models.nlp.TextClassificationEngine',
'onnxruntime' : 'jetson_voice.models.nlp.TextClassificationEngine'
}
return load_resource(resource, factory_map, *args, **kwargs)
class TextClassificationService():
"""
Text classification service base class.
"""
def __init__(self, config, *args, **kwargs):
"""
Create service instance.
"""
self.config = config
def __call__(self, query):
"""
Perform text classification on the input query.
Parameters:
query (string) -- The text query, for example:
'Today was warm, sunny and beautiful out.'
Returns a dict with the following keys:
'class' (int) -- the predicted class index
'label' (string) -- the predicted class label (and if there aren't labels `str(class)`)
'score' (float) -- the classification probability [0,1]
"""
pass
def TokenClassification(resource, *args, **kwargs):
"""
Loads a NLP token classification (aka Named Entity Recognition) service or model.
See the TokenClassificationService class for the signature that implementations use.
"""
factory_map = {
'tensorrt' : 'jetson_voice.models.nlp.TokenClassificationEngine',
'onnxruntime' : 'jetson_voice.models.nlp.TokenClassificationEngine'
}
return load_resource(resource, factory_map, *args, **kwargs)
class TokenClassificationService():
"""
Token classification (aka Named Entity Recognition) service base class.
"""
def __init__(self, config, *args, **kwargs):
"""
Create service instance.
"""
self.config = config
def __call__(self, query):
"""
Perform token classification (NER) on the input query and return tagged entities.
Parameters:
query (string) -- The text query, for example:
"Ben is from Chicago, a city in the state of Illinois, US'
Returns a list[dict] of tagged entities with the following dictionary keys:
'class' (int) -- the entity class index
'label' (string) -- the entity class label
'score' (float) -- the classification probability [0,1]
'text' (string) -- the corresponding text from the input query
'start' (int) -- the starting character index of the text
'end' (int) -- the ending character index of the text
"""
pass
@staticmethod
def tag_string(query, tags, scores=False):
"""
Returns a string with the tags inserted inline with the query. For example:
"Ben[B-PER] is from Chicago[B-LOC], a city in the state of Illinois[B-LOC], US[B-LOC]"
Parameters:
query (string) -- The original query string.
tags (list[dict]) -- The tags predicted by the model.
scores (bool) -- If true, the probabilities will be added inline.
If false (default), only the tag labels will be added.
"""
char_offset = 0
for tag in tags:
if scores:
tag_str = f"[{tag['label']} {tag['score']:.3}]"
else:
tag_str = f"[{tag['label']}]"
query = query[:tag['end'] + char_offset] + tag_str + query[tag['end'] + char_offset:]
char_offset += len(tag_str)
return query
if __name__ == "__main__":
from jetson_voice import ConfigArgParser
import pprint
parser = ConfigArgParser()
parser.add_argument('--model', default='distilbert_intent', type=str)
parser.add_argument('--type', default='intent_slot', type=str)
args = parser.parse_args()
args.type = args.type.lower()
print(args)
if args.type == 'intent_slot':
model = IntentSlot(args.model)
# create some test queries
queries = [
'Set alarm for Seven Thirty AM',
'Please increase the volume',
'What is my schedule for tomorrow',
'Place an order for a large pepperoni pizza from Dominos'
]
# process the queries
for query in queries:
results = model(query)
print('\n')
print('query:', query)
print('')
pprint.pprint(results)
elif args.type == 'question_answer' or args.type == 'qa':
model = QuestionAnswer(args.model)
# create some test queries
queries = []
queries.append({
"question" : "What is the value of Pi?",
"context" : "Some people have said that Pi is tasty but there should be a value for Pi, and the value for Pi is around 3.14. "
"Pi is the ratio of a circle's circumference to it's diameter. The constant Pi was first calculated by Archimedes "
"in ancient Greece around the year 250 BC."
})
queries.append({
"question" : "Who discovered Pi?",
"context" : queries[-1]['context']
})
queries.append({
"question" : "Which nation contains the majority of the Amazon forest?",
"context" : "The Amazon rainforest is a moist broadleaf forest that covers most of the Amazon basin of South America. "
"This basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres "
"(2,100,000 sq mi) are covered by the rainforest. The majority of the forest is contained within Brazil, "
"with 60% of the rainforest, followed by Peru with 13%, and Colombia with 10%."
})
queries.append({
"question" : "How large is the Amazon rainforest?",
"context" : queries[-1]['context']
})
# process the queries
for query in queries:
answers = model(query, top_k=5)
print('\n')
print('context:', query['context'])
print('')
print('question:', query['question'])
for answer in answers:
print('')
print('answer: ', answer['answer'])
print('score: ', answer['score'])
elif args.type == 'text_classification':
model = TextClassification(args.model)
# create some test queries (these are for sentiment models)
queries = [
"By the end of no such thing the audience, like beatrice, has a watchful affection for the monster.",
"Director Rob Marshall went out gunning to make a great one.",
"Uneasy mishmash of styles and genres.",
"I love exotic science fiction / fantasy movies but this one was very unpleasant to watch. I gave it 4 / 10 since some special effects were nice.",
"Today was cold and rainy and not very nice.",
"Today was warm, sunny and beautiful out.",
]
# process the queries
for query in queries:
results = model(query)
print('\nquery:', query)
pprint.pprint(results)
elif args.type == 'token_classification':
model = TokenClassification(args.model)
# create some test queries
queries = [
"But candidate Charles Baker, who has about eight percent of the vote, has called for an investigation into reports of people voting multiple times.",
"Analysts say Mr. Chung's comments may be part of efforts by South Korea to encourage North Korea to resume bilateral talks.",
"The 63-year-old Daltrey walked offstage during the first song; guitarist Pete Townshend later told the crowd he was suffering from bronchitis and could barely speak.",
"The Who is currently touring in support of Endless Wire, its first album since 1982.",
"Meanwhile, Iowa is cleaning up after widespread flooding inundated homes, destroyed crops and cut off highways and bridges.",
"At the White House Tuesday, U.S. President George Bush expressed concern for the flood victims.",
"Ben is from Chicago, a city in the state of Illinois, US with a population of 2.7 million people.",
"Lisa's favorite place to climb in the summer is El Capitan in Yosemite National Park in California, U.S."
]
# process the queries
for query in queries:
tags = model(query)
#print(f'\n{query}')
#pprint.pprint(tags)
print(f'\n{model.tag_string(query, tags, scores=True)}')
else:
raise ValueError(f"invalid --type argument ({args.type})")