Skip to content

Commit bd22e42

Browse files
Port CUB temporary storage layout test to Catch2 (#1835)
1 parent bbf6533 commit bd22e42

File tree

1 file changed

+22
-39
lines changed

1 file changed

+22
-39
lines changed

cub/test/test_temporary_storage_layout.cu cub/test/catch2_test_temporary_storage_layout.cu

+22-39
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131

3232
#include <memory>
3333

34+
#include "catch2_test_helper.h"
3435
#include "cub/detail/temporary_storage.cuh"
35-
#include "test_util.h"
36+
37+
using values = c2h::enum_type_list<int, 1, 4, 42>;
3638

3739
template <int Items>
3840
std::size_t GetTemporaryStorageSize(std::size_t (&sizes)[Items])
@@ -50,18 +52,17 @@ std::size_t GetActualZero()
5052
return GetTemporaryStorageSize(sizes);
5153
}
5254

53-
template <int StorageSlots>
54-
void TestEmptyStorage()
55+
CUB_TEST("Test empty storage", "[temporary_storage_layout]", values)
5556
{
57+
constexpr auto StorageSlots = c2h::get<0, TestType>::value;
5658
cub::detail::temporary_storage::layout<StorageSlots> temporary_storage;
57-
AssertEquals(temporary_storage.get_size(), GetActualZero());
59+
CHECK(temporary_storage.get_size() == GetActualZero());
5860
}
5961

60-
template <int StorageSlots>
61-
void TestPartiallyFilledStorage()
62+
CUB_TEST("Test partially filled storage", "[temporary_storage_layout]", values)
6263
{
63-
using target_type = std::uint64_t;
64-
64+
constexpr auto StorageSlots = c2h::get<0, TestType>::value;
65+
using target_type = std::uint64_t;
6566
constexpr std::size_t target_elements = 42;
6667
constexpr std::size_t full_slot_elements = target_elements * sizeof(target_type);
6768
constexpr std::size_t empty_slot_elements{};
@@ -88,26 +89,25 @@ void TestPartiallyFilledStorage()
8889

8990
temporary_storage.map_to_buffer(temp_storage.get(), temp_storage_bytes);
9091

91-
AssertEquals(temp_storage_bytes, GetTemporaryStorageSize(sizes));
92+
CHECK(temp_storage_bytes == GetTemporaryStorageSize(sizes));
9293

9394
for (int slot_id = 0; slot_id < StorageSlots; slot_id++)
9495
{
9596
if (slot_id % 2 == 0)
9697
{
97-
AssertTrue(arrays[slot_id]->get() != nullptr);
98+
CHECK(arrays[slot_id]->get() != nullptr);
9899
}
99100
else
100101
{
101-
AssertTrue(arrays[slot_id]->get() == nullptr);
102+
CHECK(arrays[slot_id]->get() == nullptr);
102103
}
103104
}
104105
}
105106

106-
template <int StorageSlots>
107-
void TestGrow()
107+
CUB_TEST("Test grow", "[temporary_storage_layout]", values)
108108
{
109-
using target_type = std::uint64_t;
110-
109+
constexpr auto StorageSlots = c2h::get<0, TestType>::value;
110+
using target_type = std::uint64_t;
111111
constexpr std::size_t target_elements_number = 42;
112112

113113
cub::detail::temporary_storage::layout<StorageSlots> preset_layout;
@@ -129,7 +129,7 @@ void TestGrow()
129129
postset_arrays[slot_id]->grow(target_elements_number);
130130
}
131131

132-
AssertEquals(preset_layout.get_size(), postset_layout.get_size());
132+
CHECK(preset_layout.get_size() == postset_layout.get_size());
133133

134134
const std::size_t tmp_storage_bytes = preset_layout.get_size();
135135
std::unique_ptr<std::uint8_t[]> temp_storage(new std::uint8_t[tmp_storage_bytes]);
@@ -139,15 +139,14 @@ void TestGrow()
139139

140140
for (int slot_id = 0; slot_id < StorageSlots; slot_id++)
141141
{
142-
AssertEquals(postset_arrays[slot_id]->get(), preset_arrays[slot_id]->get());
142+
CHECK(postset_arrays[slot_id]->get() == preset_arrays[slot_id]->get());
143143
}
144144
}
145145

146-
template <int StorageSlots>
147-
void TestDoubleGrow()
146+
CUB_TEST("Test double grow", "[temporary_storage_layout]", values)
148147
{
149-
using target_type = std::uint64_t;
150-
148+
constexpr auto StorageSlots = c2h::get<0, TestType>::value;
149+
using target_type = std::uint64_t;
151150
constexpr std::size_t target_elements_number = 42;
152151

153152
cub::detail::temporary_storage::layout<StorageSlots> preset_layout;
@@ -169,7 +168,7 @@ void TestDoubleGrow()
169168
postset_arrays[slot_id]->grow(2 * target_elements_number);
170169
}
171170

172-
AssertEquals(preset_layout.get_size(), postset_layout.get_size());
171+
CHECK(preset_layout.get_size() == postset_layout.get_size());
173172

174173
const std::size_t tmp_storage_bytes = preset_layout.get_size();
175174
std::unique_ptr<std::uint8_t[]> temp_storage(new std::uint8_t[tmp_storage_bytes]);
@@ -179,22 +178,6 @@ void TestDoubleGrow()
179178

180179
for (int slot_id = 0; slot_id < StorageSlots; slot_id++)
181180
{
182-
AssertEquals(postset_arrays[slot_id]->get(), preset_arrays[slot_id]->get());
181+
CHECK(postset_arrays[slot_id]->get() == preset_arrays[slot_id]->get());
183182
}
184183
}
185-
186-
template <int StorageSlots>
187-
void Test()
188-
{
189-
TestEmptyStorage<StorageSlots>();
190-
TestPartiallyFilledStorage<StorageSlots>();
191-
TestGrow<StorageSlots>();
192-
TestDoubleGrow<StorageSlots>();
193-
}
194-
195-
int main()
196-
{
197-
Test<1>();
198-
Test<4>();
199-
Test<42>();
200-
}

0 commit comments

Comments
 (0)