Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoquan committed Aug 16, 2018
0 parents commit 437477a
Show file tree
Hide file tree
Showing 10 changed files with 2,469 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
models
__pycache__
data/*.dat
16 changes: 16 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
help:
@echo " train-nlu"
@echo " Train the natural language understanding using Rasa NLU."
@echo " train-core"
@echo " Train a dialogue model using Rasa core."
@echo " run"
@echo " Runs the bot on the command line."

run:
python bot.py run

train-nlu:
python bot.py train-nlu

train-core:
python bot.py train-dialogue
Empty file added __init__.py
Empty file.
118 changes: 118 additions & 0 deletions bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import logging
import warnings

from policy import MobilePolicy
from rasa_core import utils
from rasa_core.actions import Action
from rasa_core.agent import Agent
from rasa_core.channels.console import ConsoleInputChannel
from rasa_core.interpreter import RasaNLUInterpreter
from rasa_core.policies.memoization import MemoizationPolicy

logger = logging.getLogger(__name__)

support_search = ["消费", "流量"]

def extract_item(item):
if item is None:
return None
for name in support_search:
if name in item:
return name
return None


class ActionSearchConsume(Action):
def name(self):
return 'action_search_consume'

def run(self, dispatcher, tracker, domain):
item = tracker.get_slot("item")
item = extract_item(item)
if item is None:
dispatcher.utter_message("您好,我现在只会查话费和流量")
dispatcher.utter_message("你可以这样问我:“帮我查话费”")
return []

time = tracker.get_slot("time")
if time is None:
dispatcher.utter_message("您想查询哪个月的消费?")
return []
# query database here using item and time as key. but you may normalize time format first.
dispatcher.utter_message("好,请稍等")
if item == "流量":
dispatcher.utter_message(
"您好,您{}共使用{}二百八十兆,剩余三十兆。".format(time, item))
else:
dispatcher.utter_message("您好,您{}共消费二十八元。".format(time))
return []


def train_dialogue(domain_file="mobile_domain.yml",
model_path="models/dialogue",
training_data_file="data/mobile_story.md"):
agent = Agent(domain_file,
policies=[MemoizationPolicy(max_history=3),
MobilePolicy()])

training_data = agent.load_data(training_data_file)
agent.train(
training_data,
epochs=400,
batch_size=100,
validation_split=0.2
)

agent.persist(model_path)
return agent


def train_nlu():
from rasa_nlu.training_data import load_data
from rasa_nlu import config
from rasa_nlu.model import Trainer

training_data = load_data('data/mobile_nlu_data.json')
trainer = Trainer(config.load("mobile_nlu_model_config.yml"))
trainer.train(training_data)
model_directory = trainer.persist('models/',
project_name="nlu",
fixed_model_name="current")

return model_directory


def run(serve_forever=True):
interpreter = RasaNLUInterpreter("models/nlu/current")
agent = Agent.load("models/dialogue", interpreter=interpreter)

if serve_forever:
agent.handle_channel(ConsoleInputChannel())
return agent


if __name__ == '__main__':
utils.configure_colored_logging(loglevel="INFO")

parser = argparse.ArgumentParser(
description='starts the bot')

parser.add_argument(
'task',
choices=["train-nlu", "train-dialogue", "run"],
help="what the bot should do - e.g. run or train?")
task = parser.parse_args().task

# decide what to do based on first parameter of the script
if task == "train-nlu":
train_nlu()
elif task == "train-dialogue":
train_dialogue()
elif task == "run":
run()
Loading

0 comments on commit 437477a

Please sign in to comment.