-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
563 lines (462 loc) · 27.5 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score
from sklearn.model_selection import KFold, train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler
import ydata_profiling as yp
import webbrowser
import time
class App:
def __init__(self):
self.data = None
self.dataset_name = None
self.classifier_name = None
self.params = dict()
self.clf = None
self.X, self.y = None, None
self.best_param = None
self.Init_Streamlit_Page()
def run(self):
"""
Run the main function of the application.
This function performs the following steps:
1. Get the dataset.
2. Preprocess the dataset if it is "Breast Cancer".
3. Visualize the dataset if it is "Breast Cancer".
4. Preprocess the dataset if it is "Airline Passenger Satisfaction".
5. Visualize the dataset if it is "Airline Passenger Satisfaction".
6. Tune the parameters of the classifier.
7. Add the parameters to the UI.
8. Train the models.
"""
self.get_dataset() # Get the dataset
if self.dataset_name == "Breast Cancer": # Preprocess and visualize the dataset
self.data_preprocess_breast_cancer()
self.data_viz_breast_cancer()
elif self.dataset_name == "Airline Passanger Satisfaction":
self.data_preprocess_airline_satisfaction()
self.data_viz_airline_satisfaction()
self.parameter_tuning(self.classifier_name, self.dataset_name) # Tune the parameters
self.add_parameter_ui() # Add the parameters to the UI
self.models() # Train the models
def Init_Streamlit_Page(self):
"""
Initialize the Streamlit page.
This function initializes the Streamlit page by setting up the sidebar
and displaying dataset information.
"""
# Set up the sidebar
self.dataset_name = st.sidebar.selectbox(
"Select Dataset", # Dataset selection dropdown
("Breast Cancer", "Airline Passanger Satisfaction"))
# Set title of the page
st.title(f"{self.dataset_name} Analysis")
# Display dataset information
if self.dataset_name == "Breast Cancer":
# Display breast cancer dataset information
st.expander("Dataset information"
).write("""
The "Diagnostic Wisconsin Breast Cancer Database" is a publicly available data set from the UCI machine learning repository.
The dataset gives information about tumor features, that are computed from a digitized image of a fine needle aspirate (FNA) of a breast mass.
For each observation there are 10 features, which describe tumor size, density, texture, symmetry, and other characteristics of the cell nuclei present in the image.
The mean, standard error and "worst" mean (mean of the three largest values) of these features were computed for each image, resulting in 30 features.
The categorical target feature indicates the type of the tumor.
The area on the aspirate slides to be analyzed was visually selected for minimal nuclear overlap.
The image for digital analysis was generated by a JVC TK-1070U color video camera mounted above an Olympus microscope and
the image was projected into the camera with a 63 x objective and a 2.5 x ocular.
The image was captured as a 512 x 480 resolution, 8 bit/pixel (Black and White) file.
The aspirated material was expressed onto a silane-coated glass slide, which was placed under a similar slide.
A typical image contains approximately from 10 to 40 nuclei. After computing 10 features for each nucleus, the mean, standart error and extreme value was computed, as it mentioned above.
These features are modeled such that higher values are typically associated with malignancy.
""")
elif self.dataset_name == "Airline Passanger Satisfaction":
# Display airline passenger satisfaction dataset information
st.expander("Dataset information"
).write("""
""")
# Display application description
st.markdown("*This application is a Streamlit dashboard that can be used "
"to analyze and predict the satisfaction of an airline passanger.*")
# Set title of the page
st.write(f"## {self.dataset_name} Dataset")
# Set up the classifier selection dropdown
self.classifier_name = st.sidebar.selectbox(
"Select Classifier", # Classifier selection dropdown
("KNN", "SVM", "Naive Bayes(GaussianNB)"))
def get_dataset(self):
"""
Downloads the dataset and assigns it to the 'data' attribute of the class.
Also sets the task to be performed on the dataset based on the dataset name.
Parameters:
None
Returns:
None
"""
try:
# Check the dataset name and assign the appropriate dataset
if self.dataset_name == "Breast Cancer":
self.data = pd.read_csv("data.csv") # Read the breast cancer dataset
task = "To predict whether the cancer is benign or malignant"
elif self.dataset_name == "Airline Passanger Satisfaction":
self.data = pd.read_csv("airline_satisfaction.csv", nrows=50000) # Read the airline passenger satisfaction dataset
task = "To predict whether the customer is satisfied or unsatisfied"
# Display status messages while downloading the dataset
with st.status("Downloading data...", expanded=True) as status:
st.write("Searching for data...")
st.write("Found URL.")
st.write("Downloading data...")
status.update(label=f"{self.dataset_name} Dataset download complete!", state="complete", expanded=False)
# Display the task to be performed
st.write(f"##### ***Task : {task}*** ")
self.data_intro()
except Exception as e:
# Display an error message if an exception occurs
st.write(f"An Error Occurred!: {e}")
@st.cache_data
def generate_report(_self):
report = yp.ProfileReport(_self.data, explorative=True)
report.to_file(f"{_self.dataset_name} Report.html")
def data_intro(self):
dataset_info = {
"Breast Cancer": {
"shape_text": "This dataset consists of 569 observations, each described by 33 features",
"attribute_info": """
Attribute Information:
- ID number
- Diagnosis (M = malignant, B = benign)
- Ten real-valued features are computed for each cell nucleus:
- radius (mean of distances from center to points on the perimeter)
- texture (standard deviation of gray-scale values)
- perimeter
- area
- smoothness (local variation in radius lengths)
- compactness (perimeter^2 / area - 1.0)
- concavity (severity of concave portions of the contour)
- concave points (number of concave portions of the contour)
- symmetry
- fractal dimension ("coastline approximation" - 1)
""",
"target_column": "diagnosis",
"target_chart_title": "Diagnosis Counts"
},
"Airline Passanger Satisfaction": {
"shape_text": "This dataset consists of 129880 observations, each described by 25 features",
"attribute_info": """
Attribute Information:
- Column Name / Description Type
- **Gender:** Gender of the passengers (Female, Male)
- **Customer Type:** The customer type (Loyal customer, disloyal customer)
- **Age:** The actual age of the passengers
- **Type of Travel:** Purpose of the flight of the passengers (Personal Travel, Business Travel)
- **Class:** Travel class in the plane of the passengers (Business, Eco, Eco Plus)
- **Flight distance:** The flight distance of this journey
- **Inflight wifi service:** Satisfaction level of the inflight wifi service (0:Not Applicable;1-5)
- **Departure/Arrival time convenient:** Satisfaction level of Departure/Arrival time convenient
- **Ease of Online booking:** Satisfaction level of online booking
- **Gate location:** Satisfaction level of Gate location
- **Food and drink:** Satisfaction level of Food and drink
- **Online boarding:** Satisfaction level of online boarding
- **Seat comfort:** Satisfaction level of Seat comfort
- **Inflight entertainment:** Satisfaction level of inflight entertainment
- **On-board service:** Satisfaction level of On-board service
- **Leg room service:** Satisfaction level of Leg room service
- **Baggage handling:** Satisfaction level of baggage handling
- **Check-in service:** Satisfaction level of Check-in service
- **Inflight service:** Satisfaction level of inflight service s
- **Cleanliness:** Satisfaction level of Cleanliness
- **Departure Delay in Minutes:** Minutes delayed when departure
- **Arrival Delay in Minutes:** Minutes delayed when Arrival
- **Satisfaction:** Airline satisfaction level(Satisfaction, neutral or dissatisfaction)
""",
"target_column": "satisfaction",
"target_chart_title": "charges values on a plot"
}
}
info = dataset_info.get(self.dataset_name)
if info is None:
return
report_button = st.button("Generate Dataset Variable Report")
if report_button:
#self.generate_report() for the first time it takes time so report generated before for better performance
with st.status("Report generating...", expanded=True) as status:
st.write("Summarize Dataset...")
st.write("Generate report structure...")
st.write("Render HTML...")
status.update(label=f"{self.dataset_name} report generating complete!", state="complete", expanded=False)
webbrowser.open(f"{self.dataset_name} Report.html")
st.write("Dataframe first 10 rows: ", self.data.head(10))
st.write("Shape of Dataset: ", self.data.shape)
st.write(f" *{info['shape_text']}*")
st.markdown(info['attribute_info'])
st.write("Target Value: ", self.data[info['target_column']].value_counts())
st.bar_chart(self.data[info['target_column']].value_counts())
missing_values = self.data.isnull().sum()
columns_with_missing_values = missing_values[missing_values > 0]
if not columns_with_missing_values.empty:
st.write("Columns with Missing Values:", columns_with_missing_values)
else:
st.write("No missing values in the dataset")
st.write("Features dtypes:" , self.data.dtypes)
def data_preprocess_breast_cancer(self):
"""
Data preprocessing for the breast cancer dataset.
This function drops the 'id' and 'Unnamed: 32' columns, encodes the
'diagnosis' column as 1 for malignant and 0 for benign, normalizes the
data, and stores the processed features and labels in self.X and self.y.
"""
# Drop unnecessary columns
st.subheader("Data Preprocessing")
st.write(f"***Drop Unnecessary Columns:*** id, Unnamed: 32")
self.data.drop(["id", "Unnamed: 32"], axis=1, inplace=True)
st.write(self.data.head(5),
"*Missing values handled with dropping Unnamed: 32 column*")
# Encode label
st.write(f"***Encoding Label:*** diagnosis - Malignant - 1, Benign - 0")
self.data['diagnosis'] = self.data['diagnosis'].map({'M': 1, 'B':0})
st.write(self.data["diagnosis"].tail(10))
# Normalize data
self.X = self.data.drop(["diagnosis"], axis=1)
self.y = self.data["diagnosis"].values
self.target_value_corelation_plot(self.dataset_name)
self.X = (self.X - self.X.min()) / (self.X.max() - self.X.min())
# Print normalized data
st.write("***Normalized Data***: ", self.X.head(10))
def data_preprocess_airline_satisfaction(self):
st.subheader("Data Preprocessing")
st.write(f"***Fill Missing Values: Arrival Delay in Minutes with 0***")
# Fill NaN values with 0 in the 'Arrival Delay in Minutes' column
self.data['Arrival Delay in Minutes'].fillna(0, inplace=True)
# Calculate the sum of all NaN values in the column
sum_of_na_values = self.data['Arrival Delay in Minutes'].isna().sum()
# Display the sum of all NaN values in one line
st.write("***Sum of all na values:***", sum_of_na_values)
st.write("***Drop Unnecessary Columns:*** id, Unnamed: 0")
self.data.drop(["id", "Unnamed: 0"], axis=1, inplace=True)
st.write(self.data.shape, " ***Data Shape***")
categorical_variable_column_name=['Gender','Customer Type','Type of Travel','Class','satisfaction']
numerical_variable_column_name=set(self.data.columns)-set(categorical_variable_column_name)
def numerical_variable_description(data):
description=data.describe().T
description=description.loc[list(numerical_variable_column_name),['mean','std','min','25%','50%','75%','max']]
st.write(description)
return description
numerical_variable_description(self.data)
def dummy_variable_encoding(data):
gender_encoding={'Male':1,'Female':0}
customer_type_encoding={'Loyal Customer':1,'disloyal Customer':0}
type_travel_encoding={'Business travel':1,'Personal Travel':0}
satisfaction_encoding={'satisfied':1,'neutral or dissatisfied':0}
data['Gender']=data['Gender'].replace(gender_encoding)
data['Customer Type']=data['Customer Type'].replace(customer_type_encoding)
data['Type of Travel']=data['Type of Travel'].replace(type_travel_encoding)
data['satisfaction']=data['satisfaction'].replace(satisfaction_encoding)
eco_plus_encoding={'Eco Plus':1,'Business':0,'Eco':0}
business_encoding={'Business':1,'Eco Plus':0,'Eco':0}
data['eco_plus']=data['Class'].replace(eco_plus_encoding)
data['business']=data['Class'].replace(business_encoding)
data.drop(labels=['Class'],axis=1,inplace=True)
st.write(f"***Before Encoding:***")
st.write(self.data.head(5))
st.write(f"***Dummy Variable Encoding:***")
st.write("Columns that encoded: ", categorical_variable_column_name)
dummy_variable_encoding(self.data)
st.write(self.data.head(5))
def normalization(data,Scaler):
scaler=Scaler()
data[list(numerical_variable_column_name)]=scaler.fit_transform(data[list(numerical_variable_column_name)])
normalization(self.data,MinMaxScaler)
self.target_value_corelation_plot(self.dataset_name)
self.X = self.data.drop(["satisfaction"], axis=1)
self.y = self.data["satisfaction"].values
@st.cache_data()
def target_value_corelation_plot(_self,dataset_name):
tab1, tab2 = st.tabs(["🗃 Matrix", "📈 Plot"])
with tab1:
tab1.subheader("Correlation Matrix ")
if _self.dataset_name == "Breast Cancer":
target_value = "diagnosis"
elif _self.dataset_name == "Airline Passanger Satisfaction":
target_value = "satisfaction"
# create correlation matrix
correlation_matrix = _self.data.corr()
st.write("Correlation Matrix:")
st.write(correlation_matrix)
st.write("##### Correlation Matrix Heatmap with Target Values:")
st.table(correlation_matrix[target_value])
with tab2:
tab2.subheader("Correlation Plot ")
# Display the correlation matrix heatmap
st.write("##### Correlation Matrix Heatmap All Features:")
plt.figure(figsize=(32, 16))
# Store heatmap object in a variable to easily access it when you want to include more features (such as title).
# Set the range of values to be displayed on the colormap from -1 to 1, and set the annotation to True to display the correlation values on the heatmap.
heatmap = sns.heatmap(correlation_matrix, vmin=-1, vmax=1, annot=True, cmap='BrBG')
# Give a title to the heatmap. Pad defines the distance of the title from the top of the heatmap.
heatmap.set_title('Correlation Heatmap', fontdict={'fontsize':12}, pad=12);
st.pyplot(plt)
# Display the correlation matrix heatmap with target values
diagnosis_correlation = correlation_matrix[target_value]
st.write("##### Correlation Matrix Heatmap with Target Values:")
plt.figure(figsize=(10, 8))
sns.heatmap(diagnosis_correlation.to_frame(), annot=True, cmap='BrBG', fmt=".2f")
plt.title('Correlation Matrix Heatmap with Target Values')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
st.pyplot(plt)
@st.cache_data
def data_viz_breast_cancer(_self):
# Display scatter plot
malignant_data = _self.data[_self.data['diagnosis'] == 1]
benign_data = _self.data[_self.data['diagnosis'] == 0]
plt.figure(figsize=(8, 6))
st.write("***Scatter Plot of Radius Mean vs Texture Mean:***")
# Make scatter plot transparent with less opacity
sns.scatterplot(x='radius_mean', y='texture_mean', data=malignant_data, label='Malignant', color='red', alpha=0.5)
sns.scatterplot(x='radius_mean', y='texture_mean', data=benign_data, label='Benign', color='blue', alpha=0.5)
plt.title('Scatter Plot of Radius Mean vs Texture Mean')
plt.xlabel('Radius Mean')
plt.ylabel('Texture Mean')
plt.legend()
st.pyplot(plt)
#postive correlation
st.write("***Scatter Plot of positive correlation:***")
fig,ax=plt.subplots(2,2,figsize=(20,25))
sns.scatterplot(x='perimeter_mean',y='radius_worst',data=_self.data,hue='diagnosis',ax=ax[0][0])
sns.scatterplot(x='area_mean',y='radius_worst',data=_self.data,hue='diagnosis',ax=ax[1][0])
sns.scatterplot(x='texture_mean',y='texture_worst',data=_self.data,hue='diagnosis',ax=ax[0][1])
sns.scatterplot(x='area_worst',y='radius_worst',data=_self.data,hue='diagnosis',ax=ax[1][1])
st.pyplot(fig)
#negative correlation
st.write("***Scatter Plot of negative correlation:***")
fig,ax=plt.subplots(2,2,figsize=(20,25))
sns.scatterplot(x='area_mean',y='fractal_dimension_mean',data=_self.data,hue='diagnosis',ax=ax[0][0])
sns.scatterplot(x='radius_mean',y='smoothness_se',data=_self.data,hue='diagnosis',ax=ax[1][0])
sns.scatterplot(x='smoothness_se',y='perimeter_mean',data=_self.data,hue='diagnosis',ax=ax[0][1])
sns.scatterplot(x='area_mean',y='smoothness_se',data=_self.data,hue='diagnosis',ax=ax[1][1])
st.pyplot(fig)
@st.cache_data
def data_viz_airline_satisfaction(_self):
st.subheader("Data Visualization")
st.write(f"***Scatter Plot of Age vs Flight Distance:***")
plt.figure(figsize=(8, 6))
sns.scatterplot(x='Age', y='Flight Distance', data=_self.data)
st.pyplot(plt)
@st.cache_data
def parameter_tuning(_self, classifier_name, dataset_name):
"""
Function for parameter tuning.
This function performs hyperparameter tuning for different classifiers.
Args:
_self (object): Instance of the class.
classifier_name (str): Name of the classifier.
dataset_name (str): Name of the dataset.
"""
# Start timer for parameter tuning
start = time.time()
# Subheader for model and classifier name
st.subheader("**Model**")
st.write(f"***Classifiers*** = {_self.classifier_name}")
# Get sample size half of data from self.X and self.y
if dataset_name == "Airline Passanger Satisfaction":
sample_x, sample_y = _self.X[:int(len(_self.X)/20)], _self.y[:int(len(_self.y)/20)]
else:
sample_x, sample_y = _self.X[:int(len(_self.X))], _self.y[:int(len(_self.y))]
# Hyperparameter tuning
X_train, X_test, y_train, y_test = train_test_split(sample_x, sample_y, test_size=0.2, random_state=42)
# KNN Classifier parameter tuning
if classifier_name == "KNN":
kf = KFold(n_splits=2, shuffle=True, random_state=42)
parameter = {'n_neighbors': np.arange(1, 10, 1)}
knn = KNeighborsClassifier()
knn_cv = GridSearchCV(knn, param_grid=parameter, cv=kf, verbose=1, scoring='accuracy')
knn_cv.fit(X_train, y_train)
st.write("Best parameter values: ", knn_cv.best_params_)
_self.best_param = knn_cv.best_params_["n_neighbors"]
# SVM Classifier parameter tuning
elif classifier_name == "SVM":
param_grid = {'C': np.arange(1, 10, 1), 'gamma': [1, 0.1, 0.01], 'kernel': ['rbf']}
grid = GridSearchCV(SVC(), param_grid, refit=True, verbose=3, scoring='accuracy')
grid.fit(X_train, y_train)
st.write("Best parameter values: ", grid.best_params_)
_self.best_param = grid.best_params_["C"]
# Naive Bayes(GaussianNB) Classifier parameter tuning
elif classifier_name == "Naive Bayes(GaussianNB)":
params_NB = {'var_smoothing': np.arange(0.01, 2.0)}
gs_NB = GridSearchCV(estimator=GaussianNB(),
param_grid=params_NB,
cv=KFold(n_splits=5),
verbose=1,
scoring='accuracy')
gs_NB.fit(X_train, y_train)
st.write("Best parameter values: ", gs_NB.best_params_)
_self.best_param = gs_NB.best_params_["var_smoothing"]
# End timer for parameter tuning
end = time.time()
# Display time taken for parameter tuning
st.write("Time taken for parameter tuning: ", end - start)
def add_parameter_ui(self):
if self.classifier_name == "SVM":
C = st.sidebar.slider("C", 1, 10,step=1, value= self.best_param)
self.params["C"] = C
gamma = st.sidebar.select_slider("gamma", options =[1,0.1,0.01,0.001], value= self.best_param)
self.params["gamma"] = gamma
kernel = st.sidebar.radio("kernel", ("rbf", "poly", "linear"))
self.params["kernel"] = kernel
elif self.classifier_name == "KNN":
n_neighbors = st.sidebar.slider("n_neighbors", 1, 10, value= self.best_param)
self.params["n_neighbors"] = n_neighbors
elif self.classifier_name == "Naive Bayes(GaussianNB)":
var_smoothing = st.sidebar.slider("var_smoothing", 0.01, 2.0, value= self.best_param)
self.params["var_smoothing"] = var_smoothing
def models(self):
"""
This function trains a model, calculates various evaluation metrics, and generates a confusion matrix.
"""
# Time starter for model function
start = time.time()
# Get the classifier and split the data
self.get_classifier()
X_train, X_test, y_train, y_test = train_test_split(self.X, self.y, test_size=0.2, random_state=42)
# Write the parameters that the user has chosen
st.write("parameters you choose: ", self.params)
# Fit the model and make predictions
self.clf.fit(X_train, y_train)
y_pred = self.clf.predict(X_test)
# Calculate evaluation metrics
acc = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
# Time ender for model function
end = time.time()
# Write evaluation metrics
st.write(f"Accuracy = {acc}")
st.write(f"F1 Score = {f1}")
st.write(f"Precision = {precision}")
st.write(f"Recall = {recall}")
# Write model time taken
st.write("***Model Time Taken:***", end - start)
# Define the categories for confusion matrix
cm = confusion_matrix(y_test, y_pred)
st.write("***Confusion Matrix:***")
st.write(cm)
# Plot confusion matrix
f, ax =plt.subplots(figsize = (5,5))
sns.heatmap(cm,annot = True, linewidths= 0.5, linecolor="red", fmt=".0f", ax=ax)
plt.xlabel("y_pred")
plt.ylabel("y_true")
st.pyplot(f)
def get_classifier(self):
if self.classifier_name == "KNN":
self.clf = KNeighborsClassifier(n_neighbors=self.params["n_neighbors"])
elif self.classifier_name == "SVM":
self.clf = SVC(C=self.params["C"], kernel=self.params["kernel"],gamma=self.params["gamma"])
elif self.classifier_name == "Naive Bayes(GaussianNB)":
self.clf = GaussianNB(var_smoothing=self.params["var_smoothing"])