Skip to content
This repository has been archived by the owner on Apr 7, 2022. It is now read-only.

Add ML Stats #114

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions data_analysis/data_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ def stats_generator() -> pd.DataFrame:
def load_stats(**kwargs) -> dict:
"""Loads stats from database or local

Args:
**kwargs: passed to `cached_table_fetch`. See its docstring for more info.
Parameters
----------
**kwargs :
passed to `cached_table_fetch`. See its docstring for more info.

Returns
-------
Expand Down
167 changes: 167 additions & 0 deletions data_analysis/ml_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from config import MODEL_PATH, ENCODER_PATH
from database import cached_table_fetch
import matplotlib.pyplot as plt
import pandas as pd
import time
import datetime
from helpers import RtdRay, ttl_lru_cache
from sklearn.dummy import DummyClassifier
import pickle


def majority_baseline(x, y):
clf = DummyClassifier(strategy="most_frequent", random_state=0)
clf.fit(x, y)
return round((clf.predict(x) == y.to_numpy()).sum() / len(x), 6)


def model_score(model, x, y):
return model.score(x, y)


def test_model(model, x_test, y_test, model_name) -> dict[str, float]:
baseline = majority_baseline(x_test, y_test)
model_score = (model.predict(x_test) == y_test).sum() / len(y_test)

stats = {}

stats["model"] = model_name
stats["baseline"] = round(baseline * 100, 6)
stats["accuracy"] = round(model_score * 100, 6)
stats["improvement"] = round((model_score - baseline)*100, 6)
print(stats)
return stats


def model_roc(model, x_test, y_test, model_name):
from sklearn.metrics import (
precision_recall_curve,
plot_precision_recall_curve,
auc,
roc_curve,
)

prediction = model.predict_proba(x_test)[:, 1]
fpr, tpr, thresholds = roc_curve(y_test, prediction, pos_label=1)
roc_auc = auc(fpr, tpr)

lw = 2
fig, ax = plt.subplots()
ax.plot(
fpr, tpr, color="darkorange", lw=lw, label="ROC curve (area = %0.2f)" % roc_auc
)
ax.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title(
"Receiver operating characteristic model {}".format(model_name))
ax.legend(loc="lower right")

# fig1, ax1 = plt.subplots()
# ax1.set_title('Predictions')
# ax1.boxplot(prediction)
plt.show()

def generate_stats(n_models = range(15), date = datetime.today()) -> pd.DataFrame:
"""
Generates stats for the machine learning models for a specific day

Parameters
----------
n_models : list
The models for which to compute the stats
date : datetime
The day for which to compute the stats

Returns
-------
pd.DataFrame
DataFrame containing the generated stats
"""
status_encoder = {}
status_encoder["ar"] = pickle.load(
open(ENCODER_PATH.format(encoder="ar_cs"), "rb"))
status_encoder["dp"] = pickle.load(
open(ENCODER_PATH.format(encoder="dp_cs"), "rb"))

# Get midnight
midnight = datetime.combine(date, time.min)
test = RtdRay.load_for_ml_model(
min_date = midnight - datetime.timedelta(days=1),
max_date = midnight,
long_distance_only = False,
return_status = True,
).compute()

ar_test = test.loc[~test["ar_delay"].isna(), ["ar_delay", "ar_cs"]]
dp_test = test.loc[~test["dp_delay"].isna(), ["dp_delay", "dp_cs"]]
# ar_test = test[['ar_delay', 'ar_cs']].dropna(subset=["ar_delay"])
# dp_test = test[['dp_delay', 'dp_cs']].dropna(subset=["dp_delay"])

ar_test_x = test.loc[~test["ar_delay"].isna()].drop(columns=["ar_delay", "dp_delay", "ar_cs", "dp_cs"], axis=0)
dp_test_x = test.loc[~test["dp_delay"].isna()].drop(columns=["ar_delay", "dp_delay", "ar_cs", "dp_cs"], axis=0)

ar_test_x.drop(columns=["obstacles_priority_24", "obstacles_priority_37", "obstacles_priority_63", "obstacles_priority_65", "obstacles_priority_70", "obstacles_priority_80"], inplace = True)
dp_test_x.drop(columns=["obstacles_priority_24", "obstacles_priority_37", "obstacles_priority_63", "obstacles_priority_65", "obstacles_priority_70", "obstacles_priority_80"], inplace = True)

del test

stats = []
for model_number in n_models:
model_name = f"ar_{model_number}"

test_y = (ar_test["ar_delay"] <= model_number) & (
ar_test["ar_cs"] != status_encoder["ar"]["c"]
)

model = pickle.load(open(MODEL_PATH.format(model_name), "rb"))

stats.append(test_model(model, ar_test_x, test_y, model_name))

# model_number += 1
model_name = f"dp_{model_number}"
test_y = (dp_test["dp_delay"] >= model_number) & (
dp_test["dp_cs"] != status_encoder["dp"]["c"]
)

model = pickle.load(open(MODEL_PATH.format(model_name), "rb"))
stats.append(test_model(model, dp_test_x, test_y, model_name))

stats = pd.DataFrame(stats)
stats["date"] = midnight - datetime.timedelta(days=1)

return stats


@ttl_lru_cache(maxsize=1, seconds_to_live=60*60)
def load_stats(**kwargs) -> dict:
"""Loads stats from database or local

Parameters
----------
**kwargs :
passed to `cached_table_fetch`. See its docstring for more info.

Returns
-------
dict
Loaded stats
"""
stats = cached_table_fetch('ml_model_stats', if_exists='append' **kwargs)

return stats.iloc[0].to_dict()

if __name__ == '__main__':
stats = load_stats(
table_generator=generate_stats,
generate=True,
)

print(stats)
16 changes: 12 additions & 4 deletions database/cached_table_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def cached_table_fetch(
prefer_cache: Optional[bool] = False,
generate: Optional[bool] = False,
table_generator: Optional[Callable[[], pd.DataFrame]] = None,
if_exists: Optional[str] = 'replace',
push: Optional[bool] = True,
**kwargs
) -> pd.DataFrame:
Expand All @@ -32,6 +33,9 @@ def cached_table_fetch(
Whether to use table_generator to generate the DataFrame and not look for cache or database, by default False
table_generator : Callable[[], pd.DataFrame], optional
Callable that generates the data of table tablename, by default None
if_exists : {'fail', 'replace', 'append'}, default 'replace'
What to do if the table exits
See https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_sql.html
push : bool, optional
Whether to push data to the db after calling table_generator, by default False

Expand All @@ -51,7 +55,7 @@ def cached_table_fetch(
raise ValueError('Cannot generate if no table_generator was supplied')
df = table_generator()
if push:
cached_table_push(df, tablename)
cached_table_push(df, tablename, if_exists = if_exists)
return df

if prefer_cache:
Expand Down Expand Up @@ -126,7 +130,7 @@ def pd_to_psql(df, uri, table_name, schema_name=None, if_exists='fail', sep=',')
return True


def cached_table_push(df: pd.DataFrame, tablename: str, fast: bool = True, **kwargs):
def cached_table_push(df: pd.DataFrame, tablename: str, fast: Optional[bool] = True, if_exists: Optional[str] = 'replace', **kwargs):
"""
Save df to local cache file and replace the table in the database.

Expand All @@ -140,12 +144,16 @@ def cached_table_push(df: pd.DataFrame, tablename: str, fast: bool = True, **kwa
Whether to use a faster push method or not, by default False
True: use the fast method, which might not be as accurate
False: use the slow method, which is more accurate
if_exists : {'fail', 'replace', 'append'}, default 'replace'
What to do if the table exits
See https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_sql.html

"""
cache_path = CACHE_PATH + '/' + tablename + '.pkl'
df.to_pickle(cache_path)
# d6stack is way faster than pandas at inserting data to sql.
# It exports the dataframe to a csv and then inserts it to the database.
if fast:
pd_to_psql(df, DB_CONNECT_STRING, tablename, if_exists='replace')
pd_to_psql(df, DB_CONNECT_STRING, tablename, if_exists = if_exists)
else:
df.to_sql(tablename, DB_CONNECT_STRING, if_exists='replace', method='multi', chunksize=10_000, **kwargs)
df.to_sql(tablename, DB_CONNECT_STRING, if_exists = if_exists, method='multi', chunksize=10_000, **kwargs)
3 changes: 3 additions & 0 deletions helpers/RtdRay.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ def load_for_ml_model(return_date_id=False, label_encode=True, return_times=Fals
Whether to label encode categorical columns, by default True
return_times : bool, optional
Whether to return planned and changed arrival and departure times, by default False
return_status : bool, optional
Whether to return the columns 'ar_cs', 'dp_cs', which contain the arrival/departure status of the train
for ex 'c' = canceled

Returns
-------
Expand Down
42 changes: 37 additions & 5 deletions k8s/jupyer-notebooks.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,35 @@
apiVersion: v1
kind: PersistentVolume
metadata:
name: jupyter-pv
spec:
capacity:
storage: 10G
storageClassName: jupyter
accessModes:
- ReadOnlyMany
hostPath:
path: "/mnt/jupyter/"

---

apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: jupyter-pvc
labels:
type: local
spec:
accessModes:
- ReadOnlyMany
storageClassName: jupyter
resources:
requests:
storage: 1G
volumeName: jupyter-pv

---

apiVersion: apps/v1
kind: Deployment
metadata:
Expand All @@ -24,8 +56,8 @@ spec:
volumeMounts:
- name: tz-berlin # set timezone to CEST
mountPath: /etc/localtime
- mountPath: /mnt/config
name: config-pvc-storage
- mountPath: /home/jovyan
name: jupyter-pvc-storage
- mountPath: /mnt/cache
name: cache-pvc-storage
dnsPolicy: ClusterFirst
Expand All @@ -34,12 +66,12 @@ spec:
securityContext: {}
terminationGracePeriodSeconds: 30
volumes:
- name: config-pvc-storage
persistentVolumeClaim:
claimName: config-pvc
- name: cache-pvc-storage
persistentVolumeClaim:
claimName: cache-pvc
- name: jupyter-pvc-storage
persistentVolumeClaim:
claimName: jupyter-pvc
- name: tz-berlin # set timezone to CEST
hostPath:
path: /usr/share/zoneinfo/Europe/Berlin
Expand Down
13 changes: 12 additions & 1 deletion update_butler/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@

print("--Done")

print("--ML Stats")

from data_analysis import ml_stats

ml_stats.load_stats(
table_generator=ml_stats.stats_generator,
generate=True,
)

print("--Done")

print("--Per Station Data")

import datetime
Expand All @@ -64,6 +75,6 @@

print("Training ML Models...")

# TODO


print("Done")
20 changes: 18 additions & 2 deletions webserver/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from webserver import predictor, streckennetz, per_station_time
from webserver.db_logger import log_activity
from data_analysis import data_stats
from data_analysis import data_stats, ml_stats
from config import CACHE_PATH

bp = Blueprint("api", __name__, url_prefix="/api")
Expand Down Expand Up @@ -219,4 +219,20 @@ def obstacle_plot(date_range):
# even though os.path.isfile('cache/plot_cache/'+ plot_name + '.png') works
return send_file(
f"{CACHE_PATH}/plot_cache/{plot_name}.png", mimetype="image/png"
)
)

@bp.route("/ml_stats")
@log_activity
def stats_ml():
"""
Retrives stats for the machine learning models

Parameters
----------

Returns
-------
The statisics for the ml models
"""

ml_stats.load_stats()
Loading