-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
62 lines (47 loc) · 2.41 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import pandas as pd
import matplotlib.pyplot as plt
import acquire
import prepare
import math
from sklearn import metrics
df = acquire.acquire_data()
df = prepare.prepare_data(df)
def split_store_data(df, train_prop=.7):
train_size = int(len(df) * train_prop)
train, test = df[0:train_size], df[train_size:len(df)]
return train, test
train, test = split_store_data(df)
target_vars=['steps']
yhat = pd.DataFrame(test[target_vars])
def evaluate(target_var, train = train, test = test, output=True):
mse = metrics.mean_squared_error(test[target_var], yhat[target_var])
rmse = math.sqrt(mse)
if output:
print('MSE: {}'.format(mse))
print('RMSE: {}'.format(rmse))
else:
return mse, rmse
def plot_and_eval(target_vars, train = train, test = test, metric_fmt = '{:.2f}', linewidth = 4):
if type(target_vars) is not list:
target_vars = [target_vars]
plt.figure(figsize=(16, 8))
plt.plot(train[target_vars],label='Train', linewidth=1)
plt.plot(test[target_vars], label='Test', linewidth=1)
for var in target_vars:
mse, rmse = evaluate(target_var = var, train = train, test = test, output=False)
plt.plot(yhat[var], linewidth=linewidth)
print(f'{var} -- MSE: {metric_fmt} RMSE: {metric_fmt}'.format(mse, rmse))
plt.show()
eval_df = pd.DataFrame(columns=['model_type', 'target_var', 'metric', 'value'])
def append_eval_df(model_type, target_vars, train = train, test = test):
temp_eval_df = pd.concat([pd.DataFrame([[model_type, i, 'mse', evaluate(target_var = i,
train = train,
test = test,
output=False)[0]],
[model_type, i, 'rmse', evaluate(target_var = i,
train = train,
test = test,
output=False)[1]]],
columns=['model_type', 'target_var', 'metric', 'value'])
for i in target_vars], ignore_index=True)
return eval_df.append(temp_eval_df, ignore_index=True)