Skip to content

Commit

Permalink
Refactor for tree solver, all tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorKraevTransferwise committed May 6, 2024
1 parent 80534f9 commit b5ff932
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 28 deletions.
41 changes: 38 additions & 3 deletions wise_pizza/cluster.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import numpy as np
import pandas as pd
from sklearn.preprocessing import PowerTransformer
Expand All @@ -19,16 +21,26 @@ def guided_kmeans(X: np.ndarray, power_transform: bool = True) -> np.ndarray:

if power_transform:
if len(X[X > 0] > 1):
X[X > 0] = PowerTransformer(standardize=False).fit_transform(X[X > 0].reshape(-1, 1)).reshape(-1)
X[X > 0] = (
PowerTransformer(standardize=False)
.fit_transform(X[X > 0].reshape(-1, 1))
.reshape(-1)
)
if len(X[X < 0] > 1):
X[X < 0] = -PowerTransformer(standardize=False).fit_transform(-X[X < 0].reshape(-1, 1)).reshape(-1)
X[X < 0] = (
-PowerTransformer(standardize=False)
.fit_transform(-X[X < 0].reshape(-1, 1))
.reshape(-1)
)

best_score = -1
best_labels = None
best_n = -1
# If we allow 2 clusters, it almost always just splits positive vs negative - boring!
for n_clusters in range(3, int(len(X) / 2) + 1):
cluster_labels = KMeans(n_clusters=n_clusters, init="k-means++", n_init=10).fit_predict(X)
cluster_labels = KMeans(
n_clusters=n_clusters, init="k-means++", n_init=10
).fit_predict(X)
score = silhouette_score(X, cluster_labels)
# print(n_clusters, score)
if score > best_score:
Expand All @@ -45,3 +57,26 @@ def to_matrix(labels: np.ndarray) -> np.ndarray:
for i in labels.unique():
out[labels == i, i] = 1.0
return out


def make_clusters(dim_df: pd.DataFrame, dims: List[str]):
cluster_names = {}
for dim in dims:
if len(dim_df[dim].unique()) >= 6: # otherwise what's the point in clustering?
grouped_df = (
dim_df[[dim, "totals", "weights"]].groupby(dim, as_index=False).sum()
)
grouped_df["avg"] = grouped_df["totals"] / grouped_df["weights"]
grouped_df["cluster"], _ = guided_kmeans(grouped_df["avg"])
pre_clusters = (
grouped_df[["cluster", dim]]
.groupby("cluster")
.agg({dim: lambda x: "@@".join(x)})
.values
)
# filter out clusters with only one element
these_clusters = [c for c in pre_clusters.reshape(-1) if "@@" in c]
# create short cluster names
for i, c in enumerate(these_clusters):
cluster_names[f"{dim}_cluster_{i + 1}"] = c
return cluster_names
30 changes: 5 additions & 25 deletions wise_pizza/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from wise_pizza.solve.find_alpha import clean_up_min_max, find_alpha
from wise_pizza.make_matrix import sparse_dummy_matrix
from wise_pizza.cluster import guided_kmeans
from wise_pizza.cluster import make_clusters
from wise_pizza.preselect import HeuristicSelector
from wise_pizza.time import extend_dataframe
from wise_pizza.slicer_facades import SliceFinderPredictFacade
Expand Down Expand Up @@ -192,31 +192,11 @@ def fit(
cluster_values = False

if cluster_values:
self.cluster_names = make_clusters(dim_df, dims)
for dim in dims:
if (
len(dim_df[dim].unique()) >= 6
): # otherwise what's the point in clustering?
grouped_df = (
dim_df[[dim, "totals", "weights"]]
.groupby(dim, as_index=False)
.sum()
)
grouped_df["avg"] = grouped_df["totals"] / grouped_df["weights"]
grouped_df["cluster"], _ = guided_kmeans(grouped_df["avg"])
pre_clusters = (
grouped_df[["cluster", dim]]
.groupby("cluster")
.agg({dim: lambda x: "@@".join(x)})
.values
)
# filter out clusters with only one element
these_clusters = [c for c in pre_clusters.reshape(-1) if "@@" in c]
# create short cluster names
for i, c in enumerate(these_clusters):
self.cluster_names[f"{dim}_cluster_{i+1}"] = c
clusters[dim] = [
c for c in self.cluster_names.keys() if c.startswith(dim)
]
clusters[dim] = [
c for c in self.cluster_names.keys() if c.startswith(dim)
]

dim_df = dim_df[dims] # if time_col is None else dims + ["__time"]]
self.dim_df = dim_df
Expand Down

0 comments on commit b5ff932

Please sign in to comment.