Skip to content

Commit

Permalink
Fixes setting of writable flag for views and writing to read-only arr…
Browse files Browse the repository at this point in the history
…ays with `out` keyword (#1527)

* Fixes bugs with `writable` flag setting

`writable` flag was not being set correctly for indexing, real views, imaginary views, tranposes, and where shape is set directly

Also fixes cases where flag could be overridden by functions with `out` kwarg

* Adds a test for writable flag view behavior

* Removes assumption that new array is writable

Now flags are set based on input regardless of whether a new array is writable
per review feedback

* Adds _copy_writable for copying the writable flag between arrays

* Correct typos in _copy_writable

* Fixes clip writing to read-only out arrays when min and max are none
  • Loading branch information
ndgrigorian authored Feb 6, 2024
1 parent e25a32a commit 194dee2
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 11 deletions.
9 changes: 9 additions & 0 deletions dpctl/tensor/_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def _clip_none(x, val, out, order, _binary_fn):
f"output array must be of usm_ndarray type, got {type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != res_shape:
raise ValueError(
"The shape of input and output arrays are inconsistent. "
Expand Down Expand Up @@ -437,6 +440,9 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
f"{type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != x.shape:
raise ValueError(
"The shape of input and output arrays are "
Expand Down Expand Up @@ -600,6 +606,9 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
f"{type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != res_shape:
raise ValueError(
"The shape of input and output arrays are "
Expand Down
6 changes: 6 additions & 0 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def __call__(self, x, out=None, order="K"):
f"output array must be of usm_ndarray type, got {type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != x.shape:
raise ValueError(
"The shape of input and output arrays are inconsistent. "
Expand Down Expand Up @@ -601,6 +604,9 @@ def __call__(self, o1, o2, out=None, order="K"):
f"output array must be of usm_ndarray type, got {type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != res_shape:
raise ValueError(
"The shape of input and output arrays are inconsistent. "
Expand Down
3 changes: 3 additions & 0 deletions dpctl/tensor/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,9 @@ def matmul(x1, x2, out=None, dtype=None, order="K"):
f"output array must be of usm_ndarray type, got {type(out)}"
)

if not out.flags.writable:
raise ValueError("provided `out` array is read-only")

if out.shape != res_shape:
raise ValueError(
"The shape of input and output arrays are inconsistent. "
Expand Down
28 changes: 17 additions & 11 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
view.shape = tuple()
return view

cdef int _copy_writable(int lhs_flags, int rhs_flags):
"Copy the WRITABLE flag to lhs_flags from rhs_flags"
return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE)

cdef class usm_ndarray:
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
Expand Down Expand Up @@ -546,7 +549,7 @@ cdef class usm_ndarray:
PyMem_Free(self.shape_)
if (self.strides_):
PyMem_Free(self.strides_)
self.flags_ = contig_flag
self.flags_ = (contig_flag | (self.flags_ & USM_ARRAY_WRITABLE))
self.nd_ = new_nd
self.shape_ = shape_ptr
self.strides_ = strides_ptr
Expand Down Expand Up @@ -725,13 +728,13 @@ cdef class usm_ndarray:
buffer=self.base_,
offset=_meta[2]
)
res.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE)
res.array_namespace_ = self.array_namespace_

adv_ind = _meta[3]
adv_ind_start_p = _meta[4]

if adv_ind_start_p < 0:
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
Expand All @@ -749,6 +752,7 @@ cdef class usm_ndarray:
if not matching:
raise IndexError("boolean index did not match indexed array in dimensions")
res = _extract_impl(res, key_, axis=adv_ind_start_p)
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

if any(ind.dtype == dpt_bool for ind in adv_ind):
Expand All @@ -758,10 +762,13 @@ cdef class usm_ndarray:
adv_ind_int.extend(_nonzero_impl(ind))
else:
adv_ind_int.append(ind)
return _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)

return _take_multi_index(res, adv_ind, adv_ind_start_p)
res = _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

res = _take_multi_index(res, adv_ind, adv_ind_start_p)
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

def to_device(self, target, stream=None):
""" to_device(target_device)
Expand Down Expand Up @@ -1040,8 +1047,7 @@ cdef class usm_ndarray:
buffer=self.base_,
offset=_meta[2],
)
# set flags and namespace
Xv.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE)
# set namespace
Xv.array_namespace_ = self.array_namespace_

from ._copy_utils import (
Expand Down Expand Up @@ -1225,7 +1231,7 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
offset=offset_elems,
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
r.flags_ = _copy_writable(r.flags_, ary.flags_)
r.array_namespace_ = ary.array_namespace_
return r

Expand Down Expand Up @@ -1257,7 +1263,7 @@ cdef usm_ndarray _imag_view(usm_ndarray ary):
offset=offset_elems,
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
r.flags_ = _copy_writable(r.flags_, ary.flags_)
r.array_namespace_ = ary.array_namespace_
return r

Expand All @@ -1277,7 +1283,7 @@ cdef usm_ndarray _transpose(usm_ndarray ary):
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
offset=ary.get_offset()
)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
r.flags_ = _copy_writable(r.flags_, ary.flags_)
return r


Expand All @@ -1294,7 +1300,7 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
offset=ary.get_offset()
)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
r.flags_ = _copy_writable(r.flags_, ary.flags_)
return r


Expand Down
19 changes: 19 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,25 @@ def test_usm_ndarray_flags_bug_gh_1334():
assert r.flags["F"] and r.flags["C"]


def test_usm_ndarray_writable_flag_views():
get_queue_or_skip()
a = dpt.arange(10, dtype="f4")
a.flags["W"] = False

a.shape = (5, 2)
assert not a.flags.writable
assert not a.T.flags.writable
assert not a.mT.flags.writable
assert not a.real.flags.writable
assert not a[0:3].flags.writable

a = dpt.arange(10, dtype="c8")
a.flags["W"] = False

assert not a.real.flags.writable
assert not a.imag.flags.writable


@pytest.mark.parametrize(
"dtype",
[
Expand Down

0 comments on commit 194dee2

Please sign in to comment.