diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e3404fb8b..b569ecd64 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -28,3 +28,4 @@ add_subdirectory(replit) add_subdirectory(mpt) add_subdirectory(starcoder) add_subdirectory(sam) +add_subdirectory(plugin) diff --git a/examples/plugin/CMakeLists.txt b/examples/plugin/CMakeLists.txt new file mode 100644 index 000000000..aac1c62ad --- /dev/null +++ b/examples/plugin/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library(plugin-model STATIC model.cpp) +target_link_libraries(plugin-model PUBLIC ggml::ggml) + +add_executable(cpu-plugin cpu-plugin.cpp) +target_link_libraries(cpu-plugin plugin-model) + +if (GGML_CUBLAS) + add_executable(cuda-plugin cuda-plugin.cpp) + target_link_libraries(cuda-plugin plugin-model) +endif() diff --git a/examples/plugin/README.md b/examples/plugin/README.md new file mode 100644 index 000000000..b75db63d4 --- /dev/null +++ b/examples/plugin/README.md @@ -0,0 +1,5 @@ +# GGML Plugin + +This example showcases the use of GGML as a plugin. + +The executables demonstrate how to initialize a backend and run inference with a model whose data comes from the outside. diff --git a/examples/plugin/cpu-plugin.cpp b/examples/plugin/cpu-plugin.cpp new file mode 100644 index 000000000..548aadecf --- /dev/null +++ b/examples/plugin/cpu-plugin.cpp @@ -0,0 +1,41 @@ +#include "model.hpp" + +#include + +#include +#include + +int main() { + auto backend = ggml_backend_cpu_init(); + + std::vector weights_data; + for (int i = 0; i < 10; ++i) { + weights_data.push_back(float(i)); + } + + void* weights = weights_data.data(); + + model m(backend, weights_data.size(), GGML_TYPE_F32, weights); + + std::vector input_data; + for (size_t i = 0; i < weights_data.size(); ++i) { + input_data.push_back(float(i) / 10); + } + + std::vector output_data(input_data.size()); + + void* input = input_data.data(); + void* output = output_data.data(); + + m.compute(output, input); + + ggml_backend_free(backend); + + std::cout << "["; + for (auto o : output_data) { + std::cout << o << ", "; + } + std::cout << "]\n"; + + return 0; +} diff --git a/examples/plugin/cuda-plugin.cpp b/examples/plugin/cuda-plugin.cpp new file mode 100644 index 000000000..8f278a57c --- /dev/null +++ b/examples/plugin/cuda-plugin.cpp @@ -0,0 +1,67 @@ +#include "model.hpp" + +#include +#include + +#include +#include +#include + +int main() { + // init cuda + int device_id = 0; + cudaSetDevice(device_id); + cublasHandle_t cublas_handle = nullptr; + cublasCreate(&cublas_handle); + cudaStream_t cuda_stream = nullptr; + cudaStreamCreateWithFlags(&cuda_stream, cudaStreamNonBlocking); + + // create plugin backend + auto backend = ggml_backend_cuda_init_plugin(device_id, cublas_handle, cuda_stream); + + // init weights + std::vector weights_data; + for (int i = 0; i < 10; ++i) { + weights_data.push_back(float(i)); + } + + void* weights = nullptr; + cudaMallocAsync(&weights, data_size(weights_data), cuda_stream); + cudaMemcpyAsync(weights, weights_data.data(), data_size(weights_data), cudaMemcpyHostToDevice, cuda_stream); + + // create model with weights + model m(backend, weights_data.size(), GGML_TYPE_F32, weights); + + // init input and output data + std::vector input_data; + for (size_t i = 0; i < weights_data.size(); ++i) { + input_data.push_back(float(i) / 10); + } + + std::vector output_data(input_data.size()); + + void* input = nullptr; + cudaMallocAsync(&input, data_size(input_data), cuda_stream); + cudaMemcpyAsync(input, input_data.data(), data_size(input_data), cudaMemcpyHostToDevice, cuda_stream); + + void* output = nullptr; + cudaMallocAsync(&output, data_size(output_data), cuda_stream); + + // compute with cuda pointers + m.compute(output, input); + + // get data back from cuda pointers + cudaMemcpyAsync(output_data.data(), output, data_size(output_data), cudaMemcpyDeviceToHost, cuda_stream); + cudaStreamSynchronize(cuda_stream); + + ggml_backend_free(backend); + + // print result + std::cout << "["; + for (auto o : output_data) { + std::cout << o << ", "; + } + std::cout << "]\n"; + + return 0; +} diff --git a/examples/plugin/model.cpp b/examples/plugin/model.cpp new file mode 100644 index 000000000..8dde16c40 --- /dev/null +++ b/examples/plugin/model.cpp @@ -0,0 +1,67 @@ +#include "model.hpp" + +#include +#include + +#include + +model::model(ggml_backend_t be, int64_t s, ggml_type t, void* weights_data) + : backend(be) + , size(s) + , type(t) +{ + assert(weights_data); + static constexpr size_t numWeightTensors = sizeof(weights_t) / sizeof(ggml_tensor*); + wctx = ggml_init({ + /*.mem_size =*/ ggml_tensor_overhead() * numWeightTensors, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }); + weights.w = ggml_new_tensor_1d(wctx, type, size); + wbuf = ggml_backend_alloc_buffer(backend, 0); + auto wallocr = ggml_allocr_new_from_buffer(wbuf); + ggml_allocr_set_tensor_external_data(wallocr, weights.w, weights_data, 0); + ggml_allocr_free(wallocr); + + cbuf = ggml_backend_alloc_buffer(backend, 0); + callocr = ggml_allocr_new_from_buffer(cbuf); +} + +model::~model() { + ggml_free(wctx); + ggml_backend_buffer_free(wbuf); + ggml_allocr_free(callocr); + ggml_backend_buffer_free(cbuf); +} + +struct io_tensors { + ggml_tensor* input = nullptr; + ggml_tensor* output = nullptr; +}; + +void model::compute(void* output, void* input) { + assert(input); + assert(output); + + static constexpr size_t num_io_tensors = sizeof(io_tensors) / sizeof(ggml_tensor*); + auto cctx = ggml_init({ + /*.mem_size =*/ ggml_tensor_overhead() * num_io_tensors + ggml_graph_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }); + + io_tensors io = {}; + io.input = ggml_new_tensor_1d(cctx, type, size); + io.output = ggml_add(cctx, io.input, weights.w); + + ggml_allocr_set_tensor_external_data(callocr, io.input, input, 0); + ggml_allocr_set_tensor_external_data(callocr, io.output, output, 0); + + auto graph = ggml_new_graph(cctx); + ggml_build_forward_expand(graph, io.output); + + ggml_backend_graph_compute(backend, graph); + + ggml_allocr_reset(callocr); + ggml_free(cctx); +} diff --git a/examples/plugin/model.hpp b/examples/plugin/model.hpp new file mode 100644 index 000000000..fee63d9c0 --- /dev/null +++ b/examples/plugin/model.hpp @@ -0,0 +1,35 @@ +#pragma once +#include +#include + +typedef struct ggml_backend* ggml_backend_t; +struct ggml_backend_buffer; +struct ggml_allocr; + +struct model { + struct weights_t { + ggml_tensor* w = nullptr; + } weights; + + ggml_backend_t backend = nullptr; + + ggml_context* wctx = nullptr; + ggml_backend_buffer* wbuf = nullptr; // weights buffer + + ggml_backend_buffer* cbuf = nullptr; // compute buffer + ggml_allocr* callocr = nullptr; // compute allocator + + const int64_t size; + const ggml_type type; + + model(ggml_backend_t be, int64_t s, ggml_type t, void* weights_data); + ~model(); + + void compute(void* output, void* input); +}; + +// util +template +size_t data_size(const Vec& vec) { + return vec.size() * sizeof(typename Vec::value_type); +} diff --git a/include/ggml/ggml-alloc.h b/include/ggml/ggml-alloc.h index e38758878..9f9c9a8b9 100644 --- a/include/ggml/ggml-alloc.h +++ b/include/ggml/ggml-alloc.h @@ -23,6 +23,13 @@ GGML_API void ggml_allocr_alloc (struct ggml_allocr * alloc, struct ggml_ GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph); GGML_API size_t ggml_allocr_max_size (struct ggml_allocr * alloc); +// set tensor data from external pointer (shallow copy) +// WARNING! It is the responsibility of the user to ensure that the provided pointer: +// * is compatible with the buffer backend (same address space) +// * points to memory of the right size and type/quantization as described by the tensor +// * remains valid while the associated tensor is used +GGML_API void ggml_allocr_set_tensor_external_data(struct ggml_allocr * alloc, struct ggml_tensor * tensor, void * data, size_t data_offset); + GGML_API size_t ggml_allocr_alloc_graph_n( struct ggml_allocr * alloc, struct ggml_cgraph ** graphs, int n_graphs, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b225597ed..3abbfec82 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -249,7 +249,12 @@ if (GGML_PERF) set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_PERF) endif() -add_library(${TARGET} +if (GGML_PLUGIN) + set(GGML_LIB_TYPE STATIC) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) +endif() + +add_library(${TARGET} ${GGML_LIB_TYPE} ggml.c ggml-alloc.c ggml-backend.c @@ -261,6 +266,8 @@ add_library(${TARGET} ${GGML_METAL_SOURCES} ) +add_library(ggml::ggml ALIAS ggml) + target_include_directories(${TARGET} PUBLIC . ../include @@ -274,7 +281,7 @@ else() target_link_libraries(${TARGET} PUBLIC m ${GGML_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT}) endif() -if (BUILD_SHARED_LIBS) +if (BUILD_SHARED_LIBS AND NOT GGML_PLUGIN) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) target_link_libraries(${TARGET} PUBLIC diff --git a/src/ggml-alloc.c b/src/ggml-alloc.c index 34eba3f83..5da722d1a 100644 --- a/src/ggml-alloc.c +++ b/src/ggml-alloc.c @@ -183,6 +183,15 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size); } +void ggml_allocr_set_tensor_external_data(struct ggml_allocr * alloc, struct ggml_tensor * tensor, void * data, size_t data_offset) { + GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources + GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated + GGML_ASSERT(data_offset == 0); // not supported yet + tensor->data = data; + tensor->buffer = alloc->buffer; + ggml_backend_buffer_init_tensor(alloc->buffer, tensor); +} + // this is a very naive implementation, but for our case the number of free blocks should be very small static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { if (ggml_allocr_is_own(alloc, tensor) == false) { diff --git a/src/ggml-backend.c b/src/ggml-backend.c index ca8d83daf..45cf7cfa1 100644 --- a/src/ggml-backend.c +++ b/src/ggml-backend.c @@ -231,8 +231,11 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512 static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) { - size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned - void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC? + void * data = NULL; + if (size) { + size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned + data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC? + } return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size); } @@ -364,7 +367,7 @@ ggml_backend_t ggml_backend_cpu_init(void) { *cpu_backend = (struct ggml_backend) { /* .interface = */ cpu_backend_i, - /* .context = */ ctx + /* .context = */ ctx, }; return cpu_backend; } diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 5bd83bb5c..507250f9f 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -463,6 +463,9 @@ inline cudaError_t ggml_cuda_set_device(const int device) { return cudaSetDevice(device); } +static bool g_cublas_initialized = false; +static bool g_cublas_initialized_as_plugin = false; + static int g_device_count = -1; static int g_main_device = 0; static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; @@ -5632,9 +5635,7 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { void ggml_init_cublas() { - static bool initialized = false; - - if (!initialized) { + if (!g_cublas_initialized) { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: @@ -5655,9 +5656,9 @@ void ggml_init_cublas() { g_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD; + g_compute_capabilities[id] = 100 * prop.major + 10 * prop.minor + CC_OFFSET_AMD; #else - g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; + g_compute_capabilities[id] = 100 * prop.major + 10 * prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } for (int id = 0; id < g_device_count; ++id) { @@ -5680,7 +5681,7 @@ void ggml_init_cublas() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); - initialized = true; + g_cublas_initialized = true; } } @@ -7536,6 +7537,26 @@ static const char * ggml_backend_cuda_name(ggml_backend_t backend) { } static void ggml_backend_cuda_free(ggml_backend_t backend) { + for (int id = 0; id < GGML_CUDA_MAX_DEVICES; ++id) { + for (int is = 0; is < MAX_STREAMS; ++is) { + auto & stream = g_cudaStreams[id][is]; + if (!stream) break; + if (!g_cublas_initialized_as_plugin) { + cudaStreamDestroy(stream); + } + stream = nullptr; + } + + auto & cublasHandle = g_cublas_handles[id]; + if (!cublasHandle) continue; + if (!g_cublas_initialized_as_plugin) { + cublasDestroy(cublasHandle); + } + cublasHandle = nullptr; + } + g_cublas_initialized = false; + g_cublas_initialized_as_plugin = false; + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; delete cuda_ctx; delete backend; @@ -7640,10 +7661,14 @@ static struct ggml_backend_buffer_i cuda_backend_buffer_interface = { }; static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) { - ggml_cuda_set_device(g_main_device); - ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda; - CUDA_CHECK(cudaMalloc(&ctx->device, size)); + if (size) { + ggml_cuda_set_device(g_main_device); + CUDA_CHECK(cudaMalloc(&ctx->device, size)); + } + else { + ctx->device = NULL; + } return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size); } @@ -7767,15 +7792,40 @@ static ggml_backend_i cuda_backend_i = { /* .supports_op = */ nullptr, }; +static ggml_backend_t create_cuda_backend(ggml_backend_context_cuda* ctx) { + ggml_backend_t cuda_backend = new ggml_backend{ + /* .interface = */ cuda_backend_i, + /* .context = */ ctx, + }; + + return cuda_backend; +} + ggml_backend_t ggml_backend_cuda_init() { ggml_init_cublas(); // TODO: remove from ggml.c ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda; + return create_cuda_backend(ctx); +} - ggml_backend_t cuda_backend = new ggml_backend { - /* .interface = */ cuda_backend_i, - /* .context = */ ctx - }; +ggml_backend_t ggml_backend_cuda_init_plugin(int main_device, void * cublas_handle, void * cuda_stream) { + GGML_ASSERT(g_cublas_initialized == false && "currently only a single cuda backend is supported"); - return cuda_backend; + g_device_count = main_device + 1; + int id = g_main_device = main_device; + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); + fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor); + + // g_tensor_split[id] = 0; + g_compute_capabilities[id] = 100 * prop.major + 10 * prop.minor; + g_cublas_handles[id] = (cublasHandle_t)cublas_handle; + g_cudaStreams[id][0] = (cudaStream_t)cuda_stream; + + g_cublas_initialized = true; + g_cublas_initialized_as_plugin = true; + + ggml_backend_context_cuda* ctx = new ggml_backend_context_cuda; + return create_cuda_backend(ctx); } diff --git a/src/ggml-cuda.h b/src/ggml-cuda.h index 57adc9cf3..e542461a6 100644 --- a/src/ggml-cuda.h +++ b/src/ggml-cuda.h @@ -45,6 +45,7 @@ GGML_API void ggml_cuda_get_device_description(int device, char * description, // backend API GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use +GGML_API ggml_backend_t ggml_backend_cuda_init_plugin(int main_device, void * cublas_handle, void * cuda_stream); #ifdef __cplusplus }