Skip to content

Commit

Permalink
great progress in showing output
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Aug 24, 2022
1 parent 35f0167 commit 1434337
Show file tree
Hide file tree
Showing 9 changed files with 1,189 additions and 174 deletions.
178 changes: 68 additions & 110 deletions Rock_fact_checker.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,50 @@
import streamlit as st

import random
import time
import logging
from json import JSONDecodeError
# from markdown import markdown
# from annotated_text import annotation
# from urllib.parse import unquote
import random

import streamlit as st
import pandas as pd
import plotly.express as px

from app_utils.backend_utils import load_statements, query
from app_utils.frontend_utils import set_state_if_absent, reset_results, entailment_html_messages
from app_utils.frontend_utils import (
set_state_if_absent,
reset_results,
entailment_html_messages,
create_df_for_relevant_snippets,
)
from app_utils.config import RETRIEVER_TOP_K


def main():


statements = load_statements()

# Persistent state
set_state_if_absent('statement', "Elvis Presley is alive")
set_state_if_absent('answer', '')
set_state_if_absent('results', None)
set_state_if_absent('raw_json', None)
set_state_if_absent('random_statement_requested', False)
set_state_if_absent("statement", "Elvis Presley is alive")
set_state_if_absent("answer", "")
set_state_if_absent("results", None)
set_state_if_absent("raw_json", None)
set_state_if_absent("random_statement_requested", False)


## MAIN CONTAINER
st.write("# Fact checking 🎸 Rocks!")
st.write()
st.markdown("""
st.markdown(
"""
##### Enter a factual statement about [Rock music](https://en.wikipedia.org/wiki/List_of_mainstream_rock_performers) and let the AI check it out for you...
""")
"""
)
# Search bar
statement = st.text_input("", value=st.session_state.statement,
max_chars=100, on_change=reset_results)
statement = st.text_input(
"", value=st.session_state.statement, max_chars=100, on_change=reset_results
)
col1, col2 = st.columns(2)
col1.markdown(
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True
)
col2.markdown(
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True
)
# Run button
run_pressed = col1.button("Run")
# Random statement button
Expand All @@ -54,12 +58,15 @@ def main():
st.session_state.random_statement_requested = True
# Re-runs the script setting the random statement as the textbox value
# Unfortunately necessary as the Random statement button is _below_ the textbox
# raise st.script_runner.RerunException(
# st.script_request_queue.RerunData(None))
# Adapted for Streamlit>=1.12
raise st.runtime.scriptrunner.script_runner.RerunException(
st.runtime.scriptrunner.script_requests.RerunData("")
)
else:
st.session_state.random_statement_requested = False
run_query = (run_pressed or statement != st.session_state.statement) \
and not st.session_state.random_statement_requested
run_query = (
run_pressed or statement != st.session_state.statement
) and not st.session_state.random_statement_requested

# Get results for query
if run_query and statement:
Expand All @@ -68,14 +75,14 @@ def main():
st.session_state.statement = statement
with st.spinner("🧠 &nbsp;&nbsp; Performing neural search on documents..."):
try:
st.session_state.results = query(
statement, RETRIEVER_TOP_K)
st.session_state.results = query(statement, RETRIEVER_TOP_K)
time_end = time.time()
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
print(f'elapsed time: {time_end - time_start}')
print(f"elapsed time: {time_end - time_start}")
except JSONDecodeError as je:
st.error(
"👓 &nbsp;&nbsp; An error occurred reading the results. Is the document store working?")
"👓 &nbsp;&nbsp; An error occurred reading the results. Is the document store working?"
)
return
except Exception as e:
logging.exception(e)
Expand All @@ -85,85 +92,36 @@ def main():
# Display results
if st.session_state.results:
results = st.session_state.results
docs, agg_entailment_info = results['documents'], results['agg_entailment_info']
print(results)

docs, agg_entailment_info = results["documents"], results["agg_entailment_info"]

# show different messages depending on entailment results
max_key = max(agg_entailment_info, key=agg_entailment_info.get)
message = entailment_html_messages[max_key]
st.markdown(f'<h4>{message}</h4>', unsafe_allow_html=True)
st.markdown(f'###### Aggregate entailment information:')
st.write(results['agg_entailment_info'])
st.markdown(f'###### Relevant snippets:')

# colms = st.columns((2, 5, 1, 1, 1, 1))
# fields = ["Page title",'Content', 'Relevance', 'contradiction', 'neutral', 'entailment']
# for col, field_name in zip(colms, fields):
# # header
# col.write(field_name)
df = []
for doc in docs:
# col1, col2, col3, col4, col5, col6 = st.columns((2, 5, 1, 1, 1, 1))
# col1.write(f"[{doc.meta['name']}]({doc.meta['url']})")
# col2.write(f"{doc.content}")
# col3.write(f"{doc.score:.3f}")
# col4.write(f"{doc.meta['entailment_info']['contradiction']:.2f}")
# col5.write(f"{doc.meta['entailment_info']['neutral']:.2f}")
# col6.write(f"{doc.meta['entailment_info']['entailment']:.2f}")

# 'con': f"{doc.meta['entailment_info']['contradiction']:.2f}",
# 'neu': f"{doc.meta['entailment_info']['neutral']:.2f}",
# 'ent': f"{doc.meta['entailment_info']['entailment']:.2f}",
# # 'url': doc.meta['url'],
# 'Content': doc.content}
#
#
#
row = {'Title': doc.meta['name'],
'Relevance': f"{doc.score:.3f}",
'con': f"{doc.meta['entailment_info']['contradiction']:.2f}",
'neu': f"{doc.meta['entailment_info']['neutral']:.2f}",
'ent': f"{doc.meta['entailment_info']['entailment']:.2f}",
# 'url': doc.meta['url'],
'Content': doc.content}
df.append(row)
st.dataframe(pd.DataFrame(df))#.style.apply(highlight))


# if len(st.session_state.results['answers']) == 0:
# st.info("""🤔 &nbsp;&nbsp; Haystack is unsure whether any of
# the documents contain an answer to your question. Try to reformulate it!""")

# for result in st.session_state.results['answers']:
# result = result.to_dict()
# if result["answer"]:
# if alert_irrelevance and result['score'] < LOW_RELEVANCE_THRESHOLD:
# alert_irrelevance = False
# st.write("""
# <h4 style='color: darkred'>Attention, the
# following answers have low relevance:</h4>""",
# unsafe_allow_html=True)

# answer, context = result["answer"], result["context"]
# start_idx = context.find(answer)
# end_idx = start_idx + len(answer)
# # Hack due to this bug: https://github.com/streamlit/streamlit/issues/3190
# st.write(markdown("- ..."+context[:start_idx] +
# str(annotation(answer, "ANSWER", "#3e1c21", "white")) +
# context[end_idx:]+"..."), unsafe_allow_html=True)
# source = ""
# name = unquote(result['meta']['name']).replace('_', ' ')
# url = result['meta']['url']
# source = f"[{name}]({url})"
# st.markdown(
# f"**Score:** {result['score']:.2f} - **Source:** {source}")

# def make_pretty(styler):
# styler.set_caption("Weather Conditions")
# # styler.format(rain_condition)
# styler.format_con(lambda v: v.float(v))
# styler.background_gradient(axis=None, vmin=0, vmax=1, cmap="YlGnBu")
# return styler

def highlight(s):
return ['background-color: red']*5
main()
st.markdown(f"<br/><h4>{message}</h4>", unsafe_allow_html=True)

st.markdown(f"###### Aggregate entailment information:")
col1, col2 = st.columns([2, 1])
df_agg_entailment_info = pd.DataFrame([results["agg_entailment_info"]])
fig = px.scatter_ternary(
df_agg_entailment_info,
a="contradiction",
b="neutral",
c="entailment",
size="contradiction",
)
with col1:
st.plotly_chart(fig, use_container_width=True)
with col2:
st.write(results["agg_entailment_info"])

st.markdown(f"###### Relevant snippets:")
df, urls = create_df_for_relevant_snippets(docs)
st.dataframe(df)

str_wiki_pages = "Wikipedia source pages: "
for doc, url in urls.items():
str_wiki_pages += f"[{doc}]({url}) "
st.markdown(str_wiki_pages)


main()
91 changes: 55 additions & 36 deletions app_utils/backend_utils.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,61 @@
import shutil

from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import EmbeddingRetriever
from haystack.pipelines import Pipeline

import streamlit as st

from app_utils.entailment_checker import EntailmentChecker
from app_utils.config import (
STATEMENTS_PATH,
INDEX_DIR,
RETRIEVER_MODEL,
RETRIEVER_MODEL_FORMAT,
NLI_MODEL,
)


@st.cache()
def load_statements():
"""Load statements from file"""
with open(STATEMENTS_PATH) as fin:
statements = [
line.strip() for line in fin.readlines() if not line.startswith("#")
]
return statements

from app_utils.config import STATEMENTS_PATH, INDEX_DIR, RETRIEVER_MODEL, RETRIEVER_MODEL_FORMAT, NLI_MODEL

# cached to make index and models load only at start
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True)
@st.cache(
hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True
)
def start_haystack():
"""
load document store, retriever, reader and create pipeline
"""
shutil.copy(f'{INDEX_DIR}/faiss_document_store.db', '.')
shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".")
document_store = FAISSDocumentStore(
faiss_index_path=f'{INDEX_DIR}/my_faiss_index.faiss',
faiss_config_path=f'{INDEX_DIR}/my_faiss_index.json')
print(f'Index size: {document_store.get_document_count()}')

faiss_index_path=f"{INDEX_DIR}/my_faiss_index.faiss",
faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json",
)
print(f"Index size: {document_store.get_document_count()}")

retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model=RETRIEVER_MODEL,
model_format=RETRIEVER_MODEL_FORMAT
model_format=RETRIEVER_MODEL_FORMAT,
)

entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL,
use_gpu=False)


entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL, use_gpu=False)

pipe = Pipeline()
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
return pipe


pipe = start_haystack()

# the pipeline is not included as parameter of the following function,
# because it is difficult to cache
@st.cache(persist=True, allow_output_mutation=True)
Expand All @@ -45,28 +64,28 @@ def query(statement: str, retriever_top_k: int = 5):
params = {"retriever": {"top_k": retriever_top_k}}
results = pipe.run(statement, params=params)

scores, agg_con, agg_neu, agg_ent = 0,0,0,0
for doc in results['documents']:
scores+=doc.score
ent_info=doc.meta['entailment_info']
con,neu,ent = ent_info['contradiction'], ent_info['neutral'], ent_info['entailment']
agg_con+=con*doc.score
agg_neu+=neu*doc.score
agg_ent+=ent*doc.score

results['agg_entailment_info'] = {
'contradiction': round(agg_con/scores, 2),
'neutral': round(agg_neu/scores, 2),
'entailment': round(agg_ent/scores, 2)}

return results
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
for i, doc in enumerate(results["documents"]):
scores += doc.score
ent_info = doc.meta["entailment_info"]
con, neu, ent = (
ent_info["contradiction"],
ent_info["neutral"],
ent_info["entailment"],
)
agg_con += con * doc.score
agg_neu += neu * doc.score
agg_ent += ent * doc.score

@st.cache()
def load_statements():
"""Load statements from file"""
with open(STATEMENTS_PATH) as fin:
statements = [line.strip() for line in fin.readlines()
if not line.startswith('#')]
return statements
# if in the first 3 documents there is a strong evidence of entailment/contradiction,
# there is non need to consider less relevant documents
if i == 2 and max(agg_con, agg_ent) / scores > 0.5:
results["documents"] = results["documents"][: i + 1]
break


results["agg_entailment_info"] = {
"contradiction": round(agg_con / scores, 2),
"neutral": round(agg_neu / scores, 2),
"entailment": round(agg_ent / scores, 2),
}
return results
5 changes: 2 additions & 3 deletions app_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

INDEX_DIR = 'data/index'
STATEMENTS_PATH = 'data/statements.txt'
INDEX_DIR = "data/index"
STATEMENTS_PATH = "data/statements.txt"

RETRIEVER_MODEL = "sentence-transformers/msmarco-distilbert-base-tas-b"
RETRIEVER_MODEL_FORMAT = "sentence_transformers"
Expand Down
Loading

0 comments on commit 1434337

Please sign in to comment.