Skip to content

Commit 9810858

Browse files
committed
Code reformat using black
1 parent 70277e7 commit 9810858

12 files changed

+395
-179
lines changed

examples/decision_boundary.py

+56-29
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,22 @@
2323
from sklearn.naive_bayes import GaussianNB
2424
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
2525

26-
h = .02 # step size in the mesh
27-
28-
names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", "Gaussian Process", "Neural Net", "Naive Bayes", "QDA",
29-
"Decision Tree", "Random Forest", "AdaBoost", "SCM-Conjunction", "SCM-Disjunction"]
26+
h = 0.02 # step size in the mesh
27+
28+
names = [
29+
"Nearest Neighbors",
30+
"Linear SVM",
31+
"RBF SVM",
32+
"Gaussian Process",
33+
"Neural Net",
34+
"Naive Bayes",
35+
"QDA",
36+
"Decision Tree",
37+
"Random Forest",
38+
"AdaBoost",
39+
"SCM-Conjunction",
40+
"SCM-Disjunction",
41+
]
3042

3143
classifiers = [
3244
KNeighborsClassifier(3),
@@ -40,17 +52,21 @@
4052
RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
4153
AdaBoostClassifier(),
4254
SetCoveringMachineClassifier(max_rules=4, model_type="conjunction", p=2.0),
43-
SetCoveringMachineClassifier(max_rules=4, model_type="disjunction", p=1.0)]
55+
SetCoveringMachineClassifier(max_rules=4, model_type="disjunction", p=1.0),
56+
]
4457

45-
X, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
46-
random_state=1, n_clusters_per_class=1)
58+
X, y = make_classification(
59+
n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1
60+
)
4761
rng = np.random.RandomState(2)
4862
X += 2 * rng.uniform(size=X.shape)
4963
linearly_separable = (X, y)
5064

51-
datasets = [make_moons(noise=0.3, random_state=0),
52-
make_circles(noise=0.2, factor=0.5, random_state=1),
53-
linearly_separable]
65+
datasets = [
66+
make_moons(noise=0.3, random_state=0),
67+
make_circles(noise=0.2, factor=0.5, random_state=1),
68+
linearly_separable,
69+
]
5470

5571
figure = plt.figure(figsize=(27, 11))
5672
i = 1
@@ -59,21 +75,21 @@
5975
# preprocess dataset, split into training and test part
6076
X, y = ds
6177
X = StandardScaler().fit_transform(X)
62-
X_train, X_test, y_train, y_test = \
63-
train_test_split(X, y, test_size=.4, random_state=42)
78+
X_train, X_test, y_train, y_test = train_test_split(
79+
X, y, test_size=0.4, random_state=42
80+
)
6481

65-
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
66-
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
67-
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
68-
np.arange(y_min, y_max, h))
82+
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
83+
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
84+
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
6985

7086
# just plot the dataset first
7187
cm = plt.cm.RdBu
72-
cm_bright = ListedColormap(['#FF0000', '#0000FF'])
73-
#cm = plt.cm.PiYG
74-
#cm_bright = ListedColormap(['#FF0000', '#00FF00'])
75-
#cm = plt.cm.bwr
76-
#cm_bright = ListedColormap(['#0000FF', '#FF0000'])
88+
cm_bright = ListedColormap(["#FF0000", "#0000FF"])
89+
# cm = plt.cm.PiYG
90+
# cm_bright = ListedColormap(['#FF0000', '#00FF00'])
91+
# cm = plt.cm.bwr
92+
# cm_bright = ListedColormap(['#0000FF', '#FF0000'])
7793
ax = plt.subplot(len(datasets), len(classifiers) + 1, i)
7894
if ds_cnt == 0:
7995
ax.set_title("Input data")
@@ -120,25 +136,36 @@
120136

121137
# Put the result into a color plot
122138
Z = Z.reshape(xx.shape)
123-
ax.contourf(xx, yy, Z, cmap=cm, alpha=.8)
139+
ax.contourf(xx, yy, Z, cmap=cm, alpha=0.8)
124140

125141
# Plot also the training points
126142
ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright)
127143
# and testing points
128-
ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright,
129-
alpha=0.6)
144+
ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.6)
130145

131146
ax.set_xlim(xx.min(), xx.max())
132147
ax.set_ylim(yy.min(), yy.max())
133148
ax.set_xticks(())
134149
ax.set_yticks(())
135150
if ds_cnt == 0:
136151
ax.set_title(name.title())
137-
ax.text(xx.min() + 0.2, yy.min() + 0.2, 'Acc.: {0:.2f}'.format(score).lstrip('0'), size=15,
138-
horizontalalignment='left', bbox=dict(facecolor='white', edgecolor='black', alpha=0.8))
139-
ax.text(xx.min() + 0.2, yy.min() + 0.8, "Rules: {0!s}".format(n_rules) if n_rules is not None else "",
140-
size=15, horizontalalignment='left', bbox=dict(facecolor='white', edgecolor='black', alpha=0.8))
152+
ax.text(
153+
xx.min() + 0.2,
154+
yy.min() + 0.2,
155+
"Acc.: {0:.2f}".format(score).lstrip("0"),
156+
size=15,
157+
horizontalalignment="left",
158+
bbox=dict(facecolor="white", edgecolor="black", alpha=0.8),
159+
)
160+
ax.text(
161+
xx.min() + 0.2,
162+
yy.min() + 0.8,
163+
"Rules: {0!s}".format(n_rules) if n_rules is not None else "",
164+
size=15,
165+
horizontalalignment="left",
166+
bbox=dict(facecolor="white", edgecolor="black", alpha=0.8),
167+
)
141168
i += 1
142169

143170
plt.tight_layout()
144-
plt.savefig("decision_boundary.pdf", bbox_inches="tight")
171+
plt.savefig("decision_boundary.pdf", bbox_inches="tight")

examples/sklearn_compatibility.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,29 @@
1010
n_examples = 200
1111
n_features = 1000
1212

13-
X,y = make_classification(n_samples=n_examples, n_features=n_features, n_classes=2,
14-
random_state=np.random.RandomState(42))
13+
X, y = make_classification(
14+
n_samples=n_examples,
15+
n_features=n_features,
16+
n_classes=2,
17+
random_state=np.random.RandomState(42),
18+
)
1519

1620
params = {
17-
"p" : [0.5,1.,2.],
18-
"max_rules" : [1,2,3,4,5],
19-
"model_type" : ["conjunction","disjunction"]
21+
"p": [0.5, 1.0, 2.0],
22+
"max_rules": [1, 2, 3, 4, 5],
23+
"model_type": ["conjunction", "disjunction"],
2024
}
2125
clf = SetCoveringMachineClassifier(random_state=np.random.RandomState(42))
2226

2327
print("Fitting in GirdSearchCV...")
2428

2529
grid = GridSearchCV(estimator=clf, param_grid=params, cv=3, n_jobs=-1, verbose=True)
26-
grid.fit(X,y)
30+
grid.fit(X, y)
2731

2832
print("GridSearch passed!")
2933
print("Fitting in pipeline with StandardScaler...")
3034

31-
clf = Pipeline([("scaler",StandardScaler()),("scm",SetCoveringMachineClassifier())])
32-
clf.fit(X,y)
35+
clf = Pipeline([("scaler", StandardScaler()), ("scm", SetCoveringMachineClassifier())])
36+
clf.fit(X, y)
3337

3438
print("Done without error.")

examples/tiebreaker.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,22 @@
1111
n_examples = 200
1212
n_features = 1000
1313

14-
X,y = make_classification(n_samples=n_examples, n_features=n_features, n_classes=2,
15-
random_state=np.random.RandomState(42))
14+
X, y = make_classification(
15+
n_samples=n_examples,
16+
n_features=n_features,
17+
n_classes=2,
18+
random_state=np.random.RandomState(42),
19+
)
20+
1621

1722
def my_tiebreaker(model_type, feature_idx, thresholds, kind):
18-
print("Hello from the tiebreaker! Got {0:d} equivalent rules for this {1!s} model.".format(len(feature_idx), model_type))
23+
print(
24+
"Hello from the tiebreaker! Got {0:d} equivalent rules for this {1!s} model.".format(
25+
len(feature_idx), model_type
26+
)
27+
)
1928
return 0
2029

30+
2131
clf = SetCoveringMachineClassifier()
2232
clf.fit(X, y, tiebreaker=my_tiebreaker)

examples/training_time.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,21 @@ def increase_n_features():
1515
n_bench_points = 5
1616
n_examples = 1000
1717
n_features = 100000
18-
18+
1919
avg_times = np.zeros(n_bench_points)
2020
nfs = [int(n_features * p) for p in np.linspace(0.01, 1.0, n_bench_points)]
2121
for _ in range(n_repeats):
2222
times = []
2323
for nf in nfs:
24-
X, y = make_classification(n_samples=n_examples, n_features=nf, n_classes=2,
25-
random_state=np.random.RandomState(42))
26-
clf = SetCoveringMachineClassifier(model_type="conjunction", p=1.0, max_rules=100)
24+
X, y = make_classification(
25+
n_samples=n_examples,
26+
n_features=nf,
27+
n_classes=2,
28+
random_state=np.random.RandomState(42),
29+
)
30+
clf = SetCoveringMachineClassifier(
31+
model_type="conjunction", p=1.0, max_rules=100
32+
)
2733
t = time()
2834
clf.fit(X, y)
2935
times.append(time() - t)
@@ -34,7 +40,11 @@ def increase_n_features():
3440
plt.plot(nfs, avg_times)
3541
plt.xlabel("n features")
3642
plt.ylabel("time (seconds)")
37-
plt.title("Training time for {0:d} <= n <= {1:d} features ({2:d} examples)".format(min(nfs), max(nfs), n_examples))
43+
plt.title(
44+
"Training time for {0:d} <= n <= {1:d} features ({2:d} examples)".format(
45+
min(nfs), max(nfs), n_examples
46+
)
47+
)
3848
plt.savefig("n_features.png", bbox_inches="tight")
3949

4050

@@ -43,15 +53,21 @@ def increase_n_examples():
4353
n_bench_points = 5
4454
n_examples = 10000
4555
n_features = 1000
46-
56+
4757
avg_times = np.zeros(n_bench_points)
4858
n_exs = [int(n_examples * p) for p in np.linspace(0.01, 1.0, n_bench_points)]
4959
for _ in range(n_repeats):
5060
times = []
5161
for n_ex in n_exs:
52-
X, y = make_classification(n_samples=n_ex, n_features=n_features, n_classes=2,
53-
random_state=np.random.RandomState(42))
54-
clf = SetCoveringMachineClassifier(model_type="conjunction", p=1.0, max_rules=100)
62+
X, y = make_classification(
63+
n_samples=n_ex,
64+
n_features=n_features,
65+
n_classes=2,
66+
random_state=np.random.RandomState(42),
67+
)
68+
clf = SetCoveringMachineClassifier(
69+
model_type="conjunction", p=1.0, max_rules=100
70+
)
5571
t = time()
5672
clf.fit(X, y)
5773
times.append(time() - t)
@@ -62,12 +78,18 @@ def increase_n_examples():
6278
plt.plot(n_exs, avg_times)
6379
plt.xlabel("n examples")
6480
plt.ylabel("time (seconds)")
65-
plt.title("Training time for {0:d} <= n <= {1:d} examples ({2:d} features)".format(min(n_exs), max(n_exs), n_features))
81+
plt.title(
82+
"Training time for {0:d} <= n <= {1:d} examples ({2:d} features)".format(
83+
min(n_exs), max(n_exs), n_features
84+
)
85+
)
6686
plt.savefig("n_examples.png", bbox_inches="tight")
6787

6888

69-
if __name__ == '__main__':
70-
logging.basicConfig(level=logging.DEBUG,
71-
format="%(asctime)s.%(msecs)d %(levelname)s %(module)s - %(funcName)s: %(message)s")
89+
if __name__ == "__main__":
90+
logging.basicConfig(
91+
level=logging.DEBUG,
92+
format="%(asctime)s.%(msecs)d %(levelname)s %(module)s - %(funcName)s: %(message)s",
93+
)
7294
increase_n_examples()
7395
increase_n_features()

pyscm/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
along with this program. If not, see <http://www.gnu.org/licenses/>.
1717
1818
"""
19-
from .scm import SetCoveringMachineClassifier
19+
from .scm import SetCoveringMachineClassifier

pyscm/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __len__(self):
6262
def __str__(self):
6363
return self._to_string()
6464

65+
6566
class ConjunctionModel(BaseModel):
6667
def predict(self, X):
6768
predictions = np.ones(X.shape[0], np.bool)

pyscm/rules.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class BaseRule(object):
2626
A rule mixin class
2727
2828
"""
29+
2930
def __init__(self):
3031
super(BaseRule, self).__init__()
3132

@@ -79,6 +80,7 @@ class DecisionStump(BaseRule):
7980
The case in which the rule returns 1, either "greater" or "less_equal".
8081
8182
"""
83+
8284
def __init__(self, feature_idx, threshold, kind="greater"):
8385
self.feature_idx = feature_idx
8486
self.threshold = threshold
@@ -116,8 +118,13 @@ def inverse(self):
116118
A rule that is the inverse of self.
117119
118120
"""
119-
return DecisionStump(feature_idx=self.feature_idx, threshold=self.threshold,
120-
kind="greater" if self.kind == "less_equal" else "less_equal")
121+
return DecisionStump(
122+
feature_idx=self.feature_idx,
123+
threshold=self.threshold,
124+
kind="greater" if self.kind == "less_equal" else "less_equal",
125+
)
121126

122127
def __str__(self):
123-
return "X[{0:d}] {1!s} {2:.3f}".format(self.feature_idx, ">" if self.kind == "greater" else "<=", self.threshold)
128+
return "X[{0:d}] {1!s} {2:.3f}".format(
129+
self.feature_idx, ">" if self.kind == "greater" else "<=", self.threshold
130+
)

0 commit comments

Comments
 (0)