Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Illegal memory access on trying to use DeviceReduce::Sum() to count number of non-zeros #726

Closed
alexsamardzic opened this issue Jun 25, 2023 · 2 comments

Comments

@alexsamardzic
Copy link

Below is slightly modified example_device_reduce.cu example from CUB distribution, demonstrating an illegal memory access issue when trying to combine TransformInputIterator with DeviceReduce::Sum() to calculate number of non-zero elements in given array.

Reproducible example
/******************************************************************************
 * Copyright (c) 2011, Duane Merrill.  All rights reserved.
 * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

/******************************************************************************
 * Simple example of DeviceReduce::Sum().
 *
 * Sums an array of int keys.
 *
 * To compile using the command line:
 *   nvcc -arch=sm_XX example_device_reduce.cu -I../.. -lcudart -O3
 *
 ******************************************************************************/

// Ensure printing of CUDA runtime errors to console
#define CUB_STDERR

#include <stdio.h>

#include <cub/util_allocator.cuh>
#include <cub/device/device_reduce.cuh>
#include <cub/iterator/transform_input_iterator.cuh>

#include "../../test/test_util.h"

using namespace cub;


//---------------------------------------------------------------------
// Globals, constants and typedefs
//---------------------------------------------------------------------

CachingDeviceAllocator  g_allocator(true);  // Caching allocator for device memory


//---------------------------------------------------------------------
// Test generation
//---------------------------------------------------------------------

/**
 * Initialize problem
 */
void Initialize(
    int   *h_in,
    int     num_items)
{
    for (int i = 0; i < num_items; ++i)
        h_in[i] = i;
}

/**
 * Compute solution
 */
void Solve(
    int           *h_in,
    int           &h_reference,
    int             num_items)
{
    for (int i = 0; i < num_items; ++i)
    {
        if (i == 0)
            h_reference = h_in[0] != 0;
        else
            h_reference += h_in[i] != 0;
    }
}

template<typename T>
struct NonZeroOp
{
    __host__ __device__ __forceinline__ bool operator()(const T& a) const {
      return (a!=T(0));
    }
};

//---------------------------------------------------------------------
// Main
//---------------------------------------------------------------------

/**
 * Main
 */
int main(int argc, char** argv)
{
    int num_items = 46000 * 46000;

    // Initialize command line
    CommandLineArgs args(argc, argv);

    // Print usage
    if (args.CheckCmdLineFlag("help"))
    {
        printf("%s "
            "[--device=<device-id>] "
            "\n", argv[0]);
        exit(0);
    }

    // Initialize device
    CubDebugExit(args.DeviceInit());

    printf("cub::DeviceReduce::Sum() %d items (%d-byte elements)\n",
        num_items, (int) sizeof(int));
    fflush(stdout);

    // Allocate host arrays
    int* h_in = new int[num_items];
    int  h_reference;

    // Initialize problem and solution
    Initialize(h_in, num_items);
    Solve(h_in, h_reference, num_items);

    // Allocate problem device arrays
    int *d_in = NULL;
    CubDebugExit(g_allocator.DeviceAllocate((void**)&d_in, sizeof(int) * num_items));

    // Initialize device input
    CubDebugExit(cudaMemcpy(d_in, h_in, sizeof(int) * num_items, cudaMemcpyHostToDevice));

    // Allocate device output array
    int *d_out = NULL;
    CubDebugExit(g_allocator.DeviceAllocate((void**)&d_out, sizeof(int) * 1));

    cub::TransformInputIterator<bool, NonZeroOp<int>, const int*> iter(d_in, NonZeroOp<int>());

    // Request and allocate temporary storage
    void            *d_temp_storage = NULL;
    size_t          temp_storage_bytes = 0;
    CubDebugExit(DeviceReduce::Sum(nullptr, temp_storage_bytes, iter, d_out, num_items));
    CubDebugExit(g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes));
    CubDebugExit(cudaDeviceSynchronize());
    printf("CUDA error (before DeviceReduce::Sum) = %d\n", cudaGetLastError());

    // Run
    CubDebugExit(DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, iter, d_out, num_items));

    CubDebugExit(cudaDeviceSynchronize());
    printf("CUDA error (after DeviceReduce::Sum) = %d\n", cudaGetLastError());

    // Check for correctness (and display results, if specified)
    int compare = CompareDeviceResults(&h_reference, d_out, 1, false, false);
    printf("\t%s", compare ? "FAIL" : "PASS");
    AssertEquals(0, compare);

    // Cleanup
    if (h_in) delete[] h_in;
    if (d_in) CubDebugExit(g_allocator.DeviceFree(d_in));
    if (d_out) CubDebugExit(g_allocator.DeviceFree(d_out));
    if (d_temp_storage) CubDebugExit(g_allocator.DeviceFree(d_temp_storage));

    printf("\n\n");

    return 0;
}

Testted with CUB 1.16.0, on an A100 GPU. The problem is not there in case d_in used instead of iter in DeviceReduce::Sum calls (this is what original code in the example is doing). The problem is also not there in case number of elements in the input array decreased (note that here this number is close to INT_MAX).

@gevtushenko
Copy link
Collaborator

Hello @alexsamardzic and thank you for reporting the issue!

I can reproduce it with CUB 1.16.0 on A100. The issue is not there for recent versions of CUB. There's a chance that it was already addressed by #592 or #589. Could you please verify the latest release of CUB?

@alexsamardzic
Copy link
Author

Thanks, I confirm that with the CUB latest main the example above works fine.

@github-project-automation github-project-automation bot moved this from Todo to Done in CCCL Jun 26, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
Archived in project
Development

No branches or pull requests

2 participants