31
31
32
32
#include < memory>
33
33
34
+ #include " catch2_test_helper.h"
34
35
#include " cub/detail/temporary_storage.cuh"
35
- #include " test_util.h"
36
+
37
+ using values = c2h::enum_type_list<int , 1 , 4 , 42 >;
36
38
37
39
template <int Items>
38
40
std::size_t GetTemporaryStorageSize (std::size_t (&sizes)[Items])
@@ -50,18 +52,17 @@ std::size_t GetActualZero()
50
52
return GetTemporaryStorageSize (sizes);
51
53
}
52
54
53
- template <int StorageSlots>
54
- void TestEmptyStorage ()
55
+ CUB_TEST (" Test empty storage" , " [temporary_storage_layout]" , values)
55
56
{
57
+ constexpr auto StorageSlots = c2h::get<0 , TestType>::value;
56
58
cub::detail::temporary_storage::layout<StorageSlots> temporary_storage;
57
- AssertEquals (temporary_storage.get_size (), GetActualZero ());
59
+ CHECK (temporary_storage.get_size () == GetActualZero ());
58
60
}
59
61
60
- template <int StorageSlots>
61
- void TestPartiallyFilledStorage ()
62
+ CUB_TEST (" Test partially filled storage" , " [temporary_storage_layout]" , values)
62
63
{
63
- using target_type = std:: uint64_t ;
64
-
64
+ constexpr auto StorageSlots = c2h::get< 0 , TestType>::value ;
65
+ using target_type = std:: uint64_t ;
65
66
constexpr std::size_t target_elements = 42 ;
66
67
constexpr std::size_t full_slot_elements = target_elements * sizeof (target_type);
67
68
constexpr std::size_t empty_slot_elements{};
@@ -88,26 +89,25 @@ void TestPartiallyFilledStorage()
88
89
89
90
temporary_storage.map_to_buffer (temp_storage.get (), temp_storage_bytes);
90
91
91
- AssertEquals (temp_storage_bytes, GetTemporaryStorageSize (sizes));
92
+ CHECK (temp_storage_bytes == GetTemporaryStorageSize (sizes));
92
93
93
94
for (int slot_id = 0 ; slot_id < StorageSlots; slot_id++)
94
95
{
95
96
if (slot_id % 2 == 0 )
96
97
{
97
- AssertTrue (arrays[slot_id]->get () != nullptr );
98
+ CHECK (arrays[slot_id]->get () != nullptr );
98
99
}
99
100
else
100
101
{
101
- AssertTrue (arrays[slot_id]->get () == nullptr );
102
+ CHECK (arrays[slot_id]->get () == nullptr );
102
103
}
103
104
}
104
105
}
105
106
106
- template <int StorageSlots>
107
- void TestGrow ()
107
+ CUB_TEST (" Test grow" , " [temporary_storage_layout]" , values)
108
108
{
109
- using target_type = std:: uint64_t ;
110
-
109
+ constexpr auto StorageSlots = c2h::get< 0 , TestType>::value ;
110
+ using target_type = std:: uint64_t ;
111
111
constexpr std::size_t target_elements_number = 42 ;
112
112
113
113
cub::detail::temporary_storage::layout<StorageSlots> preset_layout;
@@ -129,7 +129,7 @@ void TestGrow()
129
129
postset_arrays[slot_id]->grow (target_elements_number);
130
130
}
131
131
132
- AssertEquals (preset_layout.get_size (), postset_layout.get_size ());
132
+ CHECK (preset_layout.get_size () == postset_layout.get_size ());
133
133
134
134
const std::size_t tmp_storage_bytes = preset_layout.get_size ();
135
135
std::unique_ptr<std::uint8_t []> temp_storage (new std::uint8_t [tmp_storage_bytes]);
@@ -139,15 +139,14 @@ void TestGrow()
139
139
140
140
for (int slot_id = 0 ; slot_id < StorageSlots; slot_id++)
141
141
{
142
- AssertEquals (postset_arrays[slot_id]->get (), preset_arrays[slot_id]->get ());
142
+ CHECK (postset_arrays[slot_id]->get () == preset_arrays[slot_id]->get ());
143
143
}
144
144
}
145
145
146
- template <int StorageSlots>
147
- void TestDoubleGrow ()
146
+ CUB_TEST (" Test double grow" , " [temporary_storage_layout]" , values)
148
147
{
149
- using target_type = std:: uint64_t ;
150
-
148
+ constexpr auto StorageSlots = c2h::get< 0 , TestType>::value ;
149
+ using target_type = std:: uint64_t ;
151
150
constexpr std::size_t target_elements_number = 42 ;
152
151
153
152
cub::detail::temporary_storage::layout<StorageSlots> preset_layout;
@@ -169,7 +168,7 @@ void TestDoubleGrow()
169
168
postset_arrays[slot_id]->grow (2 * target_elements_number);
170
169
}
171
170
172
- AssertEquals (preset_layout.get_size (), postset_layout.get_size ());
171
+ CHECK (preset_layout.get_size () == postset_layout.get_size ());
173
172
174
173
const std::size_t tmp_storage_bytes = preset_layout.get_size ();
175
174
std::unique_ptr<std::uint8_t []> temp_storage (new std::uint8_t [tmp_storage_bytes]);
@@ -179,22 +178,6 @@ void TestDoubleGrow()
179
178
180
179
for (int slot_id = 0 ; slot_id < StorageSlots; slot_id++)
181
180
{
182
- AssertEquals (postset_arrays[slot_id]->get (), preset_arrays[slot_id]->get ());
181
+ CHECK (postset_arrays[slot_id]->get () == preset_arrays[slot_id]->get ());
183
182
}
184
183
}
185
-
186
- template <int StorageSlots>
187
- void Test ()
188
- {
189
- TestEmptyStorage<StorageSlots>();
190
- TestPartiallyFilledStorage<StorageSlots>();
191
- TestGrow<StorageSlots>();
192
- TestDoubleGrow<StorageSlots>();
193
- }
194
-
195
- int main ()
196
- {
197
- Test<1 >();
198
- Test<4 >();
199
- Test<42 >();
200
- }
0 commit comments