Skip to content

Commit

Permalink
kmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Nov 25, 2024
1 parent 812742d commit 6091a11
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
27 changes: 27 additions & 0 deletions simba/data_processors/cuda/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from cupyx.scipy.spatial.distance import cdist
except:
import numpy as cp
try:
from cuml.cluster import KMeans
except:
from sklearn.cluster import KMeans

from simba.utils.checks import check_int, check_valid_array, check_valid_tuple
from simba.utils.enums import Formats
Expand Down Expand Up @@ -627,3 +631,26 @@ def davis_bouldin(x: np.ndarray,
max_ratio = max(max_ratio, ratio)
db_index += max_ratio
return db_index / n_labels


def kmeans_cuml(data: np.ndarray,
k: int = 2,
max_iter: int = 300,
output_type: Optional[str] = None,
sample_n: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
"""CRAP, SLOWER THAN SCIKIT"""

check_valid_array(data=data, source=f'{kmeans_cuml.__name__} data', accepted_dtypes=Formats.NUMERIC_DTYPES.value)
check_int(name=f'{kmeans_cuml.__name__} k', value=k, min_value=1)
check_int(name=f'{kmeans_cuml.__name__} max_iter', value=max_iter, min_value=1)
kmeans = KMeans(n_clusters=k, max_iter=max_iter)
if sample_n is not None:
check_int(name=f'{kmeans_cuml.__name__} sample', value=sample_n, min_value=1)
sample = min(sample_n, data.shape[0])
data_idx = np.random.choice(np.arange(data.shape[0]), sample)
mdl = kmeans.fit(data[data_idx])
else:
mdl = kmeans.fit(data)

return (mdl.cluster_centers_, mdl.predict(data))

4 changes: 2 additions & 2 deletions simba/data_processors/freezing_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,5 @@ def run(self):
stdout_success(msg=f'Results saved in {self.save_dir} directory.')

#
# FreezingDetector(data_dir=r'D:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location',
# config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini")
# FreezingDetector(data_dir=r'C:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location',
# config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini")

0 comments on commit 6091a11

Please sign in to comment.