-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindexer.py
220 lines (174 loc) · 6.69 KB
/
indexer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""Methods for manipulating indices."""
from typing import List, Optional, Tuple
import torch
from torch import Tensor
@torch.jit.script
def ravel_multi_index(unraveled_coords: Tensor, shape: List[int]) -> Tensor:
"""Convert a tensor of flat indices in R^K into flattened coordinates in R^1.
'Untangle' a set of spatial indices in R^K to a 'flattened'
or 'raveled' set of indices in R^1.
Reference:
https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html
Reference:
https://github.com/francois-rozet/torchist/blob/master/torchist/__init__.py#L18
Args:
unraveled_coords: (N,K) Tensor containing the indices.
shape: (K,) Shape of the target grid.
Returns:
(N,) Tensor of 1D indices.
"""
shape_tensor = torch.as_tensor(
shape + [1],
device=unraveled_coords.device,
dtype=unraveled_coords.dtype,
)
if len(unraveled_coords) > 0:
max_coords, _ = unraveled_coords.max(dim=0)
assert torch.all(max_coords < shape_tensor[:-1])
coefs = shape_tensor[1:].flipud().cumprod(dim=0).flipud()
return torch.mul(unraveled_coords, coefs).sum(dim=-1)
@torch.jit.script
def unravel_index(raveled_indices: Tensor, shape: List[int]) -> Tensor:
"""Convert a tensor of flat indices in a multi-index of coordinate indices.
'Tangle' a set of spatial indices in R^1 to an 'aligned' or 'unraveled'
set of indices in R^K where K=len(dims).
Reference:
https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html.
Reference:
https://github.com/francois-rozet/torchist/blob/master/torchist/__init__.py#L37
Args:
raveled_indices: (N,) Tensor whose elements are indices
into the flattened version of an array of dimensions shape.
shape: (K,) List of dimensions.
Returns:
(N,K) Tensor of unraveled coordinates where K=len(dims).
"""
shape_tensor = torch.as_tensor(
shape,
device=raveled_indices.device,
dtype=raveled_indices.dtype,
)
coefs = shape_tensor[1:].flipud().cumprod(dim=0).flipud()
coefs = torch.cat((coefs, coefs.new_ones((1,))), dim=0)
unraveled_coords = (
torch.div(raveled_indices[..., None], coefs, rounding_mode="trunc")
% shape_tensor
)
return unraveled_coords
@torch.jit.script
def scatter_nd(
index: Tensor,
src: Tensor,
grid_shape: List[int],
permutation: Optional[List[int]] = None,
) -> Tensor:
"""Emplace (scatter) a set of values at the index locations.
Args:
index: (N,K) Tensor of coordinates.
src: (N,K) Values to emplace.
grid_shape: (K,) Size of each dimension.
permutation: Permutation to apply after scattering.
Returns:
The scattered output.
"""
if src.ndim == 1:
src = src[:, None]
raveled_indices = ravel_multi_index(index, grid_shape[:-1])[
:, None
].repeat(1, grid_shape[-1])
dst = torch.zeros(grid_shape, device=src.device, dtype=src.dtype)
dst.view(-1, grid_shape[-1]).scatter_add_(
dim=0, index=raveled_indices, src=src
)
if permutation is not None:
dst = dst.permute(permutation)
return dst
@torch.jit.script
def mgrid(intervals: List[List[int]]) -> Tensor:
"""Construct a meshgrid from a list of intervals.
NOTE: Variable args are not used here to maintain JIT support.
TODO: Explore rewrite with variadic args.
Args:
intervals: List of list of intervals.
Returns:
The constructed meshgrid.
"""
tensor_list = [torch.arange(start, end) for start, end in intervals]
meshgrid = torch.meshgrid(tensor_list, indexing="ij")
return torch.stack(meshgrid, dim=0)
@torch.jit.script
def ogrid(intervals: List[List[int]]) -> Tensor:
"""Return a sparse multi-dimensional 'meshgrid'.
Generate the Cartesian product of the intervals,
represented as a sparse tensor.
Reference:
https://numpy.org/doc/stable/reference/generated/numpy.ogrid.html
Args:
intervals: Any number of integer intervals.
Returns:
The sparse representation of the meshgrid.
"""
tensor_list = [torch.arange(start, end) for start, end in intervals]
meshgrid = torch.meshgrid(tensor_list, indexing="ij")
sparse_meshgrid = torch.stack(meshgrid, dim=-1).view(-1, len(intervals))
return sparse_meshgrid
@torch.jit.script
def ogrid_symmetric(intervals: List[int]) -> Tensor:
"""Compute a _symmetric_ sparse multi-dimensional 'meshgrid'.
Unlike `ogrid` this function does specify start and stop positions
for the indices. Instead, coordinates are centered about the
origin.
Reference:
https://numpy.org/doc/stable/reference/generated/numpy.ogrid.html
Args:
intervals: Any number of integer intervals.
Returns:
The sparse, _symmetric_ representation of the meshgrid.
"""
lowers = [i // 2 for i in intervals]
uppers = [i - l for i, l in zip(intervals, lowers)]
symmetric_intervals = [[-l, u] for l, u in zip(lowers, uppers)]
symmetric_ogrid: Tensor = ogrid(symmetric_intervals)
return symmetric_ogrid
@torch.jit.script
def ogrid_sparse_neighborhoods(
offsets: Tensor, intervals: List[int]
) -> Tensor:
"""Compute a sparse representation of multiple meshgrids.
Args:
offsets: (N,K) Tensor representing the "center" of each
sparse meshgrid.
intervals: (K,) List of symmetric neighborhoods to consider.
Returns:
(N,K) The tensor containing all of the sparse neighborhoods.
Raises:
ValueError: If the per-offset dimension doesn't match the length
of the intervals.
"""
if not offsets.shape[-1] == len(intervals):
raise ValueError(
"The per-offset dimension and the length "
"of the sparse intervals _must_ match."
)
ogrid = ogrid_symmetric(intervals)
ogrid_sparse: Tensor = (offsets[..., None, :] + ogrid[None]).flatten(0, 1)
return ogrid_sparse
def unique_indices(indices: Tensor, dim: int = 0) -> Tensor:
"""Compute the indices corresponding to the unique value.
Args:
indices: (N,K) Coordinate inputs.
dim: Dimension to compute unique operation over.
Returns:
The indices corresponding to the selected values.
"""
out: Tuple[Tensor, Tensor] = torch.unique(
indices, return_inverse=True, dim=dim
)
unique, inverse = out
perm = torch.arange(
inverse.size(dim), dtype=inverse.dtype, device=inverse.device
)
inverse, perm = inverse.flip([dim]), perm.flip([dim])
inv = inverse.new_empty(unique.size(dim)).scatter_(dim, inverse, perm)
inv, _ = inv.sort()
return inv