From 76eb1746dda4ca79c87cd8dccaa9311fb267928c Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Fri, 11 Apr 2025 14:53:08 -0700 Subject: [PATCH] Support RealVect in MPI --- Src/Base/AMReX_ParallelDescriptor.cpp | 31 ++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/Src/Base/AMReX_ParallelDescriptor.cpp b/Src/Base/AMReX_ParallelDescriptor.cpp index 88e11c3b32..bbac0719d1 100644 --- a/Src/Base/AMReX_ParallelDescriptor.cpp +++ b/Src/Base/AMReX_ParallelDescriptor.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -44,6 +45,7 @@ namespace { int call_mpi_finalize = 0; int num_startparallel_called = 0; + MPI_Datatype mpi_type_realvect = MPI_DATATYPE_NULL; MPI_Datatype mpi_type_intvect = MPI_DATATYPE_NULL; MPI_Datatype mpi_type_indextype = MPI_DATATYPE_NULL; MPI_Datatype mpi_type_box = MPI_DATATYPE_NULL; @@ -54,6 +56,7 @@ namespace namespace amrex::ParallelDescriptor { #ifdef AMREX_USE_MPI + template <> MPI_Datatype Mpi_typemap::type(); template <> MPI_Datatype Mpi_typemap::type(); template <> MPI_Datatype Mpi_typemap::type(); template <> MPI_Datatype Mpi_typemap::type(); @@ -377,11 +380,12 @@ StartParallel (int* argc, char*** argv, MPI_Comm a_mpi_comm) } // Create these types outside OMP parallel region + auto t0 = Mpi_typemap::type(); // NOLINT auto t1 = Mpi_typemap::type(); // NOLINT auto t2 = Mpi_typemap::type(); // NOLINT auto t3 = Mpi_typemap::type(); // NOLINT auto t4 = Mpi_typemap::type(); // NOLINT - amrex::ignore_unused(t1,t2,t3,t4); + amrex::ignore_unused(t0,t1,t2,t3,t4); // ---- find the maximum value for a tag int flag(0), *p; @@ -411,6 +415,7 @@ EndParallel () { --num_startparallel_called; if (num_startparallel_called == 0) { + BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_realvect) ); BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_intvect) ); BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_indextype) ); BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_box) ); @@ -423,6 +428,7 @@ EndParallel () BL_MPI_REQUIRE( MPI_Op_free(op) ); *op = MPI_OP_NULL; } + mpi_type_realvect = MPI_DATATYPE_NULL; mpi_type_intvect = MPI_DATATYPE_NULL; mpi_type_indextype = MPI_DATATYPE_NULL; mpi_type_box = MPI_DATATYPE_NULL; @@ -1377,6 +1383,29 @@ BL_FORT_PROC_DECL(BL_PD_ABORT,bl_pd_abort)() #endif #if defined(BL_USE_MPI) && !defined(BL_AMRPROF) +template <> MPI_Datatype Mpi_typemap::type() +{ + static_assert(std::is_trivially_copyable_v, "RealVect must be trivially copyable"); + static_assert(std::is_standard_layout_v, "RealVect must be standard layout"); + + if ( mpi_type_realvect == MPI_DATATYPE_NULL ) + { + MPI_Datatype types[] = { Mpi_typemap::type() }; + int blocklens[] = { AMREX_SPACEDIM }; + MPI_Aint disp[] = { 0 }; + BL_MPI_REQUIRE( MPI_Type_create_struct(1, blocklens, disp, types, &mpi_type_realvect) ); + MPI_Aint lb, extent; + BL_MPI_REQUIRE( MPI_Type_get_extent(mpi_type_realvect, &lb, &extent) ); + if (extent != sizeof(RealVect)) { + MPI_Datatype tmp = mpi_type_realvect; + BL_MPI_REQUIRE( MPI_Type_create_resized(tmp, 0, sizeof(RealVect), &mpi_type_realvect) ); + BL_MPI_REQUIRE( MPI_Type_free(&tmp) ); + } + BL_MPI_REQUIRE( MPI_Type_commit( &mpi_type_realvect ) ); + } + return mpi_type_realvect; +} + template <> MPI_Datatype Mpi_typemap::type() { static_assert(std::is_trivially_copyable_v, "IntVect must be trivially copyable");