11
11
#include < numeric>
12
12
#include < regex>
13
13
14
+ #include < ghex/barrier.hpp>
14
15
#include < ghex/bulk_communication_object.hpp>
15
16
#include < ghex/glue/gridtools/field.hpp>
16
17
#include < ghex/structured/grid.hpp>
17
18
#include < ghex/structured/pattern.hpp>
18
19
#include < ghex/structured/rma_range_generator.hpp>
19
20
20
- #ifdef GTBENCH_USE_GHEX_UCP
21
- #include < ghex/transport_layer/ucx/context.hpp>
22
- using transport = gridtools::ghex::tl::ucx_tag;
23
- #else
24
- #include < ghex/transport_layer/mpi/context.hpp>
25
- using transport = gridtools::ghex::tl::mpi_tag;
26
- #endif
27
- #include < ghex/transport_layer/util/barrier.hpp>
28
-
29
21
#include < gtbench/runtime/ghex_comm/factorize.hpp>
30
22
#include < gtbench/runtime/ghex_comm/run.hpp>
31
23
@@ -49,7 +41,7 @@ runtime::runtime(int num_threads, std::array<int, 2> cart_dims,
49
41
},
50
42
MPI_Finalize),
51
43
m_num_threads (num_threads), m_cart_dims(cart_dims),
52
- m_thread_cart_dims (thread_cart_dims), m_device_mapping(num_threads, 0 ),
44
+ m_thread_cart_dims (thread_cart_dims), m_device_mapping(), m_device( 0 ),
53
45
m_output_filename (output_filename) {
54
46
int size, rank;
55
47
MPI_Comm_size (MPI_COMM_WORLD, &size);
@@ -91,22 +83,20 @@ runtime::runtime(int num_threads, std::array<int, 2> cart_dims,
91
83
MPI_Comm_rank (shmem_comm, &shmem_rank);
92
84
MPI_Comm_free (&shmem_comm);
93
85
if (!device_mapping.empty ()) {
94
- if (device_mapping.size () != shmem_size * num_threads )
86
+ if (device_mapping.size () != shmem_size)
95
87
throw std::runtime_error (" device mapping has wrong size" );
96
88
m_device_mapping = device_mapping;
97
89
} else {
98
- m_device_mapping.resize (shmem_size * m_num_threads );
90
+ m_device_mapping.resize (shmem_size);
99
91
std::iota (m_device_mapping.begin (), m_device_mapping.end (), 0 );
100
92
}
101
- m_device_mapping = std::vector<int >(
102
- m_device_mapping.begin () + shmem_rank * num_threads,
103
- m_device_mapping.begin () + (shmem_rank + 1 ) * num_threads);
93
+ m_device = m_device_mapping[shmem_rank];
104
94
#endif
105
95
}
106
96
107
97
using domain_id_t = int ;
108
98
using dimension_t = std::integral_constant<int , 3 >;
109
- using coordinate_t = gt:: ghex::coordinate<std::array<int , 3 >>;
99
+ using coordinate_t = ghex::coordinate<std::array<int , 3 >>;
110
100
111
101
struct local_domain {
112
102
using domain_id_type = domain_id_t ;
@@ -122,12 +112,11 @@ struct local_domain {
122
112
coordinate_t const &last () const { return m_last; }
123
113
};
124
114
125
- using context_t =
126
- typename gt::ghex::tl::context_factory<transport>::context_type;
115
+ using context_t = ghex::context;
127
116
using communicator_t = context_t ::communicator_type;
128
- using grid_t = gt:: ghex::structured::grid::type<local_domain> ;
129
- using patterns_t =
130
- gt:: ghex::pattern_container<communicator_t , grid_t , domain_id_t >;
117
+ using barrier_t = ghex::barrier ;
118
+ using grid_t = ghex::structured::grid::type<local_domain>;
119
+ using patterns_t = ghex::pattern_container<grid_t , domain_id_t >;
131
120
132
121
struct halo_generator {
133
122
using domain_type = local_domain;
@@ -214,7 +203,7 @@ class grid::impl {
214
203
domain_vec m_domains;
215
204
context_ptr_t m_context;
216
205
patterns_ptr_t m_patterns;
217
- gridtools::ghex::tl:: barrier_t m_barrier;
206
+ barrier_t m_barrier;
218
207
219
208
public:
220
209
impl (vec<std::size_t , 3 > const &global_resolution, int num_sub_domains,
@@ -224,7 +213,9 @@ class grid::impl {
224
213
(int )global_resolution.y - 1 ,
225
214
(int )global_resolution.z - 1 }},
226
215
m_global_resolution{global_resolution.x , global_resolution.y },
227
- m_barrier (num_sub_domains) {
216
+ m_context{
217
+ std::make_unique<context_t >(MPI_COMM_WORLD, (num_sub_domains > 1 ))},
218
+ m_barrier{*m_context, static_cast <std::size_t >(num_sub_domains)} {
228
219
MPI_Comm_size (MPI_COMM_WORLD, &m_size);
229
220
MPI_Comm_rank (MPI_COMM_WORLD, &m_rank);
230
221
@@ -271,20 +262,17 @@ class grid::impl {
271
262
(int )global_resolution.z - 1 }});
272
263
}
273
264
}
274
- m_context =
275
- gt::ghex::tl::context_factory<transport>::create (MPI_COMM_WORLD);
276
265
m_patterns = std::make_unique<patterns_type>(
277
- gt:: ghex::make_pattern<gt:: ghex::structured::grid>(*m_context, m_hg,
278
- m_domains));
266
+ ghex::make_pattern<ghex::structured::grid>(*m_context, m_hg,
267
+ m_domains));
279
268
}
280
269
281
270
impl (impl const &) = delete ;
282
271
impl &operator =(impl const &) = delete ;
283
272
284
273
sub_grid operator [](unsigned int i) {
285
274
const auto &dom = m_domains[i];
286
- auto comm = m_context->get_communicator ();
287
- m_barrier (comm);
275
+ m_barrier ();
288
276
289
277
vec<std::size_t , 3 > local_resolution = {
290
278
(std::size_t )(dom.last ()[0 ] - dom.first ()[0 ] + 1 ),
@@ -295,29 +283,30 @@ class grid::impl {
295
283
(std::size_t )dom.first ()[2 ]};
296
284
297
285
auto b_comm_obj_map = std::make_shared<
298
- std::map<void *, gt:: ghex::generic_bulk_communication_object>>();
286
+ std::map<void *, ghex::generic_bulk_communication_object>>();
299
287
300
- auto halo_exchange = [b_comm_obj_map = std::move (b_comm_obj_map), comm,
301
- domain = dom,
288
+ auto halo_exchange = [b_comm_obj_map = std::move (b_comm_obj_map),
289
+ &context = m_context, domain = dom,
302
290
&patterns = *m_patterns](storage_t &storage) mutable {
303
291
#ifdef GTBENCH_BACKEND_GPU
304
- using arch_t = gt:: ghex::gpu;
292
+ using arch_t = ghex::gpu;
305
293
#else
306
- using arch_t = gt:: ghex::cpu;
294
+ using arch_t = ghex::cpu;
307
295
#endif
308
- auto field =
309
- gt::ghex::wrap_gt_field<arch_t >(domain, storage, {halo, halo, 0 });
296
+ auto field = ghex::wrap_gt_field<arch_t >(
297
+ domain, storage,
298
+ {halo, halo, 0 }); // device_id is initialized to the current device id
299
+ // by default in GHEX
310
300
auto it = b_comm_obj_map->find (field.data ());
311
301
if (it == b_comm_obj_map->end ()) {
312
- auto sbco = gt:: ghex::bulk_communication_object<
313
- gt:: ghex::structured::rma_range_generator, patterns_type,
314
- decltype (field)>(comm );
302
+ auto sbco = ghex::bulk_communication_object<
303
+ ghex::structured::rma_range_generator, patterns_type,
304
+ decltype (field)>(*context );
315
305
sbco.add_field (patterns (field));
316
306
it = b_comm_obj_map
317
- ->insert (
318
- std::make_pair ((void *)field.data (),
319
- gt::ghex::generic_bulk_communication_object (
320
- std::move (sbco))))
307
+ ->insert (std::make_pair (
308
+ (void *)field.data (),
309
+ ghex::generic_bulk_communication_object (std::move (sbco))))
321
310
.first ;
322
311
}
323
312
auto &bco = it->second ;
@@ -372,10 +361,10 @@ void runtime_register_options(ghex_comm, options &options) {
372
361
" TX TY" , 2 );
373
362
#ifdef GT_CUDACC
374
363
options (" device-mapping" ,
375
- " node device mapping: device id per sub-domain in the format "
364
+ " node device mapping: device id per rank in the format "
376
365
" I_0:I_1:...:I_(N-1) "
377
366
" where I_i are cuda device ids "
378
- " and N = #ranks-per-node x S " ,
367
+ " and N = #ranks-per-node" ,
379
368
" M" );
380
369
#endif
381
370
}
0 commit comments