Skip to content

Add new MPI data type for RealVect #4416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion Src/Base/AMReX_ParallelDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <AMReX_BLFort.H>
#include <AMReX_ParallelDescriptor.H>
#include <AMReX_Print.H>
#include <AMReX_RealVect.H>
#include <AMReX_TypeTraits.H>
#include <AMReX_Arena.H>

Expand Down Expand Up @@ -44,6 +45,7 @@ namespace
{
int call_mpi_finalize = 0;
int num_startparallel_called = 0;
MPI_Datatype mpi_type_realvect = MPI_DATATYPE_NULL;
MPI_Datatype mpi_type_intvect = MPI_DATATYPE_NULL;
MPI_Datatype mpi_type_indextype = MPI_DATATYPE_NULL;
MPI_Datatype mpi_type_box = MPI_DATATYPE_NULL;
Expand All @@ -54,6 +56,7 @@ namespace
namespace amrex::ParallelDescriptor {

#ifdef AMREX_USE_MPI
template <> MPI_Datatype Mpi_typemap<RealVect>::type();
template <> MPI_Datatype Mpi_typemap<IntVect>::type();
template <> MPI_Datatype Mpi_typemap<IndexType>::type();
template <> MPI_Datatype Mpi_typemap<Box>::type();
Expand Down Expand Up @@ -377,11 +380,12 @@ StartParallel (int* argc, char*** argv, MPI_Comm a_mpi_comm)
}

// Create these types outside OMP parallel region
auto t0 = Mpi_typemap<RealVect>::type(); // NOLINT
auto t1 = Mpi_typemap<IntVect>::type(); // NOLINT
auto t2 = Mpi_typemap<IndexType>::type(); // NOLINT
auto t3 = Mpi_typemap<Box>::type(); // NOLINT
auto t4 = Mpi_typemap<ParallelDescriptor::lull_t>::type(); // NOLINT
amrex::ignore_unused(t1,t2,t3,t4);
amrex::ignore_unused(t0,t1,t2,t3,t4);

// ---- find the maximum value for a tag
int flag(0), *p;
Expand Down Expand Up @@ -411,6 +415,7 @@ EndParallel ()
{
--num_startparallel_called;
if (num_startparallel_called == 0) {
BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_realvect) );
BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_intvect) );
BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_indextype) );
BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_box) );
Expand All @@ -423,6 +428,7 @@ EndParallel ()
BL_MPI_REQUIRE( MPI_Op_free(op) );
*op = MPI_OP_NULL;
}
mpi_type_realvect = MPI_DATATYPE_NULL;
mpi_type_intvect = MPI_DATATYPE_NULL;
mpi_type_indextype = MPI_DATATYPE_NULL;
mpi_type_box = MPI_DATATYPE_NULL;
Expand Down Expand Up @@ -1377,6 +1383,29 @@ BL_FORT_PROC_DECL(BL_PD_ABORT,bl_pd_abort)()
#endif

#if defined(BL_USE_MPI) && !defined(BL_AMRPROF)
template <> MPI_Datatype Mpi_typemap<RealVect>::type()
{
static_assert(std::is_trivially_copyable_v<RealVect>, "RealVect must be trivially copyable");
static_assert(std::is_standard_layout_v<RealVect>, "RealVect must be standard layout");

if ( mpi_type_realvect == MPI_DATATYPE_NULL )
{
MPI_Datatype types[] = { Mpi_typemap<Real>::type() };
int blocklens[] = { AMREX_SPACEDIM };
MPI_Aint disp[] = { 0 };
BL_MPI_REQUIRE( MPI_Type_create_struct(1, blocklens, disp, types, &mpi_type_realvect) );
MPI_Aint lb, extent;
BL_MPI_REQUIRE( MPI_Type_get_extent(mpi_type_realvect, &lb, &extent) );
if (extent != sizeof(RealVect)) {
MPI_Datatype tmp = mpi_type_realvect;
BL_MPI_REQUIRE( MPI_Type_create_resized(tmp, 0, sizeof(RealVect), &mpi_type_realvect) );
BL_MPI_REQUIRE( MPI_Type_free(&tmp) );
}
BL_MPI_REQUIRE( MPI_Type_commit( &mpi_type_realvect ) );
}
return mpi_type_realvect;
}

template <> MPI_Datatype Mpi_typemap<IntVect>::type()
{
static_assert(std::is_trivially_copyable_v<IntVect>, "IntVect must be trivially copyable");
Expand Down
Loading