Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the construction of SimpleArray from slicing ndarray #438

Merged
merged 1 commit into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions cpp/modmesh/buffer/SimpleArray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

#include <limits>
#include <stdexcept>
#include <functional>
#include <numeric>

#if defined(_MSC_VER)
#include <BaseTsd.h>
Expand Down Expand Up @@ -307,6 +309,65 @@ class SimpleArray
}
}

explicit SimpleArray(small_vector<size_t> const & shape,
small_vector<size_t> const & stride,
std::shared_ptr<buffer_type> const & buffer,
bool is_c_contiguous = true,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use an enum here? @ThreeMonth03

Copy link
Collaborator Author

@ThreeMonth03 ThreeMonth03 Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized I could use an enum class inside the class SimpleArray after I asked the question.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThreeMonth03 Maybe you can consider opening another PR to fix this.

bool is_f_contiguous = false)
: SimpleArray(buffer)
{
if (buffer)
{
if (shape.size() != stride.size())
{
throw std::runtime_error("SimpleArray: shape and stride size mismatch");
}

if (is_c_contiguous)
{
if (stride[stride.size() - 1] != 1)
{
throw std::runtime_error("SimpleArray: C contiguous stride must end with 1");
}
for (size_t it = 0; it < shape.size() - 1; ++it)
{
if (stride[it] != shape[it + 1] * stride[it + 1])
{
throw std::runtime_error("SimpleArray: C contiguous stride must match shape");
}
}
}
if (is_f_contiguous)
{
if (stride[0] != 1)
{
throw std::runtime_error("SimpleArray: Fortran contiguous stride must start with 1");
}
for (size_t it = 0; it < shape.size() - 1; ++it)
{
if (stride[it + 1] != shape[it] * stride[it])
{
throw std::runtime_error("SimpleArray: Fortran contiguous stride must match shape");
}
}
}
Comment on lines +326 to +353
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the C contiguous and F contiguous array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation is good.

We should refactor it to use a distinct helper (class or function) in a later PR (may use a separate issue to track).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way I have a questions:

  1. How could we pass and receive the bool is_c_contiguous and bool is_f_contiguous elegantly?
    I have considered passing the numpy flags directly, but the constructor may not only deal with the numpy array. Is it a good way to design the customed flags by enum class types to deal with every possible input?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the class template SimpleArray needs to know f/c continuity. The sanity check may happen in the Python wrapper which has the information from Numpy. Standalone helper (static or free function) allows you to check in C++ and pybind11 wrapper. It could make sense to create a stride object to organize the logics.

It is ok to add more metadata in SimpleArray, including f/c continuity, but it is a separate topic.


const size_t nbytes = ITEMSIZE *
std::accumulate(shape.begin(),
shape.end(),
static_cast<size_t>(1),
std::multiplies<size_t>());
if (nbytes != buffer->nbytes())
{
throw std::runtime_error(Formatter() << "SimpleArray: shape byte count " << nbytes
<< " differs from buffer " << buffer->nbytes());
}

m_shape = shape;
m_stride = stride;
}
}

SimpleArray(std::initializer_list<T> init)
: SimpleArray(init.size())
{
Expand Down
10 changes: 9 additions & 1 deletion cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,24 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray
{
throw std::runtime_error("dtype mismatch");
}

modmesh::detail::shape_type shape;
modmesh::detail::shape_type stride;
constexpr size_t itemsize = wrapped_type::itemsize();
for (ssize_t i = 0; i < arr_in.ndim(); ++i)
{
shape.push_back(arr_in.shape(i));
stride.push_back(arr_in.strides(i) / itemsize);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unit of the py::array::strides() is bytes, and the unit of the SimpleArray::strides() is the elements.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is clear enough. It is correct to use no comment.

}

const bool is_c_contiguous = (arr_in.flags() & py::array::c_style) == py::array::c_style;
const bool is_f_contiguous = (arr_in.flags() & py::array::f_style) == py::array::f_style;

std::shared_ptr<ConcreteBuffer> const buffer = ConcreteBuffer::construct(
arr_in.nbytes(),
arr_in.mutable_data(),
std::make_unique<ConcreteBufferNdarrayRemover>(arr_in));
return wrapped_type(shape, buffer);
return wrapped_type(shape, stride, buffer, is_c_contiguous, is_f_contiguous);
}),
py::arg("array"))
.def_buffer(&property_helper::get_buffer_info)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,27 @@ def test_SimpleArray_from_ndarray_content(self):
sarr.ndarray.fill(100)
self.assertTrue((ndarr == 100).all())

def test_SimpleArray_from_ndarray_slice(self):
ndarr = np.arange(1000, dtype='float64').reshape((10, 10, 10))
parr = ndarr[1:7:3, 6:2:-1, 3:9]
sarr = modmesh.SimpleArrayFloat64(array=ndarr[1:7:3, 6:2:-1, 3:9])
Comment on lines +455 to +456
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modify the step to make the testcase complicated.


for i in range(2):
for j in range(4):
for k in range(6):
self.assertEqual(parr[i, j, k], sarr[i, j, k])

def test_SimpleArray_from_ndarray_transpose(self):
ndarr = np.arange(350, dtype='float64').reshape((5, 7, 10))
# The following array is F contiguous.
parr = ndarr[2:4].T
sarr = modmesh.SimpleArrayFloat64(array=ndarr[2:4].T)

for i in range(10):
for j in range(7):
for k in range(2):
self.assertEqual(parr[i, j, k], sarr[i, j, k])

def test_SimpleArray_broadcast_ellipsis_shape(self):
sarr = modmesh.SimpleArrayFloat64((2, 3, 4))
ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4))
Expand Down
Loading