Skip to content

Commit

Permalink
Replace nvrtc/jitlink steps with a sugary state-machine to ease kerne…
Browse files Browse the repository at this point in the history
…l development
  • Loading branch information
wmaxey committed Oct 7, 2024
1 parent ed5dfca commit 6f61753
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 168 deletions.
98 changes: 29 additions & 69 deletions c/src/for.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
#include <cub/util_device.cuh>

#include <format>
#include <iostream>
#include <type_traits>

#include <cccl/c/for.h>
#include <cccl/c/types.h>
#include <for/for_op_helper.h>
#include <nvJitLink.h>
#include <nvrtc.h>
#include <nvrtc/command_list.h>
#include <util/errors.h>
#include <util/types.h>

Expand Down Expand Up @@ -79,14 +77,13 @@ extern "C" CCCL_C_API CUresult cccl_device_for_build(
{
CUresult error = CUDA_SUCCESS;

if (d_data.type == cccl_iterator_kind_t::iterator)
{
throw std::runtime_error(std::string("Iterators are unsupported in for_each currently"));
}

try
{
nvrtcProgram prog{};
if (d_data.type == cccl_iterator_kind_t::iterator)
{
throw std::runtime_error(std::string("Iterators are unsupported in for_each currently"));
}

const char* name = "test";

const int cc = cc_major * 10 + cc_minor;
Expand All @@ -96,86 +93,49 @@ extern "C" CCCL_C_API CUresult cccl_device_for_build(
const std::string for_kernel_name = get_device_for_kernel_name();
const std::string device_for_kernel = get_for_kernel(op, d_data);

check(nvrtcCreateProgram(&prog, device_for_kernel.c_str(), name, 0, nullptr, nullptr));

check(nvrtcAddNameExpression(prog, for_kernel_name.c_str()));

const std::string arch = std::format("-arch=sm_{0}{1}", cc_major, cc_minor);

constexpr int num_args = 7;
const char* args[] = {arch.c_str(), cub_path, thrust_path, libcudacxx_path, ctk_path, "-rdc=true", "-dlto"};
constexpr size_t num_args = 7;
const char* args[num_args] = {arch.c_str(), cub_path, thrust_path, libcudacxx_path, ctk_path, "-rdc=true", "-dlto"};

std::size_t log_size{};
nvrtcResult compile_result = nvrtcCompileProgram(prog, num_args, args);
constexpr size_t num_lto_args = 2;
const char* lopts[num_lto_args] = {"-lto", arch.c_str()};

check(nvrtcGetProgramLogSize(prog, &log_size));
std::unique_ptr<char[]> log{new char[log_size]};
check(nvrtcGetProgramLog(prog, log.get()));
std::string lowered_name;

if (log_size > 1)
{
std::cerr << log.get() << std::endl;
}
auto cl =
make_nvrtc_command_list()
.add_program(nvrtc_translation_unit{device_for_kernel, name})
.add_expression({for_kernel_name})
.compile_program({args, num_args})
.get_name({for_kernel_name, lowered_name})
.cleanup_program()
.add_link({(void*) op.ltoir, (size_t) op.ltoir_size});

std::string for_kernel_lowered_name;
{
const char* for_kernel_lowered_name_temp;
check(nvrtcGetLoweredName(prog, for_kernel_name.c_str(), &for_kernel_lowered_name_temp));
for_kernel_lowered_name = for_kernel_lowered_name_temp;
}

check(compile_result);

std::size_t ltoir_size{};
check(nvrtcGetLTOIRSize(prog, &ltoir_size));
std::unique_ptr<char[]> ltoir{new char[ltoir_size]};
check(nvrtcGetLTOIR(prog, ltoir.get()));
check(nvrtcDestroyProgram(&prog));
nvrtc_cubin result{};

nvJitLinkHandle handle;
const char* lopts[] = {"-lto", arch.c_str()};

check(nvJitLinkCreate(&handle, 2, lopts));
check(nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, ltoir.get(), ltoir_size, name));
check(nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, op.ltoir, op.ltoir_size, name));
if (cccl_iterator_kind_t::iterator == d_data.type)
{
check(nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, d_data.advance.ltoir, d_data.advance.ltoir_size, name));
check(
nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, d_data.dereference.ltoir, d_data.dereference.ltoir_size, name));
result = cl.add_link({(void*) d_data.advance.ltoir, (size_t) d_data.advance.ltoir_size})
.add_link({(void*) d_data.dereference.ltoir, (size_t) d_data.dereference.ltoir_size})
.finalize_program(num_lto_args, lopts);
}

auto jitlink_error = nvJitLinkComplete(handle);

check(nvJitLinkGetErrorLogSize(handle, &log_size));
std::unique_ptr<char[]> jitlinklog{new char[log_size]};
check(nvJitLinkGetErrorLog(handle, jitlinklog.get()));

if (log_size > 1)
else
{
std::cerr << jitlinklog.get() << std::endl;
result = cl.finalize_program(num_lto_args, lopts);
}

check(jitlink_error);

std::size_t cubin_size{};
check(nvJitLinkGetLinkedCubinSize(handle, &cubin_size));
std::unique_ptr<char[]> cubin{new char[cubin_size]};
check(nvJitLinkGetLinkedCubin(handle, cubin.get()));
check(nvJitLinkDestroy(&handle));

cuLibraryLoadData(&build->library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0);
check(cuLibraryGetKernel(&build->static_kernel, build->library, for_kernel_lowered_name.c_str()));
cuLibraryLoadData(&build->library, result.cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0);
check(cuLibraryGetKernel(&build->static_kernel, build->library, lowered_name.c_str()));

build->cc = cc;
build->cubin = cubin.release();
build->cubin_size = cubin_size;
build->cubin = result.cubin.release();
build->cubin_size = result.size;
}
catch (...)
{
error = CUDA_ERROR_UNKNOWN;
}

return error;
}

Expand Down
15 changes: 1 addition & 14 deletions c/src/for/for_op_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
#pragma once

#include <cstdlib>
#include <memory>
#include <string>
#include <variant>

#include <cccl/c/types.h>
#include <util/memory.h>

// For each kernel accepts a user operator that contains both iterator and user operator state
// This declaration is used as blueprint for aligned_storage, but is only *valid* in the generated NVRTC program.
Expand All @@ -26,19 +26,6 @@ struct for_each_default
void* user_op; // A pointer for user data
};

struct unique_free_void
{
inline void operator()(void* p)
{
if (p)
{
free(p);
}
}
};

using unique_void = std::unique_ptr<void, unique_free_void>;

struct for_each_kernel_state
{
std::variant<for_each_default, unique_void> for_each_arg;
Expand Down
Loading

0 comments on commit 6f61753

Please sign in to comment.