Skip to content

Commit

Permalink
Fix: Fix the construction of SimpleArray from slicing ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
ThreeMonth03 committed Nov 29, 2024
1 parent ad32dd9 commit 8470ec1
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 1 deletion.
61 changes: 61 additions & 0 deletions cpp/modmesh/buffer/SimpleArray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

#include <limits>
#include <stdexcept>
#include <functional>
#include <numeric>

#if defined(_MSC_VER)
#include <BaseTsd.h>
Expand Down Expand Up @@ -307,6 +309,65 @@ class SimpleArray
}
}

explicit SimpleArray(small_vector<size_t> const & shape,
small_vector<size_t> const & stride,
std::shared_ptr<buffer_type> const & buffer,
bool is_c_contiguous = true,
bool is_f_contiguous = false)
: SimpleArray(buffer)
{
if (buffer)
{
if (shape.size() != stride.size())
{
throw std::runtime_error("SimpleArray: shape and stride size mismatch");
}

if (is_c_contiguous)
{
if (stride[stride.size() - 1] != 1)
{
throw std::runtime_error("SimpleArray: C contiguous stride must end with 1");
}
for (size_t it = 0; it < shape.size() - 1; ++it)
{
if (stride[it] != shape[it + 1] * stride[it + 1])
{
throw std::runtime_error("SimpleArray: C contiguous stride must match shape");
}
}
}
if (is_f_contiguous)
{
if (stride[0] != 1)
{
throw std::runtime_error("SimpleArray: Fortran contiguous stride must start with 1");
}
for (size_t it = 0; it < shape.size() - 1; ++it)
{
if (stride[it + 1] != shape[it] * stride[it])
{
throw std::runtime_error("SimpleArray: Fortran contiguous stride must match shape");
}
}
}

const size_t nbytes = ITEMSIZE *
std::accumulate(shape.begin(),
shape.end(),
static_cast<size_t>(1),
std::multiplies<size_t>());
if (nbytes != buffer->nbytes())
{
throw std::runtime_error(Formatter() << "SimpleArray: shape byte count " << nbytes
<< " differs from buffer " << buffer->nbytes());
}

m_shape = shape;
m_stride = stride;
}
}

SimpleArray(std::initializer_list<T> init)
: SimpleArray(init.size())
{
Expand Down
10 changes: 9 additions & 1 deletion cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,24 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray
{
throw std::runtime_error("dtype mismatch");
}

modmesh::detail::shape_type shape;
modmesh::detail::shape_type stride;
constexpr size_t itemsize = wrapped_type::itemsize();
for (ssize_t i = 0; i < arr_in.ndim(); ++i)
{
shape.push_back(arr_in.shape(i));
stride.push_back(arr_in.strides(i) / itemsize);
}

const bool is_c_contiguous = (arr_in.flags() & py::array::c_style) == py::array::c_style;
const bool is_f_contiguous = (arr_in.flags() & py::array::f_style) == py::array::f_style;

std::shared_ptr<ConcreteBuffer> const buffer = ConcreteBuffer::construct(
arr_in.nbytes(),
arr_in.mutable_data(),
std::make_unique<ConcreteBufferNdarrayRemover>(arr_in));
return wrapped_type(shape, buffer);
return wrapped_type(shape, stride, buffer, is_c_contiguous, is_f_contiguous);
}),
py::arg("array"))
.def_buffer(&property_helper::get_buffer_info)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,27 @@ def test_SimpleArray_from_ndarray_content(self):
sarr.ndarray.fill(100)
self.assertTrue((ndarr == 100).all())

def test_SimpleArray_from_ndarray_slice(self):
ndarr = np.arange(1000, dtype='float64').reshape((10, 10, 10))
parr = ndarr[1:7:3, 6:2:-1, 3:9]
sarr = modmesh.SimpleArrayFloat64(array=ndarr[1:7:3, 6:2:-1, 3:9])

for i in range(2):
for j in range(4):
for k in range(6):
self.assertEqual(parr[i, j, k], sarr[i, j, k])

def test_SimpleArray_from_ndarray_transpose(self):
ndarr = np.arange(350, dtype='float64').reshape((5, 7, 10))
# The following array is F contiguous.
parr = ndarr[2:4].T
sarr = modmesh.SimpleArrayFloat64(array=ndarr[2:4].T)

for i in range(10):
for j in range(7):
for k in range(2):
self.assertEqual(parr[i, j, k], sarr[i, j, k])

def test_SimpleArray_broadcast_ellipsis_shape(self):
sarr = modmesh.SimpleArrayFloat64((2, 3, 4))
ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4))
Expand Down

0 comments on commit 8470ec1

Please sign in to comment.