Skip to content

Commit

Permalink
little code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Aug 28, 2022
1 parent 55e565f commit a147158
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 20 deletions.
5 changes: 1 addition & 4 deletions Rock_fact_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
entailment_html_messages,
create_df_for_relevant_snippets,
create_ternary_plot,
build_sidebar
build_sidebar,
)
from app_utils.config import RETRIEVER_TOP_K


def main():

statements = load_statements()

build_sidebar()

# Persistent state
Expand Down Expand Up @@ -120,7 +118,6 @@ def main():
st.markdown(f"###### Most Relevant snippets:")
df, urls = create_df_for_relevant_snippets(docs)
st.dataframe(df)

str_wiki_pages = "Wikipedia source pages: "
for doc, url in urls.items():
str_wiki_pages += f"[{doc}]({url}) "
Expand Down
10 changes: 4 additions & 6 deletions app_utils/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,19 @@ def load_statements():
)
def start_haystack():
"""
load document store, retriever, reader and create pipeline
load document store, retriever, entailment checker and create pipeline
"""
shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".")
document_store = FAISSDocumentStore(
faiss_index_path=f"{INDEX_DIR}/my_faiss_index.faiss",
faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json",
)
print(f"Index size: {document_store.get_document_count()}")

retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model=RETRIEVER_MODEL,
model_format=RETRIEVER_MODEL_FORMAT,
)

entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL, use_gpu=False)

pipe = Pipeline()
Expand Down Expand Up @@ -84,8 +82,8 @@ def query(statement: str, retriever_top_k: int = 5):
break

results["agg_entailment_info"] = {
"contradiction": float(round(agg_con / scores, 2)),
"neutral": float(round(agg_neu / scores, 2)),
"entailment": float(round(agg_ent / scores, 2)),
"contradiction": round(agg_con / scores, 2),
"neutral": round(agg_neu / scores, 2),
"entailment": round(agg_ent / scores, 2),
}
return results
25 changes: 17 additions & 8 deletions app_utils/frontend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
"neutral": 'The knowledge base is <span style="color:darkgray">neutral</span> about your statement',
}


def build_sidebar():
sidebar="""
sidebar = """
<h1 style='text-align: center'>Fact Checking 🎸 Rocks!</h1>
<div style='text-align: center'>
<i>Fact checking baseline combining dense retrieval and textual entailment</i>
Expand All @@ -20,6 +21,7 @@ def build_sidebar():
"""
st.sidebar.markdown(sidebar, unsafe_allow_html=True)


def set_state_if_absent(key, value):
if key not in st.session_state:
st.session_state[key] = value
Expand All @@ -33,6 +35,9 @@ def reset_results(*args):


def create_ternary_plot(entailment_data):
"""
Create a Plotly ternary plot for the given entailment dict.
"""
hover_text = ""
for label, value in entailment_data.items():
hover_text += f"{label}: {value}<br>"
Expand Down Expand Up @@ -83,14 +88,11 @@ def makeAxis(title, tickangle):
}


def highlight_cols(s):
coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"}
if s.name in coldict.keys():
return ["background-color: {}".format(coldict[s.name])] * len(s)
return [""] * len(s)


def create_df_for_relevant_snippets(docs):
"""
Create a dataframe that contains all relevant snippets.
Also returns the URLs
"""
rows = []
urls = {}
for doc in docs:
Expand All @@ -106,3 +108,10 @@ def create_df_for_relevant_snippets(docs):
rows.append(row)
df = pd.DataFrame(rows).style.apply(highlight_cols)
return df, urls


def highlight_cols(s):
coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"}
if s.name in coldict.keys():
return ["background-color: {}".format(coldict[s.name])] * len(s)
return [""] * len(s)
5 changes: 3 additions & 2 deletions pages/Info.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import streamlit as st

from app_utils.frontend_utils import build_sidebar

build_sidebar()

with open('README.md','r') as fin:
readme = fin.read().rpartition('---')[-1]
with open("README.md", "r") as fin:
readme = fin.read().rpartition("---")[-1]

st.markdown(readme, unsafe_allow_html=True)

0 comments on commit a147158

Please sign in to comment.