Skip to content

Commit 12044c2

Browse files
authored
Porting to GHEX 0.2.0 + Fixes for Multiple Devices (#77)
- `ghex_comm` runtime ported to ghex 0.2.0, using oomph as transport library. - fixed an issue that caused `detected error differences between the runs` when using the `ghex_comm` runtime on multiple devices. It was due both to streams synchronization issues and to the choice of the default set-device strategy. - GHEX code does not depend anymore on the `gridtools::_impl` namespace. - NOTE: semantics of the `--device-mapping` option has changed: it now accepts only one device id per rank, meaning that different threads on the same rank must use the same device. - AMD GPUs fixes are not included, and will follow shortly.
1 parent 897e4b0 commit 12044c2

File tree

8 files changed

+71
-66
lines changed

8 files changed

+71
-66
lines changed

.github/workflows/tests.yml

-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ jobs:
1919
matrix:
2020
backend: [cpu_ifirst, cpu_kfirst, cuda, hip]
2121
runtime: [single_node, simple_mpi, gcl, ghex_comm]
22-
exclude:
23-
- backend: hip
24-
runtime: ghex_comm
2522
steps:
2623
- uses: actions/checkout@v2
2724
name: checkout

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ endif()
6565
# Helper functions
6666
function(compile_as_cuda)
6767
if(_gtbench_cuda_enabled)
68-
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE CUDA)
68+
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE CUDA COMPILE_OPTIONS "--default-stream=per-thread")
6969
endif()
7070
endfunction()
7171

Dockerfile

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ RUN apt-get update -qq && \
1010
libmpich-dev \
1111
tar \
1212
software-properties-common \
13-
wget && \
13+
wget \
14+
libnuma-dev && \
1415
rm -rf /var/lib/apt/lists/*
1516

1617
ARG CMAKE_VERSION=3.18.4
@@ -34,7 +35,7 @@ COPY . /gtbench
3435
RUN cd /gtbench && \
3536
mkdir -p build && \
3637
cd build && \
37-
if [ -d /opt/rocm ]; then export CXX=/opt/rocm/bin/hipcc; fi && \
38+
if [ -d /opt/rocm ]; then export ROCM_PATH=/opt/rocm; export PATH=${ROCM_PATH}/bin:${PATH}; export CXX=${ROCM_PATH}/bin/hipcc; fi && \
3839
cmake \
3940
-DCMAKE_BUILD_TYPE=Release \
4041
-DGTBENCH_BACKEND=${GTBENCH_BACKEND} \
@@ -43,5 +44,6 @@ RUN cd /gtbench && \
4344
.. && \
4445
make -j $(nproc) install && \
4546
rm -rf /gtbench/build
47+
ENV LD_LIBRARY_PATH=/usr/local/lib64:/usr/local/lib:${LD_LIBRARY_PATH}
4648

4749
CMD ["convergence_tests"]

include/gtbench/runtime/ghex_comm/run.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ struct runtime {
4040
std::array<int, 2> m_cart_dims;
4141
std::array<int, 2> m_thread_cart_dims;
4242
std::vector<int> m_device_mapping;
43+
int m_device;
4344
std::string m_output_filename;
4445
};
4546

@@ -74,7 +75,7 @@ result runtime_solve(runtime &rt, Analytical analytical, Stepper stepper,
7475

7576
std::vector<result> results(rt.m_num_threads);
7677
auto execution_func = [&](int id = 0) {
77-
set_device(rt.m_device_mapping[id]);
78+
set_device(rt.m_device);
7879
auto sub_grid = comm_grid[id];
7980
const auto exact = discrete_analytical::discretize(
8081
analytical, global_resolution, sub_grid.m_local_resolution,
@@ -116,7 +117,7 @@ result runtime_solve(runtime &rt, Analytical analytical, Stepper stepper,
116117
threads.reserve(rt.m_num_threads - 1);
117118
for (int i = 1; i < rt.m_num_threads; ++i)
118119
threads.emplace_back(execution_func, i);
119-
set_device(rt.m_device_mapping[0]);
120+
set_device(rt.m_device);
120121
execution_func(0);
121122

122123
for (auto &thread : threads)

src/runtime/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ add_library(runtime)
22
target_link_libraries(runtime PUBLIC gtbench_common)
33
target_link_libraries(gtbench PUBLIC runtime)
44

5-
add_subdirectory(${GTBENCH_RUNTIME})
65
add_subdirectory(device)
6+
add_subdirectory(${GTBENCH_RUNTIME})

src/runtime/device/set_device.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,17 @@
99
*/
1010
#include <stdexcept>
1111

12+
#include <gridtools/common/defs.hpp>
1213
#include <gtbench/runtime/device/set_device.hpp>
1314

15+
#ifdef GT_CUDACC
16+
#include <gridtools/common/cuda_runtime.hpp>
17+
#if defined(__HIP__)
18+
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
19+
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
20+
#endif
21+
#endif
22+
1423
namespace gtbench {
1524
namespace runtime {
1625

src/runtime/ghex_comm/CMakeLists.txt

+19-12
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
compile_as_cuda(run.cpp TARGET_DIRECTORY runtime)
2+
target_sources(runtime PRIVATE run.cpp)
3+
14
if(NOT _ghex_already_fetched)
25
find_package(GHEX QUIET)
36
endif()
47
if(NOT GHEX_FOUND)
5-
set(_ghex_repository "https://github.com/GridTools/GHEX.git")
6-
set(_ghex_tag "4d48f3145349064942941f191021b970db3cb36e")
8+
set(_ghex_repository "https://github.com/ghex-org/GHEX.git")
9+
set(_ghex_tag "b9adb65a28303a05d607a34d6a75de4f36305ae0")
710
if(NOT _ghex_already_fetched)
811
message(STATUS "Fetching GHEX ${_ghex_tag} from ${_ghex_repository}")
912
endif()
@@ -13,21 +16,25 @@ if(NOT GHEX_FOUND)
1316
GIT_REPOSITORY ${_ghex_repository}
1417
GIT_TAG ${_ghex_tag}
1518
)
19+
if(GTBENCH_BACKEND STREQUAL "gpu")
20+
set(GHEX_USE_GPU ON CACHE INTERNAL "")
21+
if(_gtbench_cuda_enabled)
22+
set(GHEX_GPU_TYPE "NVIDIA" CACHE INTERNAL "")
23+
endif()
24+
endif()
1625
FetchContent_MakeAvailable(ghex)
1726
set(_ghex_already_fetched ON CACHE INTERNAL "")
1827
endif()
1928

20-
compile_as_cuda(run.cpp TARGET_DIRECTORY runtime)
21-
target_sources(runtime PRIVATE run.cpp)
22-
target_link_libraries(runtime PUBLIC GHEX::ghexlib)
23-
target_compile_options(runtime PUBLIC "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--default-stream per-thread>")
24-
25-
if(GHEX_USE_UCP)
26-
target_compile_definitions(runtime PRIVATE GTBENCH_USE_GHEX_UCP)
27-
if(GHEX_USE_PMIX)
28-
target_compile_definitions(runtime PRIVATE GTBENCH_USE_GHEX_PMIX)
29-
endif()
29+
target_link_libraries(runtime PUBLIC GHEX::lib)
30+
if(GHEX_TRANSPORT_BACKEND STREQUAL "LIBFABRIC")
31+
target_link_libraries(runtime PUBLIC oomph::libfabric)
32+
elseif(GHEX_TRANSPORT_BACKEND STREQUAL "UCX")
33+
target_link_libraries(runtime PUBLIC oomph::ucx)
34+
elseif(GHEX_TRANSPORT_BACKEND STREQUAL "MPI")
35+
target_link_libraries(runtime PUBLIC oomph::mpi)
3036
endif()
37+
3138
if(GHEX_USE_XPMEM)
3239
target_compile_definitions(runtime PRIVATE GHEX_USE_XPMEM)
3340
endif()

src/runtime/ghex_comm/run.cpp

+34-45
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,13 @@
1111
#include <numeric>
1212
#include <regex>
1313

14+
#include <ghex/barrier.hpp>
1415
#include <ghex/bulk_communication_object.hpp>
1516
#include <ghex/glue/gridtools/field.hpp>
1617
#include <ghex/structured/grid.hpp>
1718
#include <ghex/structured/pattern.hpp>
1819
#include <ghex/structured/rma_range_generator.hpp>
1920

20-
#ifdef GTBENCH_USE_GHEX_UCP
21-
#include <ghex/transport_layer/ucx/context.hpp>
22-
using transport = gridtools::ghex::tl::ucx_tag;
23-
#else
24-
#include <ghex/transport_layer/mpi/context.hpp>
25-
using transport = gridtools::ghex::tl::mpi_tag;
26-
#endif
27-
#include <ghex/transport_layer/util/barrier.hpp>
28-
2921
#include <gtbench/runtime/ghex_comm/factorize.hpp>
3022
#include <gtbench/runtime/ghex_comm/run.hpp>
3123

@@ -49,7 +41,7 @@ runtime::runtime(int num_threads, std::array<int, 2> cart_dims,
4941
},
5042
MPI_Finalize),
5143
m_num_threads(num_threads), m_cart_dims(cart_dims),
52-
m_thread_cart_dims(thread_cart_dims), m_device_mapping(num_threads, 0),
44+
m_thread_cart_dims(thread_cart_dims), m_device_mapping(), m_device(0),
5345
m_output_filename(output_filename) {
5446
int size, rank;
5547
MPI_Comm_size(MPI_COMM_WORLD, &size);
@@ -91,22 +83,20 @@ runtime::runtime(int num_threads, std::array<int, 2> cart_dims,
9183
MPI_Comm_rank(shmem_comm, &shmem_rank);
9284
MPI_Comm_free(&shmem_comm);
9385
if (!device_mapping.empty()) {
94-
if (device_mapping.size() != shmem_size * num_threads)
86+
if (device_mapping.size() != shmem_size)
9587
throw std::runtime_error("device mapping has wrong size");
9688
m_device_mapping = device_mapping;
9789
} else {
98-
m_device_mapping.resize(shmem_size * m_num_threads);
90+
m_device_mapping.resize(shmem_size);
9991
std::iota(m_device_mapping.begin(), m_device_mapping.end(), 0);
10092
}
101-
m_device_mapping = std::vector<int>(
102-
m_device_mapping.begin() + shmem_rank * num_threads,
103-
m_device_mapping.begin() + (shmem_rank + 1) * num_threads);
93+
m_device = m_device_mapping[shmem_rank];
10494
#endif
10595
}
10696

10797
using domain_id_t = int;
10898
using dimension_t = std::integral_constant<int, 3>;
109-
using coordinate_t = gt::ghex::coordinate<std::array<int, 3>>;
99+
using coordinate_t = ghex::coordinate<std::array<int, 3>>;
110100

111101
struct local_domain {
112102
using domain_id_type = domain_id_t;
@@ -122,12 +112,11 @@ struct local_domain {
122112
coordinate_t const &last() const { return m_last; }
123113
};
124114

125-
using context_t =
126-
typename gt::ghex::tl::context_factory<transport>::context_type;
115+
using context_t = ghex::context;
127116
using communicator_t = context_t::communicator_type;
128-
using grid_t = gt::ghex::structured::grid::type<local_domain>;
129-
using patterns_t =
130-
gt::ghex::pattern_container<communicator_t, grid_t, domain_id_t>;
117+
using barrier_t = ghex::barrier;
118+
using grid_t = ghex::structured::grid::type<local_domain>;
119+
using patterns_t = ghex::pattern_container<grid_t, domain_id_t>;
131120

132121
struct halo_generator {
133122
using domain_type = local_domain;
@@ -214,7 +203,7 @@ class grid::impl {
214203
domain_vec m_domains;
215204
context_ptr_t m_context;
216205
patterns_ptr_t m_patterns;
217-
gridtools::ghex::tl::barrier_t m_barrier;
206+
barrier_t m_barrier;
218207

219208
public:
220209
impl(vec<std::size_t, 3> const &global_resolution, int num_sub_domains,
@@ -224,7 +213,9 @@ class grid::impl {
224213
(int)global_resolution.y - 1,
225214
(int)global_resolution.z - 1}},
226215
m_global_resolution{global_resolution.x, global_resolution.y},
227-
m_barrier(num_sub_domains) {
216+
m_context{
217+
std::make_unique<context_t>(MPI_COMM_WORLD, (num_sub_domains > 1))},
218+
m_barrier{*m_context, static_cast<std::size_t>(num_sub_domains)} {
228219
MPI_Comm_size(MPI_COMM_WORLD, &m_size);
229220
MPI_Comm_rank(MPI_COMM_WORLD, &m_rank);
230221

@@ -271,20 +262,17 @@ class grid::impl {
271262
(int)global_resolution.z - 1}});
272263
}
273264
}
274-
m_context =
275-
gt::ghex::tl::context_factory<transport>::create(MPI_COMM_WORLD);
276265
m_patterns = std::make_unique<patterns_type>(
277-
gt::ghex::make_pattern<gt::ghex::structured::grid>(*m_context, m_hg,
278-
m_domains));
266+
ghex::make_pattern<ghex::structured::grid>(*m_context, m_hg,
267+
m_domains));
279268
}
280269

281270
impl(impl const &) = delete;
282271
impl &operator=(impl const &) = delete;
283272

284273
sub_grid operator[](unsigned int i) {
285274
const auto &dom = m_domains[i];
286-
auto comm = m_context->get_communicator();
287-
m_barrier(comm);
275+
m_barrier();
288276

289277
vec<std::size_t, 3> local_resolution = {
290278
(std::size_t)(dom.last()[0] - dom.first()[0] + 1),
@@ -295,29 +283,30 @@ class grid::impl {
295283
(std::size_t)dom.first()[2]};
296284

297285
auto b_comm_obj_map = std::make_shared<
298-
std::map<void *, gt::ghex::generic_bulk_communication_object>>();
286+
std::map<void *, ghex::generic_bulk_communication_object>>();
299287

300-
auto halo_exchange = [b_comm_obj_map = std::move(b_comm_obj_map), comm,
301-
domain = dom,
288+
auto halo_exchange = [b_comm_obj_map = std::move(b_comm_obj_map),
289+
&context = m_context, domain = dom,
302290
&patterns = *m_patterns](storage_t &storage) mutable {
303291
#ifdef GTBENCH_BACKEND_GPU
304-
using arch_t = gt::ghex::gpu;
292+
using arch_t = ghex::gpu;
305293
#else
306-
using arch_t = gt::ghex::cpu;
294+
using arch_t = ghex::cpu;
307295
#endif
308-
auto field =
309-
gt::ghex::wrap_gt_field<arch_t>(domain, storage, {halo, halo, 0});
296+
auto field = ghex::wrap_gt_field<arch_t>(
297+
domain, storage,
298+
{halo, halo, 0}); // device_id is initialized to the current device id
299+
// by default in GHEX
310300
auto it = b_comm_obj_map->find(field.data());
311301
if (it == b_comm_obj_map->end()) {
312-
auto sbco = gt::ghex::bulk_communication_object<
313-
gt::ghex::structured::rma_range_generator, patterns_type,
314-
decltype(field)>(comm);
302+
auto sbco = ghex::bulk_communication_object<
303+
ghex::structured::rma_range_generator, patterns_type,
304+
decltype(field)>(*context);
315305
sbco.add_field(patterns(field));
316306
it = b_comm_obj_map
317-
->insert(
318-
std::make_pair((void *)field.data(),
319-
gt::ghex::generic_bulk_communication_object(
320-
std::move(sbco))))
307+
->insert(std::make_pair(
308+
(void *)field.data(),
309+
ghex::generic_bulk_communication_object(std::move(sbco))))
321310
.first;
322311
}
323312
auto &bco = it->second;
@@ -372,10 +361,10 @@ void runtime_register_options(ghex_comm, options &options) {
372361
"TX TY", 2);
373362
#ifdef GT_CUDACC
374363
options("device-mapping",
375-
"node device mapping: device id per sub-domain in the format "
364+
"node device mapping: device id per rank in the format "
376365
"I_0:I_1:...:I_(N-1) "
377366
"where I_i are cuda device ids "
378-
"and N = #ranks-per-node x S",
367+
"and N = #ranks-per-node",
379368
"M");
380369
#endif
381370
}

0 commit comments

Comments
 (0)