Skip to content

Commit

Permalink
plugin example
Browse files Browse the repository at this point in the history
  • Loading branch information
iboB committed Oct 12, 2023
1 parent 5224fff commit 315f304
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ add_subdirectory(replit)
add_subdirectory(mpt)
add_subdirectory(starcoder)
add_subdirectory(sam)
add_subdirectory(plugin)
10 changes: 10 additions & 0 deletions examples/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
add_library(plugin-model STATIC model.cpp)
target_link_libraries(plugin-model PUBLIC ggml::ggml)

add_executable(cpu-plugin cpu-plugin.cpp)
target_link_libraries(cpu-plugin plugin-model)

if (GGML_CUBLAS)
add_executable(cuda-plugin cuda-plugin.cpp)
target_link_libraries(cuda-plugin plugin-model)
endif()
5 changes: 5 additions & 0 deletions examples/plugin/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# GGML Plugin

This example showcases the use of GGML as a plugin.

The executables demonstrate how to initialize a backend and run inference with a model whose data comes from the outside.
42 changes: 42 additions & 0 deletions examples/plugin/cpu-plugin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "model.hpp"

#include <ggml.h>
#include <ggml-backend.h>

#include <vector>
#include <iostream>

int main() {
auto backend = ggml_backend_cpu_init();

std::vector<float> weights_data;
for (int i = 0; i < 10; ++i) {
weights_data.push_back(float(i));
}

void* weights = weights_data.data();

model m(backend, weights_data.size(), GGML_TYPE_F32, weights);

std::vector<float> input_data;
for (size_t i = 0; i < weights_data.size(); ++i) {
input_data.push_back(float(i) / 10);
}

std::vector<float> output_data(input_data.size());

void* input = input_data.data();
void* output = output_data.data();

m.compute(output, input);

ggml_backend_free(backend);

std::cout << "[";
for (auto o : output_data) {
std::cout << o << ", ";
}
std::cout << "]\n";

return 0;
}
67 changes: 67 additions & 0 deletions examples/plugin/cuda-plugin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "model.hpp"

#include <vector>
#include <iostream>

#include <ggml-cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>

int main() {
// init cuda
int device_id = 0;
cudaSetDevice(device_id);
cublasHandle_t cublas_handle = nullptr;
cublasCreate(&cublas_handle);
cudaStream_t cuda_stream = nullptr;
cudaStreamCreateWithFlags(&cuda_stream, cudaStreamNonBlocking);

// create plugin backend
auto backend = ggml_backend_cuda_init_plugin(device_id, cublas_handle, cuda_stream);

// init weights
std::vector<float> weights_data;
for (int i = 0; i < 10; ++i) {
weights_data.push_back(float(i));
}

void* weights = nullptr;
cudaMallocAsync(&weights, data_size(weights_data), cuda_stream);
cudaMemcpyAsync(weights, weights_data.data(), data_size(weights_data), cudaMemcpyHostToDevice, cuda_stream);

// create model with weights
model m(backend, weights_data.size(), GGML_TYPE_F32, weights);

// init input and output data
std::vector<float> input_data;
for (size_t i = 0; i < weights_data.size(); ++i) {
input_data.push_back(float(i) / 10);
}

std::vector<float> output_data(input_data.size());

void* input = nullptr;
cudaMallocAsync(&input, data_size(input_data), cuda_stream);
cudaMemcpyAsync(input, input_data.data(), data_size(input_data), cudaMemcpyHostToDevice, cuda_stream);

void* output = nullptr;
cudaMallocAsync(&output, data_size(output_data), cuda_stream);

// compute with cuda pointers
m.compute(output, input);

// get data back from cuda pointers
cudaMemcpyAsync(output_data.data(), output, data_size(output_data), cudaMemcpyDeviceToHost, cuda_stream);
cudaStreamSynchronize(cuda_stream);

ggml_backend_free(backend);

// print result
std::cout << "[";
for (auto o : output_data) {
std::cout << o << ", ";
}
std::cout << "]\n";

return 0;
}
68 changes: 68 additions & 0 deletions examples/plugin/model.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "model.hpp"

#include <ggml.h>
#include <ggml-alloc.h>
#include <ggml-backend.h>

#include <cassert>

model::model(ggml_backend_t be, int64_t s, ggml_type t, void* weights_data)
: backend(be)
, size(s)
, type(t)
{
assert(weights_data);
static constexpr size_t numWeightTensors = sizeof(weights_t) / sizeof(ggml_tensor*);
wctx = ggml_init({
/*.mem_size =*/ ggml_tensor_overhead() * numWeightTensors,
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
});
weights.w = ggml_new_tensor_1d(wctx, type, size);
wbuf = ggml_backend_alloc_buffer(backend, 0);
auto wallocr = ggml_allocr_new_from_buffer(wbuf);
ggml_allocr_set_tensor_external_data(wallocr, weights.w, weights_data, 0);
ggml_allocr_free(wallocr);

cbuf = ggml_backend_alloc_buffer(backend, 0);
callocr = ggml_allocr_new_from_buffer(cbuf);
}

model::~model() {
ggml_free(wctx);
ggml_backend_buffer_free(wbuf);
ggml_allocr_free(callocr);
ggml_backend_buffer_free(cbuf);
}

struct io_tensors {
ggml_tensor* input = nullptr;
ggml_tensor* output = nullptr;
};

void model::compute(void* output, void* input) {
assert(input);
assert(output);

static constexpr size_t num_io_tensors = sizeof(io_tensors) / sizeof(ggml_tensor*);
auto cctx = ggml_init({
/*.mem_size =*/ ggml_tensor_overhead() * num_io_tensors + ggml_graph_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
});

io_tensors io = {};
io.input = ggml_new_tensor_1d(cctx, type, size);
io.output = ggml_add(cctx, io.input, weights.w);

ggml_allocr_set_tensor_external_data(callocr, io.input, input, 0);
ggml_allocr_set_tensor_external_data(callocr, io.output, output, 0);

auto graph = ggml_new_graph(cctx);
ggml_build_forward_expand(graph, io.output);

ggml_backend_graph_compute(backend, graph);

ggml_allocr_reset(callocr);
ggml_free(cctx);
}
37 changes: 37 additions & 0 deletions examples/plugin/model.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once
#include <cstdint>

struct ggml_tensor;
typedef struct ggml_backend* ggml_backend_t;
struct ggml_context;
enum ggml_type;
struct ggml_backend_buffer;
struct ggml_allocr;

struct model {
struct weights_t {
ggml_tensor* w = nullptr;
} weights;

ggml_backend_t backend = nullptr;

ggml_context* wctx = nullptr;
ggml_backend_buffer* wbuf = nullptr; // weights buffer

ggml_backend_buffer* cbuf = nullptr; // compute buffer
ggml_allocr* callocr = nullptr; // compute allocator

const int64_t size;
const ggml_type type;

model(ggml_backend_t be, int64_t s, ggml_type t, void* weights_data);
~model();

void compute(void* output, void* input);
};

// util
template <typename Vec>
size_t data_size(const Vec& vec) {
return vec.size() * sizeof(typename Vec::value_type);
}

0 comments on commit 315f304

Please sign in to comment.