diff --git a/Rock_fact_checker.py b/Rock_fact_checker.py index a57ef87..ab1225b 100644 --- a/Rock_fact_checker.py +++ b/Rock_fact_checker.py @@ -12,15 +12,13 @@ entailment_html_messages, create_df_for_relevant_snippets, create_ternary_plot, - build_sidebar + build_sidebar, ) from app_utils.config import RETRIEVER_TOP_K def main(): - statements = load_statements() - build_sidebar() # Persistent state @@ -120,7 +118,6 @@ def main(): st.markdown(f"###### Most Relevant snippets:") df, urls = create_df_for_relevant_snippets(docs) st.dataframe(df) - str_wiki_pages = "Wikipedia source pages: " for doc, url in urls.items(): str_wiki_pages += f"[{doc}]({url}) " diff --git a/app_utils/backend_utils.py b/app_utils/backend_utils.py index c7e03af..ec3c388 100644 --- a/app_utils/backend_utils.py +++ b/app_utils/backend_utils.py @@ -31,7 +31,7 @@ def load_statements(): ) def start_haystack(): """ - load document store, retriever, reader and create pipeline + load document store, retriever, entailment checker and create pipeline """ shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".") document_store = FAISSDocumentStore( @@ -39,13 +39,11 @@ def start_haystack(): faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json", ) print(f"Index size: {document_store.get_document_count()}") - retriever = EmbeddingRetriever( document_store=document_store, embedding_model=RETRIEVER_MODEL, model_format=RETRIEVER_MODEL_FORMAT, ) - entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL, use_gpu=False) pipe = Pipeline() @@ -84,8 +82,8 @@ def query(statement: str, retriever_top_k: int = 5): break results["agg_entailment_info"] = { - "contradiction": float(round(agg_con / scores, 2)), - "neutral": float(round(agg_neu / scores, 2)), - "entailment": float(round(agg_ent / scores, 2)), + "contradiction": round(agg_con / scores, 2), + "neutral": round(agg_neu / scores, 2), + "entailment": round(agg_ent / scores, 2), } return results diff --git a/app_utils/frontend_utils.py b/app_utils/frontend_utils.py index 896be42..34f8dcf 100644 --- a/app_utils/frontend_utils.py +++ b/app_utils/frontend_utils.py @@ -9,8 +9,9 @@ "neutral": 'The knowledge base is neutral about your statement', } + def build_sidebar(): - sidebar=""" + sidebar = """

Fact Checking 🎸 Rocks!

Fact checking baseline combining dense retrieval and textual entailment @@ -20,6 +21,7 @@ def build_sidebar(): """ st.sidebar.markdown(sidebar, unsafe_allow_html=True) + def set_state_if_absent(key, value): if key not in st.session_state: st.session_state[key] = value @@ -33,6 +35,9 @@ def reset_results(*args): def create_ternary_plot(entailment_data): + """ + Create a Plotly ternary plot for the given entailment dict. + """ hover_text = "" for label, value in entailment_data.items(): hover_text += f"{label}: {value}
" @@ -83,14 +88,11 @@ def makeAxis(title, tickangle): } -def highlight_cols(s): - coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"} - if s.name in coldict.keys(): - return ["background-color: {}".format(coldict[s.name])] * len(s) - return [""] * len(s) - - def create_df_for_relevant_snippets(docs): + """ + Create a dataframe that contains all relevant snippets. + Also returns the URLs + """ rows = [] urls = {} for doc in docs: @@ -106,3 +108,10 @@ def create_df_for_relevant_snippets(docs): rows.append(row) df = pd.DataFrame(rows).style.apply(highlight_cols) return df, urls + + +def highlight_cols(s): + coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"} + if s.name in coldict.keys(): + return ["background-color: {}".format(coldict[s.name])] * len(s) + return [""] * len(s) diff --git a/pages/Info.py b/pages/Info.py index 37dab90..fdbb222 100644 --- a/pages/Info.py +++ b/pages/Info.py @@ -1,9 +1,10 @@ import streamlit as st + from app_utils.frontend_utils import build_sidebar build_sidebar() -with open('README.md','r') as fin: - readme = fin.read().rpartition('---')[-1] +with open("README.md", "r") as fin: + readme = fin.read().rpartition("---")[-1] st.markdown(readme, unsafe_allow_html=True)