From 8470ec1ac0ea0b8162e15d56aed49739cd2a337e Mon Sep 17 00:00:00 2001 From: ThreeMonth03 Date: Fri, 29 Nov 2024 01:47:08 +0800 Subject: [PATCH] Fix: Fix the construction of SimpleArray from slicing ndarray --- cpp/modmesh/buffer/SimpleArray.hpp | 61 +++++++++++++++++++ cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp | 10 ++- tests/test_buffer.py | 21 +++++++ 3 files changed, 91 insertions(+), 1 deletion(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 62e62e91..eeb04380 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -32,6 +32,8 @@ #include #include +#include +#include #if defined(_MSC_VER) #include @@ -307,6 +309,65 @@ class SimpleArray } } + explicit SimpleArray(small_vector const & shape, + small_vector const & stride, + std::shared_ptr const & buffer, + bool is_c_contiguous = true, + bool is_f_contiguous = false) + : SimpleArray(buffer) + { + if (buffer) + { + if (shape.size() != stride.size()) + { + throw std::runtime_error("SimpleArray: shape and stride size mismatch"); + } + + if (is_c_contiguous) + { + if (stride[stride.size() - 1] != 1) + { + throw std::runtime_error("SimpleArray: C contiguous stride must end with 1"); + } + for (size_t it = 0; it < shape.size() - 1; ++it) + { + if (stride[it] != shape[it + 1] * stride[it + 1]) + { + throw std::runtime_error("SimpleArray: C contiguous stride must match shape"); + } + } + } + if (is_f_contiguous) + { + if (stride[0] != 1) + { + throw std::runtime_error("SimpleArray: Fortran contiguous stride must start with 1"); + } + for (size_t it = 0; it < shape.size() - 1; ++it) + { + if (stride[it + 1] != shape[it] * stride[it]) + { + throw std::runtime_error("SimpleArray: Fortran contiguous stride must match shape"); + } + } + } + + const size_t nbytes = ITEMSIZE * + std::accumulate(shape.begin(), + shape.end(), + static_cast(1), + std::multiplies()); + if (nbytes != buffer->nbytes()) + { + throw std::runtime_error(Formatter() << "SimpleArray: shape byte count " << nbytes + << " differs from buffer " << buffer->nbytes()); + } + + m_shape = shape; + m_stride = stride; + } + } + SimpleArray(std::initializer_list init) : SimpleArray(init.size()) { diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index ac6ce06e..29919846 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -74,16 +74,24 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray { throw std::runtime_error("dtype mismatch"); } + modmesh::detail::shape_type shape; + modmesh::detail::shape_type stride; + constexpr size_t itemsize = wrapped_type::itemsize(); for (ssize_t i = 0; i < arr_in.ndim(); ++i) { shape.push_back(arr_in.shape(i)); + stride.push_back(arr_in.strides(i) / itemsize); } + + const bool is_c_contiguous = (arr_in.flags() & py::array::c_style) == py::array::c_style; + const bool is_f_contiguous = (arr_in.flags() & py::array::f_style) == py::array::f_style; + std::shared_ptr const buffer = ConcreteBuffer::construct( arr_in.nbytes(), arr_in.mutable_data(), std::make_unique(arr_in)); - return wrapped_type(shape, buffer); + return wrapped_type(shape, stride, buffer, is_c_contiguous, is_f_contiguous); }), py::arg("array")) .def_buffer(&property_helper::get_buffer_info) diff --git a/tests/test_buffer.py b/tests/test_buffer.py index b7497f7c..e632ac70 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -450,6 +450,27 @@ def test_SimpleArray_from_ndarray_content(self): sarr.ndarray.fill(100) self.assertTrue((ndarr == 100).all()) + def test_SimpleArray_from_ndarray_slice(self): + ndarr = np.arange(1000, dtype='float64').reshape((10, 10, 10)) + parr = ndarr[1:7:3, 6:2:-1, 3:9] + sarr = modmesh.SimpleArrayFloat64(array=ndarr[1:7:3, 6:2:-1, 3:9]) + + for i in range(2): + for j in range(4): + for k in range(6): + self.assertEqual(parr[i, j, k], sarr[i, j, k]) + + def test_SimpleArray_from_ndarray_transpose(self): + ndarr = np.arange(350, dtype='float64').reshape((5, 7, 10)) + # The following array is F contiguous. + parr = ndarr[2:4].T + sarr = modmesh.SimpleArrayFloat64(array=ndarr[2:4].T) + + for i in range(10): + for j in range(7): + for k in range(2): + self.assertEqual(parr[i, j, k], sarr[i, j, k]) + def test_SimpleArray_broadcast_ellipsis_shape(self): sarr = modmesh.SimpleArrayFloat64((2, 3, 4)) ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4))