This repository has been archived by the owner on Nov 18, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubsampling2.py
113 lines (98 loc) · 4.05 KB
/
subsampling2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (C) 2017, 2018 University of Vienna
# All rights reserved.
# BSD license.
# Author: Ali Baharev <[email protected]>
import numpy as np
from cffi import FFI
__all__ = [ 'subsample2' ]
ffi = FFI()
ffi.cdef('''
void downsample2(const double* x, // shape: n x point_dim
const int* idx_selected, // shape: k
int k,
const int* idx_not_selected, // shape: n-k
unsigned char* selected, // shape: n
double* d_min, // shape: n
int n, // number of points
int point_dim,
int n_pts_to_add);
''')
so = ffi.dlopen('subsample2.so')
downsample = so.downsample2
def const_double_ptr(arr):
assert arr.dtype == np.float64, arr.dtype
return ffi.cast('const double*', arr.ctypes.data)
def double_ptr(arr):
assert arr.dtype == np.float64, arr.dtype
return ffi.cast('double*', arr.ctypes.data)
def uchar_ptr(arr):
assert arr.dtype == np.bool_, arr.dtype
return ffi.cast('unsigned char*', arr.ctypes.data)
def const_int_ptr(arr):
assert arr.dtype == np.intc, arr.dtype
return ffi.cast('const int*', arr.ctypes.data)
#-------------------------------------------------------------------------------
def subsample2(x, n_desired_pts):
# Returns: a boolean array, indicating whether a point in x was selected.
assert x.ndim == 2, x.shape # x is a 2D array of the points
n_pts, point_dim = x.shape[0], x.shape[1]
assert n_pts > 0 and point_dim > 0, (n_pts, point_dim)
assert np.isfinite(x).all()
assert 0 < n_desired_pts, n_desired_pts
if n_desired_pts > n_pts:
msg = '*** Warning: We have less points than desired! ({} < {}) ***'
print(msg.format(n_pts, n_desired_pts))
return np.full(n_pts, np.True_, dtype=np.bool_)
else:
# Best effort: try to pick the point, closest to the middle:
mean = np.mean(x, axis=0)
idx = np.argmin(np.linalg.norm(x-mean, axis=1, ord=1))
selected = np.full(n_pts, np.False_, dtype=np.bool_)
selected[idx] = np.True_
# Then call subsampling to fill up the rest, if any
return subsample2_(x, selected, n_desired_pts)
def subsample2_(x, selected, n_desired_pts):
# Similar to subsample but some points are already selected
assert x.ndim == 2, x.shape # x is a 2D array of the points
n_pts, point_dim = x.shape[0], x.shape[1]
assert n_pts > 0 and point_dim > 0, (n_pts, point_dim)
assert np.isfinite(x).all()
assert x.shape[0] == selected.shape[0], (x.shape, selected.shape)
assert 0 < n_desired_pts, n_desired_pts
if n_desired_pts > n_pts:
msg = '*** Warning: We have less points than desired! ({} < {}) ***'
print(msg.format(n_pts, n_desired_pts))
return np.full(n_pts, np.True_, dtype=np.bool_)
selected = selected.copy()
all_idx = np.arange(n_pts)
idx = all_idx[selected]
k = len(idx)
to_add = n_desired_pts - k
if to_add <= 0:
return selected
idx_not_selected = all_idx[~selected]
assert len(idx_not_selected) == n_pts - k
assert selected.any()
x = np.ascontiguousarray(x)
idx = idx.astype(np.intc)
idx_not_selected = idx_not_selected.astype(np.intc)
# d_nn: temporary work array for distances to the nearest neighbors
d_nn = np.full(n_pts, np.nan)
downsample(const_double_ptr(x), # shape: (n_pts, point_dim)
const_int_ptr(idx), # shape: (k,)
k,
const_int_ptr(idx_not_selected), # shape: (n_pts-k,)
uchar_ptr(selected), # shape: (n_pts,)
double_ptr(d_nn), # shape: (n_pts,)
n_pts,
point_dim,
to_add)
return selected
#-------------------------------------------------------------------------------
# For performance testing and profiling with perf
def _main():
np.random.seed(1)
# 10k -> 1k downsampling, point dimension: 25
subsample2(np.random.random((10000, 25)), 1000)
if __name__ == '__main__':
_main()