-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbased_trainer.py
80 lines (63 loc) · 2.46 KB
/
based_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import xgboost as xgb
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from data_format import generate_data
from joblib import dump
# Set Seaborn style and font scale
sns.set_style("whitegrid")
sns.set_context("paper", font_scale=2.2)
generate_data()
df = pd.read_csv('data/based_input_data.csv')
X = df.filter(['width', 'slope', 'discharge'], axis=1)
y = df['depth']
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42, shuffle=True)
# Train the XGBoost model
params = {"objective": "reg:squarederror",
'max_depth': 12,
'subsample': 0.8,
'learning_rate': 0.09,
'n_estimators': 75,
'reg_lambda': 1.2}
xg_reg = xgb.XGBRegressor(**params)
xg_reg.fit(X_train, y_train, eval_metric='mae')
# Predict on test set using bootstrapping and calculate mean and confidence interval
n_iterations = 100
predictions = np.zeros((n_iterations, len(y_test)))
for i in range(n_iterations):
X_resample, y_resample = resample(X_train, y_train)
xg_reg.fit(X_resample, y_resample)
predictions[i] = xg_reg.predict(X_test)
mean_predictions = predictions.mean(axis=0)
std_predictions = predictions.std(axis=0)
confidence_interval = [mean_predictions - 1.96 * std_predictions, mean_predictions + 1.96 * std_predictions]
# Print model report and plot results
print("\nModel Report")
print("MAE (Train): %f" % metrics.mean_absolute_error(y_test, mean_predictions))
print("RMSE (Train): %f" % np.sqrt(metrics.mean_squared_error(y_test, mean_predictions)))
print("R2: %f" % metrics.r2_score(y_test, mean_predictions))
plt.figure(figsize=(10,10))
sns.set_context('talk')
sns.set(font_scale = 2.2)
p1 = max(max(mean_predictions), max(y_test))
p2 = min(min(mean_predictions), min(y_test))
sns.scatterplot(x=y_test, y=mean_predictions, color= '#FFCCBC', edgecolor='k', s=100)
plt.plot([p1, p2], [p1, p2], 'k--', label='1:1 line', lw=3.5)
plt.yscale('log')
plt.xscale('log')
plt.xlabel('Measured Channel Depth (m)', )
plt.ylabel('Predicted Channel Depth (m)', )
plt.title('BASED Validation | n = {}'.format(len(y_test)), )
plt.axis('equal')
plt.legend()
plt.tight_layout()
plt.savefig("img/BASED_validation.png", dpi=250)
# Train the final model on all data
xg_reg.fit(X, y)
# Save the model
dump(xg_reg, 'based_model.joblib')