Skip to content

Commit a8ded71

Browse files
committed
ENH: Implemented __getitem__ logic
1 parent c81d2e2 commit a8ded71

File tree

3 files changed

+154
-8
lines changed

3 files changed

+154
-8
lines changed

sparse/mlir_backend/_constructors.py

+6
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ def __del__(self):
348348
for field in self._obj.get__fields_():
349349
free_memref(field)
350350

351+
def __getitem__(self, key) -> "Tensor":
352+
# imported lazily to avoid cyclic dependency
353+
from ._ops import getitem
354+
355+
return getitem(self, key)
356+
351357
@_hold_self_ref_in_ret
352358
def to_scipy_sparse(self) -> sps.sparray | np.ndarray:
353359
return self._obj.to_sps(self.shape)

sparse/mlir_backend/_ops.py

+111-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
from types import EllipsisType
23

34
import mlir.execution_engine
45
import mlir.passmanager
@@ -85,12 +86,39 @@ def get_reshape_module(
8586
def reshape(a, shape):
8687
return tensor.reshape(out_tensor_type, a, shape)
8788

88-
reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
89-
if DEBUG:
90-
(CWD / "reshape_module.mlir").write_text(str(module))
91-
pm.run(module.operation)
92-
if DEBUG:
93-
(CWD / "reshape_module_opt.mlir").write_text(str(module))
89+
reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
90+
if DEBUG:
91+
(CWD / "reshape_module.mlir").write_text(str(module))
92+
pm.run(module.operation)
93+
if DEBUG:
94+
(CWD / "reshape_module_opt.mlir").write_text(str(module))
95+
96+
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
97+
98+
99+
@fn_cache
100+
def get_slice_module(
101+
in_tensor_type: ir.RankedTensorType,
102+
out_tensor_type: ir.RankedTensorType,
103+
offsets: tuple[int, ...],
104+
sizes: tuple[int, ...],
105+
strides: tuple[int, ...],
106+
) -> ir.Module:
107+
with ir.Location.unknown(ctx):
108+
module = ir.Module.create()
109+
110+
with ir.InsertionPoint(module.body):
111+
112+
@func.FuncOp.from_py_func(in_tensor_type)
113+
def getitem(a):
114+
return tensor.extract_slice(out_tensor_type, a, [], [], [], offsets, sizes, strides)
115+
116+
getitem.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
117+
if DEBUG:
118+
(CWD / "getitem_module.mlir").write_text(str(module))
119+
pm.run(module.operation)
120+
if DEBUG:
121+
(CWD / "getitem_module_opt.mlir").write_text(str(module))
94122

95123
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
96124

@@ -135,3 +163,80 @@ def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
135163
)
136164

137165
return Tensor(ret_obj, shape=out_tensor_type.shape)
166+
167+
168+
def _add_missing_dims(key: tuple, ndim: int) -> tuple:
169+
if len(key) < ndim and Ellipsis not in key:
170+
return key + (...,)
171+
return key
172+
173+
174+
def _expand_ellipsis(key: tuple, ndim: int) -> tuple:
175+
if Ellipsis in key:
176+
if len([e for e in key if e is Ellipsis]) > 1:
177+
raise Exception(f"Ellipsis should be used once: {key}")
178+
to_expand = ndim - len(key) + 1
179+
if to_expand <= 0:
180+
raise Exception(f"Invalid use of Ellipsis in {key}")
181+
idx = key.index(Ellipsis)
182+
return key[:idx] + tuple(slice(None) for _ in range(to_expand)) + key[idx + 1 :]
183+
return key
184+
185+
186+
def _decompose_slices(
187+
key: tuple,
188+
shape: tuple[int, ...],
189+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
190+
offsets = []
191+
sizes = []
192+
strides = []
193+
194+
for key_elem, size in zip(key, shape, strict=False):
195+
if isinstance(key_elem, slice):
196+
offset = key_elem.start if key_elem.start is not None else 0
197+
size = key_elem.stop - offset if key_elem.stop is not None else size - offset
198+
stride = key_elem.step if key_elem.step is not None else 1
199+
elif isinstance(key_elem, int):
200+
offset = key_elem
201+
size = key_elem + 1
202+
stride = 1
203+
offsets.append(offset)
204+
sizes.append(size)
205+
strides.append(stride)
206+
207+
return tuple(offsets), tuple(sizes), tuple(strides)
208+
209+
210+
def _get_new_shape(sizes, strides) -> tuple[int, ...]:
211+
return tuple(size // stride for size, stride in zip(sizes, strides, strict=False))
212+
213+
214+
def getitem(
215+
x: Tensor,
216+
key: int | slice | EllipsisType | tuple[int | slice | EllipsisType, ...],
217+
) -> Tensor:
218+
if not isinstance(key, tuple):
219+
key = (key,)
220+
if None in key:
221+
raise Exception(f"Lazy indexing isn't supported: {key}")
222+
223+
ret_obj = x._format_class()
224+
225+
key = _add_missing_dims(key, x.ndim)
226+
key = _expand_ellipsis(key, x.ndim)
227+
offsets, sizes, strides = _decompose_slices(key, x.shape)
228+
229+
new_shape = _get_new_shape(sizes, strides)
230+
out_tensor_type = x._obj.get_tensor_definition(new_shape)
231+
232+
slice_module = get_slice_module(
233+
x._obj.get_tensor_definition(x.shape),
234+
out_tensor_type,
235+
offsets,
236+
sizes,
237+
strides,
238+
)
239+
240+
slice_module.invoke("getitem", ctypes.pointer(ctypes.pointer(ret_obj)), *x._obj.to_module_arg())
241+
242+
return Tensor(ret_obj, shape=out_tensor_type.shape)

sparse/mlir_backend/tests/test_simple.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,7 @@ def test_reshape(rng, dtype):
217217
arr = sps.random_array(
218218
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
219219
)
220-
if format == "coo":
221-
arr.sum_duplicates()
220+
arr.sum_duplicates()
222221

223222
tensor = sparse.asarray(arr)
224223

@@ -264,3 +263,39 @@ def test_reshape(rng, dtype):
264263
# DENSE
265264
# NOTE: dense reshape is probably broken in MLIR
266265
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
266+
267+
268+
@pytest.mark.skip(reason="https://discourse.llvm.org/t/illegal-operation-when-slicing-csr-csc-coo-tensor/81404")
269+
@parametrize_dtypes
270+
@pytest.mark.parametrize(
271+
"index",
272+
[
273+
0,
274+
(2,),
275+
(2, 3),
276+
(..., slice(0, 4, 2)),
277+
(1, slice(1, None, 1)),
278+
# TODO: For below cases we need an update to ownership mechanism.
279+
# `tensor[:, :]` returns the same memref that was passed.
280+
# The mechanism sees the result as MLIR-allocated and frees
281+
# it, while it still can be owned by SciPy/NumPy causing a
282+
# segfault when it frees SciPy/NumPy managed memory.
283+
# ...,
284+
# slice(None),
285+
# (slice(None), slice(None)),
286+
],
287+
)
288+
def test_indexing_2d(rng, dtype, index):
289+
SHAPE = (20, 30)
290+
DENSITY = 0.5
291+
292+
for format in ["csr", "csc", "coo"]:
293+
arr = sps.random_array(SHAPE, density=DENSITY, format=format, dtype=dtype, random_state=rng)
294+
arr.sum_duplicates()
295+
296+
tensor = sparse.asarray(arr)
297+
298+
actual = tensor[index].to_scipy_sparse()
299+
expected = arr.todense()[index]
300+
301+
np.testing.assert_array_equal(actual.todense(), expected)

0 commit comments

Comments
 (0)