diff --git a/cub/benchmarks/bench/find_if/base.cu b/cub/benchmarks/bench/find_if/base.cu new file mode 100644 index 00000000000..6c51a754b56 --- /dev/null +++ b/cub/benchmarks/bench/find_if/base.cu @@ -0,0 +1,89 @@ +/****************************************************************************** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include + +#include +#include + +#include + +template +struct equals +{ + T val; + + __device__ __host__ bool operator()(T i) + { + return i == val; + } +}; + +template +void find_if(nvbench::state& state, nvbench::type_list) +{ + T val = 1; + // set up input + const auto elements = state.get_int64("Elements"); + const auto common_prefix = state.get_float64("MismatchAt"); + const auto mismatch_point = elements * common_prefix; + + thrust::device_vector dinput(elements, 0); + thrust::fill(dinput.begin() + mismatch_point, dinput.end(), val); + thrust::device_vector d_result(1); + /// + + void* d_temp_storage = nullptr; + size_t temp_storage_bytes{}; + + cub::DeviceFind::FindIf( + d_temp_storage, + temp_storage_bytes, + thrust::raw_pointer_cast(dinput.data()), + thrust::raw_pointer_cast(d_result.data()), + equals{val}, + dinput.size(), + 0); + + thrust::device_vector temp_storage(temp_storage_bytes); + d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); + + state.exec(nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { + cub::DeviceFind::FindIf( + d_temp_storage, + temp_storage_bytes, + thrust::raw_pointer_cast(dinput.data()), + thrust::raw_pointer_cast(d_result.data()), + equals{val}, + dinput.size(), + launch.get_stream()); + }); +} + +NVBENCH_BENCH_TYPES(find_if, NVBENCH_TYPE_AXES(fundamental_types)) + .add_int64_power_of_two_axis("Elements", nvbench::range(16, 28, 4)) + .add_float64_axis("MismatchAt", std::vector{1.0, 0.5, 0.0}); diff --git a/cub/cub/agent/agent_find.cuh b/cub/cub/agent/agent_find.cuh new file mode 100644 index 00000000000..3298ad94cdb --- /dev/null +++ b/cub/cub/agent/agent_find.cuh @@ -0,0 +1,306 @@ +/****************************************************************************** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * @file cub::AgentFind implements a stateful abstraction of CUDA thread + * blocks for participating in device-wide search. + */ + +#pragma once +#include + +#include +#include +#include + +CUB_NAMESPACE_BEGIN + +/****************************************************************************** + * Tuning policy types + ******************************************************************************/ + +/** + * Parameterizable tuning policy type for AgentFind + * @tparam NOMINAL_BLOCK_THREADS_4B Threads per thread block + * @tparam NOMINAL_ITEMS_PER_THREAD_4B Items per thread (per tile of input) + * @tparam _VECTOR_LOAD_LENGTH Number of items per vectorized load + * @tparam _LOAD_MODIFIER Cache load modifier for reading input elements + */ +template > +struct AgentFindPolicy : ScalingType +{ + /// Number of items per vectorized load + static constexpr int VECTOR_LOAD_LENGTH = _VECTOR_LOAD_LENGTH; + + /// Cache load modifier for reading input elements + static constexpr CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; +}; + +template // @giannis OutputiteratorT not needed +struct AgentFind +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// The input value type + using InputT = cub::detail::value_t; + + /// Vector type of InputT for data movement + using VectorT = typename CubVector::Type; + + /// Input iterator wrapper type (for applying cache modifier) + // Wrap the native input pointer with CacheModifiedInputIterator + // or directly use the supplied input iterator type + using WrappedInputIteratorT = + ::cuda::std::_If<::cuda::std::is_pointer::value, + CacheModifiedInputIterator, + InputIteratorT>; + + /// Constants + static constexpr int BLOCK_THREADS = AgentFindPolicy::BLOCK_THREADS; + static constexpr int ITEMS_PER_THREAD = AgentFindPolicy::ITEMS_PER_THREAD; + static constexpr int TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD; + static constexpr int VECTOR_LOAD_LENGTH = CUB_MIN(ITEMS_PER_THREAD, AgentFindPolicy::VECTOR_LOAD_LENGTH); + + // Can vectorize according to the policy if the input iterator is a native + // pointer to a primitive type + static constexpr bool ATTEMPT_VECTORIZATION = + (VECTOR_LOAD_LENGTH > 1) && (ITEMS_PER_THREAD % VECTOR_LOAD_LENGTH == 0) + && (::cuda::std::is_pointer::value) && Traits::PRIMITIVE; + + static constexpr CacheLoadModifier LOAD_MODIFIER = AgentFindPolicy::LOAD_MODIFIER; + + /// Shared memory type required by this thread block + using _TempStorage = OffsetT; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> + {}; + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + _TempStorage& sresult; ///< Reference to temp_storage + InputIteratorT d_in; ///< Input data to find + // OutputIteratorT d_out; + // OffsetT num_items; + // OffsetT* value_temp_storage; + // WrappedInputIteratorT d_wrapped_in; ///< Wrapped input data to find + ScanOpT scan_op; ///< Binary reduction operator + + //--------------------------------------------------------------------- + // Utility + //--------------------------------------------------------------------- + + template + static _CCCL_DEVICE _CCCL_FORCEINLINE bool + IsAlignedAndFullTile(T* d_in, int tile_offset, int tile_size, OffsetT num_items, Int2Type /*CAN_VECTORIZE*/) + { + /// Create an AgentFindIf and extract these two as type member in the encapsulating struct + using InputT = T; + using VectorT = typename CubVector::Type; + /// + const bool full_tile = (tile_offset + tile_size) <= num_items; + const bool is_aligned = reinterpret_cast<::cuda::std::uintptr_t>(d_in) % uintptr_t{sizeof(VectorT)} == 0; + return full_tile && is_aligned; + } + + template + static _CCCL_DEVICE _CCCL_FORCEINLINE bool IsAlignedAndFullTile( + Iterator /*d_in*/, + int /*tile_offset*/, + int /*tile_size*/, + std::size_t /*num_items*/, + Int2Type /*CAN_VECTORIZE*/) + { + return false; + } + + //--------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------- + + /** + * @brief Constructor + * @param sresult Reference to temp_storage + * @param d_in Input data to search + * @param scan_op Binary scan operator + */ + _CCCL_DEVICE _CCCL_FORCEINLINE AgentFind(TempStorage& sresult, InputIteratorT d_in, ScanOpT scan_op) + : sresult(sresult.Alias()) + , d_in(d_in) + , scan_op(scan_op) + {} + + //--------------------------------------------------------------------- + // Tile consumption + //--------------------------------------------------------------------- + + template + __device__ void + ConsumeTile(int tile_offset, Pred pred, OffsetT* result, OffsetT num_items, Int2Type /*CAN_VECTORIZE*/) + { + using InputT = cub::detail::value_t; + using VectorT = typename CubVector::Type; + + __shared__ OffsetT block_result; + + if (threadIdx.x == 0) + { + block_result = num_items; + } + + __syncthreads(); + + constexpr int NUMBER_OF_VECTORS = ITEMS_PER_THREAD / VECTOR_LOAD_LENGTH; + //// vectorized loads begin + const InputT* d_in_unqualified = d_in + tile_offset + (threadIdx.x * VECTOR_LOAD_LENGTH); + + cub::CacheModifiedInputIterator d_vec_in( + reinterpret_cast(d_in_unqualified)); + + InputT input_items[ITEMS_PER_THREAD]; + VectorT* vec_items = reinterpret_cast(input_items); + +#pragma unroll + for (int i = 0; i < NUMBER_OF_VECTORS; ++i) + { + vec_items[i] = d_vec_in[BLOCK_THREADS * i]; + } + //// vectorized loads end + + bool found = false; + for (int i = 0; i < ITEMS_PER_THREAD; ++i) + { + OffsetT nth_vector_of_thread = i / VECTOR_LOAD_LENGTH; + OffsetT element_in_vector = i % VECTOR_LOAD_LENGTH; + OffsetT vector_of_tile = nth_vector_of_thread * BLOCK_THREADS + threadIdx.x; + + OffsetT index = tile_offset + vector_of_tile * VECTOR_LOAD_LENGTH + element_in_vector; + + if (index < num_items) + { + if (pred(input_items[i])) + { + found = true; + atomicMin(&block_result, index); + break; // every thread goes over multiple elements per thread + // for every tile. If a thread finds a local minimum it doesn't + // need to proceed further (inner early exit). + } + } + } + + if (syncthreads_or(found)) + { + if (threadIdx.x == 0) + { + if (block_result < num_items) + { + atomicMin(result, block_result); + } + } + } + } + + template + __device__ void + ConsumeTile(int tile_offset, Pred pred, OffsetT* result, OffsetT num_items, Int2Type /*CAN_VECTORIZE*/) + { + __shared__ int block_result; + + if (threadIdx.x == 0) + { + block_result = num_items; + } + + __syncthreads(); + + bool found = false; + for (int i = 0; i < ITEMS_PER_THREAD; ++i) + { + auto index = tile_offset + threadIdx.x + i * blockDim.x; + + if (index < num_items) + { + if (pred(*(d_in + index))) + { + found = true; + atomicMin(&block_result, index); + break; + } + } + } + if (syncthreads_or(found)) + { + if (threadIdx.x == 0) + { + if (block_result < num_items) + { + atomicMin(result, block_result); + } + } + } + } + + __device__ void Process(OffsetT* value_temp_storage, OffsetT num_items) + { + for (int tile_offset = blockIdx.x * TILE_ITEMS; tile_offset < num_items; tile_offset += TILE_ITEMS * gridDim.x) + { + // Only one thread reads atomically and propagates it to the + // the other threads of the block through shared memory + if (threadIdx.x == 0) + { + sresult = atomicAdd(value_temp_storage, 0); + } + __syncthreads(); + + // early exit + if (sresult < tile_offset) + { + return; + } + + IsAlignedAndFullTile(d_in, tile_offset, TILE_ITEMS, num_items, Int2Type()) + ? ConsumeTile(tile_offset, scan_op, value_temp_storage, num_items, Int2Type()) + : ConsumeTile(tile_offset, scan_op, value_temp_storage, num_items, Int2Type()); + } + } +}; + +CUB_NAMESPACE_END diff --git a/cub/cub/device/device_find_if.cuh b/cub/cub/device/device_find_if.cuh new file mode 100644 index 00000000000..ee30f5c4392 --- /dev/null +++ b/cub/cub/device/device_find_if.cuh @@ -0,0 +1,84 @@ +/****************************************************************************** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +//! @file +//! cub::DeviceScan provides device-wide, parallel operations for computing a prefix scan across a sequence of data +//! items residing within device-accessible memory. + +#pragma once + +#include + +#include "cub/util_type.cuh" +#include "cuda/std/__cccl/dialect.h" +#include "cuda/std/__memory/pointer_traits.h" +#include "cuda/std/__utility/declval.h" +#include "device_launch_parameters.h" + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include +#include +#include +#include +#include +#include +#include + +#include + +CUB_NAMESPACE_BEGIN + +struct DeviceFind +{ + template + CUB_RUNTIME_FUNCTION static cudaError_t FindIf( + void* d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + NumItemsT num_items, + cudaStream_t stream = 0) + { + // CUB_DETAIL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DeviceFind::FindIf"); + + // Signed integer type for global offsets + using OffsetT = detail::choose_offset_t; + + return DispatchFind::Dispatch( + d_temp_storage, temp_storage_bytes, d_in, d_out, static_cast(num_items), scan_op, stream); + } +}; + +CUB_NAMESPACE_END diff --git a/cub/cub/device/dispatch/dispatch_find.cuh b/cub/cub/device/dispatch/dispatch_find.cuh new file mode 100644 index 00000000000..52bf5f6efe8 --- /dev/null +++ b/cub/cub/device/dispatch/dispatch_find.cuh @@ -0,0 +1,331 @@ +/****************************************************************************** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * @file cub::DeviceFind provides device-wide, parallel operations for + * computing search across a sequence of data items residing within + * device-accessible memory. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +#include + +#include "cub/util_type.cuh" + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include + +CUB_NAMESPACE_BEGIN + +template +__global__ void write_final_result_in_output_iterator_already(ValueType* d_temp_storage, OutputIteratorT d_out) +{ + *d_out = *d_temp_storage; +} + +template +__global__ void cuda_mem_set_async_dtemp_storage(ValueType* d_temp_storage, NumItemsT num_items) +{ + *d_temp_storage = num_items; +} + +/****************************************************************************** + * Kernel entry points + *****************************************************************************/ + +/** ENTER DOCUMENTATION */ +template +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::FindPolicy::BLOCK_THREADS)) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceFindKernel( + InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items, OffsetT* value_temp_storage, ScanOpT scan_op) +{ + using AgentFindT = + AgentFind; + + __shared__ typename AgentFindT::TempStorage sresult; + // __shared__ int temp_storage; + // Process tiles + AgentFindT agent(sresult, d_in, scan_op); // Seems like sresult can be defined and initialized in agent_find.cuh + // directly without having to pass it here as an argument. + + agent.Process(value_temp_storage, num_items); +} + +template +struct DeviceFindPolicy +{ + //--------------------------------------------------------------------------- + // Architecture-specific tuning policies + //--------------------------------------------------------------------------- + + /// SM30 + struct Policy300 : ChainedPolicy<300, Policy300, Policy300> + { + static constexpr int threads_per_block = 128; + static constexpr int items_per_thread = 16; + static constexpr int items_per_vec_load = 4; + + // FindPolicy (GTX670: 154.0 @ 48M 4B items) + using FindPolicy = + AgentFindPolicy, items_per_vec_load, LOAD_LDG>; + + // // SingleTilePolicy + // using SingleTilePolicy = FindPolicy; + }; + + using MaxPolicy = Policy300; +}; + +template > +struct DispatchFind : SelectedPolicy +{ + //--------------------------------------------------------------------------- + // Problem state + //--------------------------------------------------------------------------- + + /// Device-accessible allocation of temporary storage. When `nullptr`, the + /// required allocation size is written to `temp_storage_bytes` and no work + /// is done. + void* d_temp_storage; + + /// Reference to size in bytes of `d_temp_storage` allocation + size_t& temp_storage_bytes; + + /// Pointer to the input sequence of data items + InputIteratorT d_in; + + /// Pointer to the output aggregate + OutputIteratorT d_out; + + /// Total number of input items (i.e., length of `d_in`) + OffsetT num_items; + + /// Unary search functor + ScanOpT scan_op; + + /// CUDA stream to launch kernels within. Default is stream0. + cudaStream_t stream; + + int ptx_version; + + //--------------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------------- + + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE DispatchFind( + void* d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + OffsetT num_items, + ScanOpT scan_op, + cudaStream_t stream, + int ptx_version) + : d_temp_storage(d_temp_storage) + , temp_storage_bytes(temp_storage_bytes) + , d_in(d_in) + , d_out(d_out) + , num_items(num_items) + , scan_op(scan_op) + , stream(stream) + , ptx_version(ptx_version) + {} + + //--------------------------------------------------------------------------- + // Normal problem size invocation + //--------------------------------------------------------------------------- + + //--------------------------------------------------------------------------- + // Chained policy invocation + //--------------------------------------------------------------------------- + + /// Invocation + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke(FindKernel find_kernel) + { + using Policy = typename ActivePolicyT::FindPolicy; + + cudaError error = cudaSuccess; + do + { + // Number of input tiles + int tile_size = Policy::BLOCK_THREADS * Policy::ITEMS_PER_THREAD; + int num_tiles = static_cast(::cuda::ceil_div(num_items, tile_size)); + + // Get device ordinal + int device_ordinal; + error = CubDebug(cudaGetDevice(&device_ordinal)); + if (cudaSuccess != error) + { + break; + } + + // Get SM count + int sm_count; + error = CubDebug(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal)); + if (cudaSuccess != error) + { + break; + } + + int find_if_sm_occupancy; + error = CubDebug(cub::MaxSmOccupancy(find_if_sm_occupancy, find_kernel, Policy::BLOCK_THREADS)); + if (cudaSuccess != error) + { + break; + } + + int findif_device_occupancy = find_if_sm_occupancy * sm_count; + int max_blocks = findif_device_occupancy; // no * CUB_SUBSCRIPTION_FACTOR(0) because max_blocks gets too big + int findif_grid_size = CUB_MIN(num_tiles, max_blocks); + + // Temporary storage allocation requirements + void* allocations[1] = {}; + size_t allocation_sizes[1] = {sizeof(int)}; + // Alias the temporary allocations from the single storage blob (or + // compute the necessary size of the blob) + error = CubDebug(AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes)); + if (cudaSuccess != error) + { + break; + } + + OffsetT* value_temp_storage = static_cast(allocations[0]); + + if (d_temp_storage == nullptr) + { + // Return if the caller is simply requesting the size of the storage + // allocation + return cudaSuccess; + } + + // use d_temp_storage as the intermediate device result + // to read and write from. Then store the final result in the output iterator. + + cuda_mem_set_async_dtemp_storage<<<1, 1>>>(value_temp_storage, num_items); + + // Invoke FindIfKernel + THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( + findif_grid_size, ActivePolicyT::FindPolicy::BLOCK_THREADS, 0, stream) + .doit(find_kernel, d_in, d_out, num_items, value_temp_storage, scan_op); + + write_final_result_in_output_iterator_already<<<1, 1>>>(value_temp_storage, d_out); + + // Check for failure to launch + error = CubDebug(cudaPeekAtLastError()); + if (cudaSuccess != error) + { + break; + } + + // Sync the stream if specified to flush runtime errors + error = CubDebug(detail::DebugSyncStream(stream)); + if (cudaSuccess != error) + { + break; + } + + } while (0); + return error; + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke() + { + using MaxPolicyT = typename SelectedPolicy::MaxPolicy; + return Invoke( + DeviceFindKernel); // include the surrounding two + // init and write back kernels + // here. + } + + //--------------------------------------------------------------------------- + // Dispatch entrypoints + //--------------------------------------------------------------------------- + + /** + * @brief @giannis ENTER NO DOCUMENTATION. DISPATCH LAYER IN NEW ALGOS NOT EXPOSED + * + // private: ????? */ + + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch( + void* d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + OffsetT num_items, + ScanOpT scan_op, + cudaStream_t stream) + { + using MaxPolicyT = typename DispatchFind::MaxPolicy; + + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version = 0; + error = CubDebug(PtxVersion(ptx_version)); + if (cudaSuccess != error) + { + break; + } + // Create dispatch functor + DispatchFind dispatch(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, scan_op, stream, ptx_version); + + // Dispatch to chained policy + error = CubDebug(MaxPolicyT::Invoke(ptx_version, dispatch)); // @giannis how is Invoke() been called since it + // takes no arguments + if (cudaSuccess != error) + { + break; + } + } while (0); + + return error; + } +}; + +CUB_NAMESPACE_END diff --git a/cub/test/catch2_test_device_find_if.cu b/cub/test/catch2_test_device_find_if.cu new file mode 100644 index 00000000000..524f7c04c41 --- /dev/null +++ b/cub/test/catch2_test_device_find_if.cu @@ -0,0 +1,244 @@ +/****************************************************************************** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include "insert_nested_NVTX_range_guard.h" +// above header needs to be included first + +#include +#include +#include +#include + +#include +#include + +#include "c2h/custom_type.cuh" +#include "catch2_test_device_reduce.cuh" +#include "catch2_test_launch_helper.h" +#include "thrust/detail/raw_pointer_cast.h" +#include +#include + +// %PARAM% TEST_LAUNCH lid 0:1 + +// DECLARE_LAUNCH_WRAPPER(cub::DeviceFind::FindIf, device_findif); + +// List of types to test +using custom_t = + c2h::custom_type_t; + +using full_type_list = c2h::type_list, type_pair>; +// clang-format on + +enum class gen_data_t : int +{ + /// Uniform random data generation + GEN_TYPE_RANDOM, + /// Constant value as input data + GEN_TYPE_CONST +}; + +template +void compute_find_if_reference(InputIt first, InputIt last, OutputIt& result, BinaryOp op) +{ + auto pos = thrust::find_if(first, last, op); + result = pos - first; +} + +template +struct equals +{ + T val; + + __device__ __host__ bool operator()(T i) + { + return i == val; + } +}; + +C2H_TEST("Device find_if works", "[device]", full_type_list) +{ + using params = params_t; + using input_t = typename params::item_t; + using output_t = typename params::output_t; + using offset_t = output_t; + + constexpr offset_t min_items = 1; + constexpr offset_t max_items = std::numeric_limits::max() / 5; // make test run faster + + // Generate the input sizes to test for + const offset_t num_items = GENERATE_COPY( + take(3, random(min_items, max_items)), + values({ + min_items, + max_items, + })); + + // Input data generation to test + const gen_data_t data_gen_mode = GENERATE_COPY(gen_data_t::GEN_TYPE_RANDOM, gen_data_t::GEN_TYPE_CONST); + + // Generate input data + c2h::device_vector in_items(num_items); + if (data_gen_mode == gen_data_t::GEN_TYPE_RANDOM) + { + c2h::gen(C2H_SEED(2), in_items); + } + else + { + input_t default_constant{}; + init_default_constant(default_constant); + thrust::fill(c2h::device_policy, in_items.begin(), in_items.end(), default_constant); + } + auto d_in_it = thrust::raw_pointer_cast(in_items.data()); + + using op_t = equals; + input_t val_to_find = GENERATE_COPY(take(1, random(min_items, max_items))); + + SECTION("Generic find if case") + { + // Prepare verification data + c2h::host_vector host_items(in_items); + c2h::host_vector expected_result(1); + compute_find_if_reference(host_items.begin(), host_items.end(), expected_result[0], op_t{val_to_find}); + + void* d_temp_storage = nullptr; + size_t temp_storage_bytes{}; + + // Run test + c2h::device_vector out_result(1); + output_t* d_out_it = thrust::raw_pointer_cast(out_result.data()); + + cub::DeviceFind::FindIf( + d_temp_storage, temp_storage_bytes, unwrap_it(d_in_it), unwrap_it(d_out_it), op_t{val_to_find}, num_items); + + thrust::device_vector temp_storage(temp_storage_bytes); + d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); + + cub::DeviceFind::FindIf( + d_temp_storage, temp_storage_bytes, unwrap_it(d_in_it), unwrap_it(d_out_it), op_t{val_to_find}, num_items); + + // Verify result + REQUIRE(expected_result == out_result); + } + + SECTION("find_if works with non raw pointers - .begin() iterator") + { + // Prepare verification data + c2h::host_vector host_items(in_items); + c2h::host_vector expected_result(1); + compute_find_if_reference(host_items.begin(), host_items.end(), expected_result[0], op_t{val_to_find}); + + void* d_temp_storage = nullptr; + size_t temp_storage_bytes{}; + + // Run test + c2h::device_vector out_result(1); + + cub::DeviceFind::FindIf( + d_temp_storage, temp_storage_bytes, in_items.begin(), out_result.begin(), op_t{val_to_find}, num_items); + + thrust::device_vector temp_storage(temp_storage_bytes); + d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); + + cub::DeviceFind::FindIf( + d_temp_storage, temp_storage_bytes, in_items.begin(), out_result.begin(), op_t{val_to_find}, num_items); + + // Verify result + REQUIRE(expected_result == out_result); + } + + SECTION("find_if works for unaligned input") + { + for (int offset = 1; offset < 4; ++offset) + { + if (num_items - offset > 0) + { + // Prepare verification data + c2h::host_vector host_items(in_items); + c2h::host_vector expected_result(1); + compute_find_if_reference(host_items.begin() + offset, host_items.end(), expected_result[0], op_t{val_to_find}); + + void* d_temp_storage = nullptr; + size_t temp_storage_bytes{}; + + // Run test + c2h::device_vector out_result(1); + auto d_out_it = thrust::raw_pointer_cast(out_result.data()); + + cub::DeviceFind::FindIf( + d_temp_storage, + temp_storage_bytes, + unwrap_it(d_in_it + offset), + unwrap_it(d_out_it), + op_t{val_to_find}, + num_items - offset); + + thrust::device_vector temp_storage(temp_storage_bytes); + d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); + + cub::DeviceFind::FindIf( + d_temp_storage, + temp_storage_bytes, + unwrap_it(d_in_it + offset), + unwrap_it(d_out_it), + op_t{val_to_find}, + num_items - offset); + + // Verify result + REQUIRE(expected_result == out_result); + } + } + } + + SECTION("find_if works with non primitive iterator") + { + // Prepare verification data + auto it = thrust::make_counting_iterator(0); // non-primitive iterator + c2h::host_vector expected_result(1); + compute_find_if_reference(it, it + num_items, expected_result[0], op_t{val_to_find}); + + void* d_temp_storage = nullptr; + size_t temp_storage_bytes{}; + + // Run test + c2h::device_vector out_result(1); + auto d_out_it = thrust::raw_pointer_cast(out_result.data()); + + cub::DeviceFind::FindIf(d_temp_storage, temp_storage_bytes, it, unwrap_it(d_out_it), op_t{val_to_find}, num_items); + + thrust::device_vector temp_storage(temp_storage_bytes); + d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); + + cub::DeviceFind::FindIf(d_temp_storage, temp_storage_bytes, it, unwrap_it(d_out_it), op_t{val_to_find}, num_items); + + // Verify result + REQUIRE(expected_result == out_result); + } +} diff --git a/thrust/benchmarks/bench/count_if/basic.cu b/thrust/benchmarks/bench/count_if/basic.cu new file mode 100644 index 00000000000..981e7c7610e --- /dev/null +++ b/thrust/benchmarks/bench/count_if/basic.cu @@ -0,0 +1,68 @@ +/****************************************************************************** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include +#include + +#include "nvbench_helper.cuh" + +template +struct equals +{ + T val; + + __device__ __host__ bool operator()(T i) + { + return i == val; + } +}; + +template +void count_if(nvbench::state& state, nvbench::type_list) +{ + T val = 1; + // set up input + const auto elements = static_cast(state.get_int64("Elements")); + const auto common_prefix = state.get_float64("MismatchAt"); + const auto mismatch_point = elements * common_prefix; + + thrust::device_vector dinput(elements, 0); + thrust::fill(dinput.begin() + mismatch_point, dinput.end(), val); + /// + + caching_allocator_t alloc; + thrust::count_if(policy(alloc), dinput.begin(), dinput.end(), equals{val}); + + state.exec(nvbench::exec_tag::no_batch | nvbench::exec_tag::sync, [&](nvbench::launch& launch) { + thrust::count_if(policy(alloc, launch), dinput.begin(), dinput.end(), equals{val}); + }); +} + +NVBENCH_BENCH_TYPES(count_if, NVBENCH_TYPE_AXES(fundamental_types)) + .set_name("thrust::count_if") + .add_int64_power_of_two_axis("Elements", nvbench::range(16, 28, 4)) + .add_float64_axis("MismatchAt", std::vector{1.0, 0.5, 0.0}); diff --git a/thrust/benchmarks/bench/find_if/basic.cu b/thrust/benchmarks/bench/find_if/basic.cu new file mode 100644 index 00000000000..362619f29e0 --- /dev/null +++ b/thrust/benchmarks/bench/find_if/basic.cu @@ -0,0 +1,68 @@ +/****************************************************************************** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include +#include + +#include "nvbench_helper.cuh" + +template +struct equals +{ + T val; + + __device__ __host__ bool operator()(T i) + { + return i == val; + } +}; + +template +void find_if(nvbench::state& state, nvbench::type_list) +{ + T val = 1; + // set up input + const auto elements = static_cast(state.get_int64("Elements")); + const auto common_prefix = state.get_float64("MismatchAt"); + const auto mismatch_point = elements * common_prefix; + + thrust::device_vector dinput(elements, 0); + thrust::fill(dinput.begin() + mismatch_point, dinput.end(), val); + /// + + caching_allocator_t alloc; + thrust::find_if(policy(alloc), dinput.begin(), dinput.end(), equals{val}); + + state.exec(nvbench::exec_tag::no_batch | nvbench::exec_tag::sync, [&](nvbench::launch& launch) { + thrust::find_if(policy(alloc, launch), dinput.begin(), dinput.end(), equals{val}); + }); +} + +NVBENCH_BENCH_TYPES(find_if, NVBENCH_TYPE_AXES(fundamental_types)) + .set_name("thrust::find_if") + .add_int64_power_of_two_axis("Elements", nvbench::range(16, 28, 4)) + .add_float64_axis("MismatchAt", std::vector{1.0, 0.5, 0.0});