Skip to content

Commit

Permalink
flexible snowflake authenticator
Browse files Browse the repository at this point in the history
  • Loading branch information
julianteichgraber committed Jan 24, 2024
1 parent a36dcd3 commit 2a0ac2c
Show file tree
Hide file tree
Showing 6 changed files with 449 additions and 444 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dist/
**/__pycache__/

# only to be modified locally:
tw_experimentation/snowflake_config.py
**/snowflake_config.json

# User-specific stuff:
.idea/**/workspace.xml
Expand Down
720 changes: 356 additions & 364 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tw-experimentation"
version = "0.1.2.2"
version = "0.1.2.3"
description = "Wise AB platform"
authors = ["Wise"]
readme = "README.md"
Expand Down
12 changes: 0 additions & 12 deletions tw_experimentation/snowflake_config.py

This file was deleted.

73 changes: 47 additions & 26 deletions tw_experimentation/streamlit/pages_wrap/page1_Data_Loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
swap_checkbox_state,
cols_to_select,
exp_config_to_json,
SnowflakeConnection,
SnowflakeIndividualCredentials,
)
from tw_experimentation.streamlit.streamlit_utils import generate_experiment_output
from tw_experimentation.utils import ExperimentDataset
Expand All @@ -15,7 +15,7 @@
import copy


def page_1_data_loading(snowflake_connector=SnowflakeConnection()):
def page_1_data_loading(snowflake_connector=SnowflakeIndividualCredentials()):
st.session_state.update(st.session_state)

DATA_PAGE = "Data Loading"
Expand Down Expand Up @@ -53,12 +53,6 @@ def page_1_data_loading(snowflake_connector=SnowflakeConnection()):
if snowflake_import == "query":
st.text_input("SQL query", key="query")

snowflake_pull_kwargs = {
"sql_query": st.session_state["query"],
"user": st.session_state["snowflake_username"],
"restart_engine": restart_snowflake,
}

if snowflake_import == "table name":
col11, col12, col13 = st.columns(3)

Expand All @@ -74,33 +68,61 @@ def page_1_data_loading(snowflake_connector=SnowflakeConnection()):
)
with col13:
st.text_input("Table", key="table")
snowflake_pull_kwargs = dict(
source_database=st.session_state["warehouse"],
source_schema=st.session_state["schema"],
source_table=st.session_state["table"],
user=st.session_state["snowflake_username"],
restart_engine=restart_snowflake,
)

enter_username = False
enter_credentials = False
if not st.session_state.has_snowflake_connection:
enter_username = True
enter_credentials = True
else:
restart_snowflake = st.checkbox("Restart snowflake connection")

st.session_state["snowflake_username"] = st.text_input(
"Snowflake username",
st.session_state["snowflake_username"],
disabled=not (enter_username or restart_snowflake),
)
if len(st.session_state["snowflake_connection"].input_configs) > 0:
with st.expander("Load snowflake configuration from json"):
st.write(
"""
You can load a snowflake configuration from a json file.
"""
)
st.write(
"""
You can upload a json file of the format, e.g.,
{
"user" = "USERNAME",
"region" = "REGION"
}
"""
)
config_json_snowflake = st.file_uploader(
"Upload a json config file for snowflake", type="json"
)

if config_json_snowflake is not None:
try:
json_content_snowflake = config_json_snowflake.getvalue()
exp_config = json.loads(json_content_snowflake)
st.write(exp_config)
for config, value in exp_config.items():
st.session_state["snowflake_" + config] = value
except json.JSONDecodeError:
st.error("Invalid JSON file. Please upload a valid JSON file.")
for config_variable in st.session_state["snowflake_connection"].input_configs:
initalise_session_states({"snowflake_" + config_variable: ""})
st.text_input(
config_variable,
st.session_state["snowflake_" + config_variable],
disabled=not (enter_credentials or restart_snowflake),
)

if st.button("Fetch data from snowflake"):
account_configs = {
config_variable: st.session_state["snowflake_" + config_variable]
for config_variable in st.session_state[
"snowflake_connection"
].input_configs
}
st.session_state["snowflake_connection"].account_config = account_configs
st.session_state["snowflake_connection"].connect(
restart_engine=restart_snowflake
)
# st.session_state["data_loader"].pull_snowflake_table(
# **snowflake_pull_kwargs
# )
if snowflake_import == "query":
st.session_state["df_temp"] = st.session_state[
"snowflake_connection"
Expand All @@ -112,7 +134,6 @@ def page_1_data_loading(snowflake_connector=SnowflakeConnection()):
source_database=st.session_state["warehouse"],
source_schema=st.session_state["schema"],
source_table=st.session_state["table"],
user=st.session_state["snowflake_username"],
)
st.session_state.has_snowflake_connection = True

Expand Down
84 changes: 44 additions & 40 deletions tw_experimentation/streamlit/streamlit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
NormalityChecks,
)
from tw_experimentation.bayes.bayes_test import BayesTest

from tw_experimentation.constants import (
COLORSCALES,
ACCOUNT,
REGION,
USERNAME,
Expand All @@ -44,6 +42,7 @@
RESULT_DATABASE,
RESULT_SCHEMA,
RESULT_TABLE,
COLORSCALES,
)

from abc import ABC, abstractmethod
Expand All @@ -55,24 +54,26 @@
from sqlalchemy import create_engine
import json

from typing import Optional, List, Union

import streamlit as st

import pandas as pd
from scipy.stats import chi2_contingency
from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine
import json

from typing import Optional, List, Union
from typing import Optional, List, Union, Dict


class SnowflakeConnection(ABC):
def __init__(self):
self.connection = None
self.engine = None
self._config_variables = []
self._account_config = dict()

@abstractmethod
def connect(self, restart_engine=False, **kwargs):
pass

@property
@abstractmethod
def input_configs(self):
"""List of variables to configure for snowflake connection"""
return []

def load_table(
self,
sql_query=None,
Expand All @@ -94,38 +95,37 @@ def close_connection(self):


class SnowflakeIndividualCredentials(SnowflakeConnection):
def __init__(self, username: str, password: str):
self.username = username
self.password = password

def connect(
self,
restart_engine=False,
user=USERNAME,
account=ACCOUNT,
region=REGION,
authenticator=AUTHENTICATOR,
database=DATABASE,
warehouse=WAREHOUSE,
):
credentials = [user, account, region, authenticator, database, warehouse]
assert all(
[False if cred is None else True for cred in credentials]
), "Please specify your snowflake credentials in "
"tw_experimentation/snowflake_config.py"
engine_kwargs = self._account_config
if self.engine is None or self.connection is None or restart_engine:
self.engine = create_engine(
URL(
account=account,
region=region,
user=user,
authenticator=authenticator,
database=database,
warehouse=warehouse,
)
)
self.engine = create_engine(URL(**engine_kwargs))
self.connection = self.engine.connect()
return self.conection
return self.connection

@property
def input_configs(self):
self.config_variables = [
"user",
"account",
"region",
"authenticator",
"database",
"warehouse",
]
return self.config_variables

@property
def account_config(self):
return self._account_config

@account_config.setter
def account_config(self, value):
# for config_variable in value.keys():
# assert config_variable in self._config_variables
self._account_config = value

def dispose_engine(self):
if hasattr(self, "engine"):
Expand Down Expand Up @@ -153,7 +153,7 @@ def exp_config_to_json():
return config_json


def initalise_session_states():
def initalise_session_states(additional_params: Optional[Dict] = dict()):
STATE_VARS = {
"last_page": "Main",
"output_loaded": None,
Expand Down Expand Up @@ -207,6 +207,10 @@ def initalise_session_states():
if k not in st.session_state:
st.session_state[k] = v

for k, v in additional_params.items():
if k not in st.session_state:
st.session_state[k] = v


def load_data(path: str):
df = pd.read_csv(path)
Expand Down

0 comments on commit 2a0ac2c

Please sign in to comment.