From 4da16e4d8e9828c57bfd669bcf11b2e00a8f4274 Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 12 Oct 2023 22:29:23 +0200 Subject: [PATCH 1/3] ggml C++ bindings wip --- examples/CMakeLists.txt | 1 + examples/ggml-cpp/CMakeLists.txt | 21 + examples/ggml-cpp/ggml-cpp.h | 538 +++++++++++++++++++ examples/ggml-cpp/gpt-2-cpp.cpp | 889 +++++++++++++++++++++++++++++++ include/ggml/ggml-alloc.h | 2 + src/ggml-alloc.c | 5 +- src/ggml-backend.c | 9 + 7 files changed, 1464 insertions(+), 1 deletion(-) create mode 100644 examples/ggml-cpp/CMakeLists.txt create mode 100644 examples/ggml-cpp/ggml-cpp.h create mode 100644 examples/ggml-cpp/gpt-2-cpp.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e3404fb8b..498c6730e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -28,3 +28,4 @@ add_subdirectory(replit) add_subdirectory(mpt) add_subdirectory(starcoder) add_subdirectory(sam) +add_subdirectory(ggml-cpp) diff --git a/examples/ggml-cpp/CMakeLists.txt b/examples/ggml-cpp/CMakeLists.txt new file mode 100644 index 000000000..e0342dae5 --- /dev/null +++ b/examples/ggml-cpp/CMakeLists.txt @@ -0,0 +1,21 @@ +# +# gpt-2-cpp + +set(TEST_TARGET gpt-2-cpp) +add_executable(${TEST_TARGET} gpt-2-cpp.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml) + +# +# For GPU offloading + +if (GGML_CUBLAS) + add_compile_definitions(GGML_USE_CUBLAS) +endif() + +if (GGML_CLBLAST) + add_compile_definitions(GGML_USE_CLBLAST) +endif() + +if (GGML_METAL) + add_compile_definitions(GGML_USE_METAL) +endif() diff --git a/examples/ggml-cpp/ggml-cpp.h b/examples/ggml-cpp/ggml-cpp.h new file mode 100644 index 000000000..ade8e6dee --- /dev/null +++ b/examples/ggml-cpp/ggml-cpp.h @@ -0,0 +1,538 @@ +#include "ggml/ggml.h" +#include "ggml/ggml-alloc.h" +#include "ggml/ggml-backend.h" + +#include +#include +#include + +namespace ggml { + struct context { + context() : ctx(nullptr) {} + context(size_t mem_size, void * mem_buffer, bool no_alloc) { + ggml_init_params params = { + /*.mem_size = */ mem_size, + /*.mem_buffer = */ mem_buffer, + /*.no_alloc = */ no_alloc + }; + ctx = ggml_init(params); + if (ctx == nullptr) { + throw std::runtime_error("failed to initialize ggml"); + } + } + context(const context & ctx) = delete; + context(context && ctx) { + this->ctx = ctx.ctx; + ctx.ctx = nullptr; + } + ~context() { + ggml_free(ctx); + } + context & operator=(const context & rhs) = delete; + context & operator=(context && rhs) { + if (this != &rhs) { + this->ctx = rhs.ctx; + rhs.ctx = nullptr; + } + return *this; + } + + + operator bool() const { + return ctx != nullptr; + } + + ggml_context * get() { + GGML_ASSERT(ctx != nullptr && "context not initialized"); + return ctx; + } + + private: + ggml_context * ctx; + }; + + // the global context stack allows using tensors without explicitly passing the context + // tensors must be created within a context_guard + struct ctx_stack { + std::stack stack; + }; + + inline ctx_stack & get_ctx_stack() { + static ctx_stack s; + return s; + } + + inline ggml_context * ctx() { + ggml_context * g_ctx = get_ctx_stack().stack.empty() ? nullptr : get_ctx_stack().stack.top(); + GGML_ASSERT(g_ctx != nullptr && "this function must be called within a context_guard"); + return g_ctx; + } + + // TODO: nested context guards are not always properly handled + struct context_guard { + context_guard(context & ctx) : ctx(ctx.get()) { + get_ctx_stack().stack.push(ctx.get()); + } + context_guard(const context_guard & ctx) = delete; + context_guard(context_guard && ctx) { + this->ctx = ctx.ctx; + ctx.ctx = nullptr; + } + + context_guard & operator=(const context_guard & rhs) = delete; + context_guard & operator=(context_guard && rhs) { + this->ctx = rhs.ctx; + rhs.ctx = nullptr; + return *this; + } + + ~context_guard() { + if (ctx != nullptr) { + release(); + } + } + + void release() { + GGML_ASSERT(ctx != nullptr && "this context_guard has already been released"); + GGML_ASSERT(get_ctx_stack().stack.top() == ctx && "only the top context_guard can be released"); + ctx = nullptr; + get_ctx_stack().stack.pop(); + } + + + ggml_context * ctx; + }; + + struct tensor { + tensor() : val(nullptr) {} + tensor(ggml_tensor * val) : val(val) {} + tensor(const tensor & val) = delete; // reference copies can be performed by initializing from get() + tensor(tensor && val) { + this->val = val.val; + val.val = nullptr; + } + tensor & operator=(const tensor & rhs) = delete; + tensor & operator=(tensor && rhs) { + if (this != &rhs) { + this->val = rhs.val; + rhs.val = nullptr; + } + return *this; + } + + // new tensor + tensor(ggml_type type) { + val = ggml_new_tensor_1d(ctx(), type, 1); + } + tensor(ggml_type type, int64_t ne0) { + val = ggml_new_tensor_1d(ctx(), type, ne0); + } + tensor(ggml_type type, int64_t ne0, int64_t ne1) { + val = ggml_new_tensor_2d(ctx(), type, ne0, ne1); + } + tensor(ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2) { + val = ggml_new_tensor_3d(ctx(), type, ne0, ne1, ne2); + } + tensor(ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { + val = ggml_new_tensor_4d(ctx(), type, ne0, ne1, ne2, ne3); + } + + // new float tensor + tensor(int64_t ne0) : tensor(GGML_TYPE_F32, ne0) {} + tensor(int64_t ne0, int64_t ne1) : tensor(GGML_TYPE_F32, ne0, ne1) {} + tensor(int64_t ne0, int64_t ne1, int64_t ne2) : tensor(GGML_TYPE_F32, ne0, ne1, ne2) {} + tensor(int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) : tensor(GGML_TYPE_F32, ne0, ne1, ne2, ne3) {} + + // view + tensor view(int64_t ne0, size_t offset = 0) { + return ggml_view_1d(ctx(), get(), ne0, offset); + } + tensor view(int64_t ne0, int64_t ne1, size_t nb1, size_t offset = 0) { + return ggml_view_2d(ctx(), get(), ne0, ne1, nb1, offset); + } + tensor view(int64_t ne0, int64_t ne1, int64_t ne2, size_t nb1, size_t nb2, size_t offset = 0) { + return ggml_view_3d(ctx(), get(), ne0, ne1, ne2, nb1, nb2, offset); + } + tensor view(int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, size_t nb1, size_t nb2, size_t nb3, size_t offset = 0) { + return ggml_view_4d(ctx(), get(), ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset); + } + + // reshape + tensor reshape(int64_t ne0) { + return ggml_reshape_1d(ctx(), get(), ne0); + } + tensor reshape(int64_t ne0, int64_t ne1) { + return ggml_reshape_2d(ctx(), get(), ne0, ne1); + } + tensor reshape(int64_t ne0, int64_t ne1, int64_t ne2) { + return ggml_reshape_3d(ctx(), get(), ne0, ne1, ne2); + } + tensor reshape(int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { + return ggml_reshape_4d(ctx(), get(), ne0, ne1, ne2, ne3); + } + + // permute + tensor permute(int axis0, int axis1, int axis2, int axis3) { + return ggml_permute(ctx(), get(), axis0, axis1, axis2, axis3); + } + + // cont + tensor cont() const { + return ggml_cont(ctx(), val); + } + tensor cont(int64_t ne0) const { + return ggml_cont_1d(ctx(), get(), ne0); + } + tensor cont(int64_t ne0, int64_t ne1) const { + return ggml_cont_2d(ctx(), get(), ne0, ne1); + } + tensor cont(int64_t ne0, int64_t ne1, int64_t ne2) const { + return ggml_cont_3d(ctx(), get(), ne0, ne1, ne2); + } + tensor cont(int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) const { + return ggml_cont_4d(ctx(), get(), ne0, ne1, ne2, ne3); + } + + // copy + tensor cpy(const tensor & a) { + return ggml_cpy(ctx(), a.get(), get()); + } + + tensor dup_tensor() const { + return ggml_dup_tensor(ctx(), get()); + } + + // operators + tensor operator*(const tensor & rhs) const { + if (rhs.ne(0) == 1 && rhs.ne(1) == 1 && rhs.ne(2) == 1 && rhs.ne(3) == 1) { + return ggml_scale(ctx(), get(), rhs.get()); + } + return ggml_mul(ctx(), get(), rhs.get()); + } + + tensor operator+(const tensor & rhs) const { + return ggml_add(ctx(), get(), rhs.get()); + } + + tensor operator()(const tensor & rhs) const { + return ggml_mul_mat(ctx(), get(), rhs.get()); + } + + operator bool() const { + return val != nullptr; + } + + // getters + int64_t ne(int n) const { + return get()->ne[n]; + } + + size_t nb(int n) const { + return get()->nb[n]; + } + + ggml_type type() const { + return get()->type; + } + + size_t element_size() const { + return ggml_element_size(get()); + } + + size_t nbytes() const { + return ggml_nbytes(get()); + } + + int64_t nelements() const { + return ggml_nelements(get()); + } + + void * data() { + return ggml_get_data(get()); + } + + ggml_tensor * get() const { + GGML_ASSERT(val != nullptr && "tensor not initialized"); + return val; + } + + std::string get_name() const { + return ggml_get_name(get()); + } + + // setters + void set_name(const std::string & name) { + ggml_set_name(get(), name.c_str()); + } + + ggml_tensor * val; + + // backend + void backend_set(const void * data, size_t offset, size_t nbytes) { + ggml_backend_tensor_set(get(), data, offset, nbytes); + } + + void backend_get(void * data, size_t offset, size_t nbytes) { + ggml_backend_tensor_get(get(), data, offset, nbytes); + } + + void backend_copy(tensor & dst) { + ggml_backend_tensor_copy(get(), dst.get()); + } + }; + + struct graph { + graph() { + gf = ggml_new_graph(ctx()); + } + + graph(const graph & g) = delete; + + graph(graph && g) { + this->gf = g.gf; + g.gf = nullptr; + } + + graph & operator=(const graph & rhs) = delete; + + graph & operator=(graph && rhs) { + if (this != &rhs) { + this->gf = rhs.gf; + rhs.gf = nullptr; + } + return *this; + } + + void expand(const tensor & t) { + ggml_build_forward_expand(gf, t.get()); + } + + tensor get_node(int i) { + return get()->nodes[i]; + } + + size_t n_nodes() const { + return get()->n_nodes; + } + + ggml_cgraph * get() const { + return gf; + } + + ggml_cgraph * gf; + }; + + inline tensor get_rows(const tensor & a, const tensor & b) { + return ggml_get_rows(ctx(), a.get(), b.get()); + } + + inline tensor norm(const tensor & t, float eps) { + return ggml_norm(ctx(), t.get(), eps); + } + + inline tensor diag_mask_inf(const tensor & t, int n_past) { + return ggml_diag_mask_inf(ctx(), t.get(), n_past); + } + + inline tensor soft_max(const tensor & t) { + return ggml_soft_max(ctx(), t.get()); + } + + inline tensor gelu(const tensor & t) { + return ggml_gelu(ctx(), t.get()); + } + + inline tensor mul_mat(const tensor & a, const tensor & b) { + return ggml_mul_mat(ctx(), a.get(), b.get()); + } + + // backend + struct backend_buffer { + backend_buffer() : val(nullptr) {} + backend_buffer(ggml_backend_buffer_t val) : val(val) {} + backend_buffer(const backend_buffer & val) = delete; + backend_buffer(backend_buffer && val) { + this->val = val.val; + val.val = nullptr; + } + ~backend_buffer() { + free(); + } + + backend_buffer & operator=(const backend_buffer & rhs) = delete; + backend_buffer & operator=(backend_buffer && rhs) { + if (this != &rhs) { + free(); + this->val = rhs.val; + rhs.val = nullptr; + } + return *this; + } + + operator bool() const { + return val != nullptr; + } + + void free() { + ggml_backend_buffer_free(val); + val = nullptr; + } + + size_t get_alignment() const { + return ggml_backend_buffer_get_alignment(get()); + } + + void * get_base() const { + return ggml_backend_buffer_get_base(get()); + } + + size_t get_size() const { + return ggml_backend_buffer_get_size(get()); + } + + size_t get_alloc_size(tensor & tensor) const { + return ggml_backend_buffer_get_alloc_size(get(), tensor.get()); + } + + void init_tensor(tensor & tensor) { + ggml_backend_buffer_init_tensor(get(), tensor.get()); + } + + void free_tensor(tensor & tensor) { + ggml_backend_buffer_free_tensor(get(), tensor.get()); + } + + ggml_backend_buffer_t get() const { + GGML_ASSERT(val != nullptr && "backend_buffer not initialized"); + return val; + } + + ggml_backend_buffer_t val; + }; + + struct backend { + backend() : val(nullptr) {} + backend(ggml_backend_t val) : val(val) {} + backend(const backend & val) = delete; + backend(backend && val) { + this->val = val.val; + val.val = nullptr; + } + ~backend() { + free(); + } + + backend & operator=(const backend & rhs) = delete; + backend & operator=(backend && rhs) { + if (this != &rhs) { + free(); + this->val = rhs.val; + rhs.val = nullptr; + } + return *this; + } + + operator bool() const { + return val != nullptr; + } + + std::string name() const { + return ggml_backend_name(get()); + } + + void free() { + ggml_backend_free(val); + val = nullptr; + } + + size_t get_alignment() const { + return ggml_backend_get_alignment(get()); + } + + backend_buffer alloc_buffer(size_t size) { + return ggml_backend_alloc_buffer(get(), size); + } + + void graph_compute(const graph & gf) { + ggml_backend_graph_compute(get(), gf.get()); + } + + ggml_backend_t get() const { + GGML_ASSERT(val != nullptr && "backend not initialized"); + return val; + } + + ggml_backend_t val; + }; + + + struct allocr { + allocr() : val(nullptr) {} + allocr(void * data, size_t size, size_t alignment) { + val = ggml_allocr_new(data, size, alignment); + } + allocr(backend_buffer & buffer) { + val = ggml_allocr_new_from_buffer(buffer.get()); + } + allocr(ggml_allocr_t val) : val(val) {} + allocr(const allocr & val) = delete; + allocr(allocr && val) { + this->val = val.val; + val.val = nullptr; + } + ~allocr() { + free(); + } + + static allocr new_measure(size_t alignment) { + return ggml_allocr_new_measure(alignment); + } + + allocr & operator=(const allocr & rhs) = delete; + allocr & operator=(allocr && rhs) { + if (this != &rhs) { + free(); + this->val = rhs.val; + rhs.val = nullptr; + } + return *this; + } + + operator bool() const { + return val != nullptr; + } + + void free() { + ggml_allocr_free(val); + val = nullptr; + } + + bool is_measure() const { + return ggml_allocr_is_measure(get()); + } + + void reset() { + ggml_allocr_reset(get()); + } + + void alloc(tensor & tensor) { + ggml_allocr_alloc(get(), tensor.get()); + } + + size_t alloc_graph(graph & graph) { + return ggml_allocr_alloc_graph(get(), graph.get()); + } + + size_t max_size() const { + return ggml_allocr_max_size(get()); + } + + ggml_allocr_t get() const { + GGML_ASSERT(val != nullptr && "allocr not initialized"); + return val; + } + + + ggml_allocr_t val; + }; +} diff --git a/examples/ggml-cpp/gpt-2-cpp.cpp b/examples/ggml-cpp/gpt-2-cpp.cpp new file mode 100644 index 000000000..d5793bdc5 --- /dev/null +++ b/examples/ggml-cpp/gpt-2-cpp.cpp @@ -0,0 +1,889 @@ +#include "ggml/ggml.h" +#include "ggml/ggml-alloc.h" +#include "ggml/ggml-backend.h" + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include "common.h" +#include "common-ggml.h" +#include "ggml-cpp.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +std::string format(const char * fmt, ...) { + va_list args; + va_start(args, fmt); + + // Get the required size + int size = std::vsnprintf(nullptr, 0, fmt, args); + va_end(args); + + if(size <= 0) { + return ""; + } + + std::string result(size, '\0'); + + va_start(args, fmt); + std::vsnprintf(&result[0], size + 1, fmt, args); + va_end(args); + + return result; +} + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +// default hparams (GPT-2 117M) +struct gpt2_hparams { + int32_t n_vocab = 50257; + int32_t n_ctx = 1024; + int32_t n_embd = 768; + int32_t n_head = 12; + int32_t n_layer = 12; + int32_t ftype = 1; + float eps = 1e-5f; +}; + +struct gpt2_layer { + gpt2_layer(ggml_type wtype, int n_embd) { + ln_1_g = ggml::tensor(GGML_TYPE_F32, n_embd); + ln_1_b = ggml::tensor(GGML_TYPE_F32, n_embd); + ln_2_g = ggml::tensor(GGML_TYPE_F32, n_embd); + ln_2_b = ggml::tensor(GGML_TYPE_F32, n_embd); + c_attn_attn_w = ggml::tensor(wtype, n_embd, 3*n_embd); + c_attn_attn_b = ggml::tensor(GGML_TYPE_F32, 3*n_embd); + c_attn_proj_w = ggml::tensor(wtype, n_embd, n_embd); + c_attn_proj_b = ggml::tensor(GGML_TYPE_F32, n_embd); + c_mlp_fc_w = ggml::tensor(wtype, n_embd, 4*n_embd); + c_mlp_fc_b = ggml::tensor(GGML_TYPE_F32, 4*n_embd); + c_mlp_proj_w = ggml::tensor(wtype, 4*n_embd, n_embd); + c_mlp_proj_b = ggml::tensor(GGML_TYPE_F32, n_embd); + } + + // normalization + ggml::tensor ln_1_g; + ggml::tensor ln_1_b; + + ggml::tensor ln_2_g; + ggml::tensor ln_2_b; + + // attention + ggml::tensor c_attn_attn_w; + ggml::tensor c_attn_attn_b; + + ggml::tensor c_attn_proj_w; + ggml::tensor c_attn_proj_b; + + // mlp + ggml::tensor c_mlp_fc_w; + ggml::tensor c_mlp_fc_b; + + ggml::tensor c_mlp_proj_w; + ggml::tensor c_mlp_proj_b; +}; + +struct gpt2_model { + gpt2_model(ggml_type wtype, gpt2_hparams hparams) : hparams(hparams) { + ctx = ggml::context(ggml_tensor_overhead()*(2 + 6 + 12*hparams.n_layer), NULL, true); + ggml::context_guard ctx_guard(ctx); + + ln_f_g = ggml::tensor(GGML_TYPE_F32, hparams.n_embd); + ln_f_b = ggml::tensor(GGML_TYPE_F32, hparams.n_embd); + + wte = ggml::tensor(wtype, hparams.n_embd, hparams.n_vocab); + wpe = ggml::tensor(GGML_TYPE_F32, hparams.n_embd, hparams.n_ctx); + lm_head = ggml::tensor(wtype, hparams.n_embd, hparams.n_vocab); + + layers.reserve(hparams.n_layer); + for (int i = 0; i < hparams.n_layer; ++i) { + layers.emplace_back(wtype, hparams.n_embd); + } + + memory_k = ggml::tensor(GGML_TYPE_F32, hparams.n_embd*hparams.n_layer*hparams.n_ctx); + memory_v = ggml::tensor(GGML_TYPE_F32, hparams.n_embd*hparams.n_layer*hparams.n_ctx); + + // map by name + tensors["model/ln_f/g"] = &ln_f_g; + tensors["model/ln_f/b"] = &ln_f_b; + + tensors["model/wte"] = &wte; + tensors["model/wpe"] = &wpe; + tensors["model/lm_head"] = &lm_head; + + for (int i = 0; i < hparams.n_layer; ++i) { + gpt2_layer & layer = layers[i]; + + // map by name + tensors["model/h" + std::to_string(i) + "/ln_1/g"] = &layer.ln_1_g; + tensors["model/h" + std::to_string(i) + "/ln_1/b"] = &layer.ln_1_b; + + tensors["model/h" + std::to_string(i) + "/ln_2/g"] = &layer.ln_2_g; + tensors["model/h" + std::to_string(i) + "/ln_2/b"] = &layer.ln_2_b; + + tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = &layer.c_attn_attn_w; + tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = &layer.c_attn_attn_b; + + tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = &layer.c_attn_proj_w; + tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = &layer.c_attn_proj_b; + + tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = &layer.c_mlp_fc_w; + tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = &layer.c_mlp_fc_b; + + tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = &layer.c_mlp_proj_w; + tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = &layer.c_mlp_proj_b; + } + } + + gpt2_hparams hparams; + ggml::context ctx; + + // normalization + ggml::tensor ln_f_g; + ggml::tensor ln_f_b; + + ggml::tensor wte; // position embedding + ggml::tensor wpe; // token embedding + ggml::tensor lm_head; // language model head + + std::vector layers; + + // key + value memory + ggml::tensor memory_k; + ggml::tensor memory_v; + + ggml::backend backend; + ggml::backend_buffer buffer_w; + ggml::backend_buffer buffer_kv; + + std::map tensors; +}; + +// load the model's weights from a file +gpt2_model gpt2_model_load(const std::string & fname, gpt_vocab & vocab, int n_gpu_layers) { + printf("%s: loading model from '%s'\n", __func__, fname.c_str()); + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + throw std::runtime_error(format("failed to open '%s'", fname.c_str())); + } + + // verify magic + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != GGML_FILE_MAGIC) { + throw std::runtime_error(format("invalid model file '%s' (bad magic)", fname.c_str())); + } + } + + // load hparams + gpt2_hparams hparams; + { + fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); + fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); + fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); + fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); + fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: ftype = %d\n", __func__, hparams.ftype); + printf("%s: qntvr = %d\n", __func__, qntvr); + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + } + + // load vocab + { + int32_t n_vocab = 0; + fin.read((char *) &n_vocab, sizeof(n_vocab)); + + if (n_vocab != hparams.n_vocab) { + throw std::runtime_error(format("invalid model file '%s' (bad vocab size %d != %d)", fname.c_str(), n_vocab, hparams.n_vocab)); + } + + std::string word; + std::vector buf(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + fin.read((char *) &len, sizeof(len)); + + buf.resize(len); + fin.read((char *) buf.data(), len); + word.assign(buf.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (hparams.ftype)); + if (wtype == GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", + __func__, fname.c_str(), hparams.ftype); + } + + // initialize the model object + gpt2_model model(wtype, hparams); + + // initialize the backend +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > 0) { + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (n_gpu_layers > 0) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if (!model.backend) { + // fallback to CPU backend + fprintf(stderr, "%s: using CPU backend\n", __func__); + model.backend = ggml_backend_cpu_init(); + } + + if (!model.backend) { + throw std::runtime_error("ggml_backend_cpu_init() failed"); + } + + // calculate the size of the backend buffer + size_t buffer_size = 0; + { + for (auto it : model.tensors) { + buffer_size += it.second->nbytes() + 128; + } + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %6.2f MB\n", __func__, buffer_size/(1024.0*1024.0)); + } + + // allocate weights buffer + model.buffer_w = model.backend.alloc_buffer(buffer_size); + + // key + value memory + { + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_mem = n_layer*n_ctx; + + const size_t memory_size = model.memory_k.nbytes() + model.memory_v.nbytes(); + + printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem); + + // create a backend buffer (can be in host or device memory) + model.buffer_kv = model.backend.alloc_buffer(memory_size + 256); + + // allocate the tensors into the backend buffer + { + ggml::allocr alloc(model.buffer_kv); + + // this updates the pointers in the tensors to point to the correct location in the buffer + // this is necessary since the ggml_context is .no_alloc == true + // note that the buffer can actually be a device buffer, depending on the backend + alloc.alloc(model.memory_k); + alloc.alloc(model.memory_v); + } + } + + // load weights + { + ggml::allocr alloc(model.buffer_w); + + size_t total_size = 0; + + bool has_lm_head = false; + + std::vector read_buf; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ttype), sizeof(ttype)); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + if (model.tensors.find(name) == model.tensors.end()) { + throw std::runtime_error(format("unknown tensor '%s' in model file", name.c_str())); + } + + auto & tensor = *model.tensors[name]; + tensor.set_name(name); + if (tensor.nelements() != nelements) { + throw std::runtime_error(format("tensor '%s' has wrong size in model file", name.c_str())); + } + + if (tensor.ne(0) != ne[0] || tensor.ne(1) != ne[1]) { + throw std::runtime_error(format("tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]", + name.c_str(), (int) tensor.ne(0), (int) tensor.ne(1), ne[0], ne[1])); + } + + // for debugging + if (0) { + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), tensor.nbytes()/1024.0/1024.0, tensor.nbytes()); + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor.type()) != tensor.nbytes()) { + throw std::runtime_error(format("tensor '%s' has wrong size in model file: got %zu, expected %zu", + name.c_str(), tensor.nbytes(), nelements*bpe)); + } + + alloc.alloc(tensor); + + if (ggml_backend_is_cpu (model.backend.get()) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend.get()) +#endif + ) { + // for the CPU and Metal backend, we can read directly into the tensor + fin.read(reinterpret_cast(tensor.data()), tensor.nbytes()); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(tensor.nbytes()); + fin.read(read_buf.data(), tensor.nbytes()); + tensor.backend_set(read_buf.data(), 0, tensor.nbytes()); + } + + // GPT-2 models share the WTE tensor as the LM head + if (name == "model/wte" && has_lm_head == false) { + alloc.alloc(model.lm_head); + tensor.backend_copy(model.lm_head); + //model.lm_head = tensor; + } + + if (name == "model/lm_head") { + has_lm_head = true; + } + + total_size += tensor.nbytes(); + } + + printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0); + } + + fin.close(); + + return model; +} + +// build the computation graph +struct ggml::graph gpt2_graph( + gpt2_model & model, + ggml::allocr & allocr, + const int n_past, + const std::vector & embd_inp) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead(); + static std::vector buf(buf_size); + + ggml::context ctx0(buf_size, buf.data(), true); + + ggml::context_guard ctx_guard(ctx0); + + ggml::graph gf; + + ggml::tensor embd(GGML_TYPE_I32, N); + allocr.alloc(embd); + + // avoid writing to tensors if we are only measuring the memory usage + if (!allocr.is_measure()) { + embd.backend_set(embd_inp.data(), 0, N*embd.element_size()); + } + + ggml::tensor position(GGML_TYPE_I32, N); + allocr.alloc(position); + if (!allocr.is_measure()) { + for (int i = 0; i < N; ++i) { + int32_t v = n_past + i; + position.backend_set(&v, i*sizeof(int32_t), sizeof(v)); + } + } + + ggml::tensor KQ_scale(GGML_TYPE_F32); + allocr.alloc(KQ_scale); + if (!allocr.is_measure()) { + float s = 1.0f/sqrtf(float(n_embd)/n_head); + KQ_scale.backend_set(&s, 0, sizeof(s)); + } + + // wte + wpe + ggml::tensor inpL = ggml::get_rows(model.wte, embd) + ggml::get_rows(model.wpe, position); + + for (int il = 0; il < n_layer; ++il) { + ggml::tensor cur; + + // norm + { + // [ 768, N] + cur = ggml::norm(inpL, hparams.eps); + + // [ 768, N] + cur = cur*model.layers[il].ln_1_g + model.layers[il].ln_1_b; + } + + // attn + // [2304, 768] - model.layers[il].c_attn_attn_w + // [2304, 1] - model.layers[il].c_attn_attn_b + // [ 768, N] - cur (in) + // [2304, N] - cur (out) + // + // cur = attn_w*cur + attn_b + // [2304, N] + { + cur = ggml::mul_mat(model.layers[il].c_attn_attn_w, cur) + model.layers[il].c_attn_attn_b; + } + + // self-attention + { + ggml::tensor Qcur = cur.view(n_embd, N, cur.nb(1), 0*sizeof(float)*n_embd); + ggml::tensor Kcur = cur.view(n_embd, N, cur.nb(1), 1*sizeof(float)*n_embd); + ggml::tensor Vcur = cur.view(n_embd, N, cur.nb(1), 2*sizeof(float)*n_embd); + + // store key and value to memory + if (N >= 1) { + ggml::tensor k = model.memory_k.view(N*n_embd, model.memory_k.element_size()*n_embd*(il*n_ctx + n_past)); + ggml::tensor v = model.memory_v.view(N*n_embd, model.memory_v.element_size()*n_embd*(il*n_ctx + n_past)); + + // alternative? may be questionable use of operator overloading + // k = Kcur; + // v = Vcur; + // gf.expand(k); + // gf.expand(v); + gf.expand(k.cpy(Kcur)); + gf.expand(v.cpy(Vcur)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + // [64, N, 12] + ggml::tensor Q = Qcur.cont(n_embd/n_head, n_head, N).permute(0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + // [64, n_past + N, 12] + ggml::tensor K = model.memory_k.view((n_past + N)*n_embd, il*n_ctx*model.memory_k.element_size()*n_embd) + .reshape(n_embd/n_head, n_head, n_past + N) + .permute(0, 2, 1, 3); + + // GG: flash attention + //struct ggml_tensor * V = + // ggml_cpy(ctx0, + // ggml_permute(ctx0, + // ggml_reshape_3d(ctx0, + // ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + // n_embd/n_head, n_head, n_past + N), + // 1, 2, 0, 3), + // ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head)); + + //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true); + + // K * Q + // [n_past + N, N, 12] + ggml::tensor KQ = ggml::mul_mat(K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + // [n_past + N, N, 12] + ggml::tensor KQ_scaled = KQ * KQ_scale; + + // KQ_masked = mask_past(KQ_scaled) + // [n_past + N, N, 12] + ggml::tensor KQ_masked = ggml::diag_mask_inf(KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + // [n_past + N, N, 12] + ggml::tensor KQ_soft_max = ggml::soft_max(KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + // [n_past + N, 64, 12] + ggml::tensor V_trans = model.memory_v.view((n_past + N)*n_embd, il*n_ctx*model.memory_v.element_size()*n_embd) + .reshape(n_embd/n_head, n_head, n_past + N) + .permute(1, 2, 0, 3) + .cont(n_past + N, n_embd/n_head, n_head); + + // KQV = transpose(V) * KQ_soft_max + // [64, N, 12] + ggml::tensor KQV = ggml::mul_mat(V_trans, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + // [64, 12, N] + ggml::tensor KQV_merged = KQV.permute(0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + // [768, N] + cur = KQV_merged.cont(n_embd, N); + } + + // projection + // [ 768, 768] - model.layers[il].c_attn_proj_w + // [ 768, 1] - model.layers[il].c_attn_proj_b + // [ 768, N] - cur (in) + // [ 768, N] - cur (out) + // + // cur = proj_w*cur + proj_b + // [768, N] + { + cur = ggml::mul_mat(model.layers[il].c_attn_proj_w, cur) + model.layers[il].c_attn_proj_b; + } + + // add the input + cur = cur + inpL; + + ggml::tensor inpFF = cur.get(); + + // feed-forward network + { + // norm + { + cur = ggml::norm(inpFF, hparams.eps); + + // cur = ln_2_g*cur + ln_2_b + // [ 768, N] + cur = cur*model.layers[il].ln_2_g + model.layers[il].ln_2_b; + } + + // fully connected + // [3072, 768] - model.layers[il].c_mlp_fc_w + // [3072, 1] - model.layers[il].c_mlp_fc_b + // [ 768, N] - cur (in) + // [3072, N] - cur (out) + // + // cur = fc_w*cur + fc_b + // [3072, N] + cur = ggml::mul_mat(model.layers[il].c_mlp_fc_w, cur) + model.layers[il].c_mlp_fc_b; + + // GELU activation + // [3072, N] + cur = ggml::gelu(cur); + + // projection + // [ 768, 3072] - model.layers[il].c_mlp_proj_w + // [ 768, 1] - model.layers[il].c_mlp_proj_b + // [3072, N] - cur (in) + // [ 768, N] - cur (out) + // + // cur = proj_w*cur + proj_b + // [768, N] + cur = ggml::mul_mat(model.layers[il].c_mlp_proj_w, cur) + model.layers[il].c_mlp_proj_b; + } + + // input for next layer + inpL = cur + inpFF; + } + + // norm + { + // [ 768, N] + inpL = ggml::norm(inpL, hparams.eps); + + // inpL = ln_f_g*inpL + ln_f_b + // [ 768, N] + inpL = inpL*model.ln_f_g + model.ln_f_b; + } + + // inpL = WTE * inpL + // [ 768, 50257] - model.lm_head + // [ 768, N] - inpL + inpL = ggml::mul_mat(model.lm_head, inpL); + + // logits -> probs + //inpL = ggml::soft_max(inpL); + + gf.expand(inpL); + + return gf; +} + +// evaluate the transformer +// +// - model: the model +// - allocr: ggml_allocr to use to allocate the compute buffer +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +bool gpt2_eval( + gpt2_model & model, + ggml::allocr & allocr, + const int n_threads, + const int n_past, + const std::vector & embd_inp, + std::vector & embd_w) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + // reset the allocator to free all the memory allocated during the previous inference + allocr.reset(); + + ggml::graph gf = gpt2_graph(model, allocr, n_past, embd_inp); + + // allocate tensors + allocr.alloc_graph(gf); + + // run the computation + if (ggml_backend_is_cpu(model.backend.get())) { + ggml_backend_cpu_set_n_threads(model.backend.get(), n_threads); + } +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend.get())) { + ggml_backend_metal_set_n_cb(model.backend.get(), n_threads); + } +#endif + model.backend.graph_compute(gf); + + //if (n_past%100 == 0) { + // ggml_graph_print (&gf); + // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + //} + + // in this case, the output tensor is the last one in the graph + ggml::tensor inpL = gf.get_node(gf.n_nodes() - 1); + + //embd_w.resize(n_vocab*N); + //inpL.backend_get(embd_w.data(), 0, sizeof(float)*n_vocab*N); + + // return result just for the last token + embd_w.resize(n_vocab); + inpL.backend_get(embd_w.data(), (n_vocab*(N-1))*sizeof(float), sizeof(float)*n_vocab); + + return true; +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + const int64_t t_main_start_us = ggml_time_us(); + + gpt_params params; + params.model = "models/gpt-2-117M/ggml-model.bin"; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + printf("%s: seed = %d\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.prompt.empty()) { + params.prompt = gpt_random_prompt(rng); + } + + int64_t t_load_us = 0; + + gpt_vocab vocab; + + // load the model + const int64_t t_start_us = ggml_time_us(); + + // ... + auto load_model = [](const std::string & model_file, gpt_vocab & vocab, const int n_gpu_layers) { + try { + return gpt2_model_load(model_file, vocab, n_gpu_layers); + } + catch (const std::exception & e) { + fprintf(stderr, "%s: failed to load model: %s\n", __func__, e.what()); + exit(1); + } + }; + + gpt2_model model = load_model(params.model, vocab, params.n_gpu_layers); + + t_load_us = ggml_time_us() - t_start_us; + + test_gpt_tokenizer(vocab, params.token_test); + + // keep this buffer alive while evaluating the model + ggml::backend_buffer buf_compute; + + ggml::allocr allocr; + // allocate the compute buffer + { + // alignment required by the backend + size_t align = model.backend.get_alignment(); + allocr = ggml::allocr::new_measure(align); + + // create the worst case graph for memory usage estimation + int n_tokens = std::min(model.hparams.n_ctx, params.n_batch); + int n_past = model.hparams.n_ctx - n_tokens; + ggml::graph gf = gpt2_graph(model, allocr, n_past, std::vector(n_tokens, 0)); + + // compute the required memory + size_t mem_size = allocr.alloc_graph(gf); + + // recreate the allocator with the required memory + buf_compute = model.backend.alloc_buffer(mem_size); + allocr = ggml::allocr(buf_compute); + + fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0); + } + + int n_past = 0; + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); + + params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); + + printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, embd_inp.size()); + for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) { + printf("%d ", embd_inp[i]); + } + printf("\n\n"); + + // submit the input prompt token-by-token + // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning + std::vector embd; + + for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { + // predict + if (embd.size() > 0) { + const int64_t t_start_us = ggml_time_us(); + + if (!gpt2_eval(model, allocr, params.n_threads, n_past, embd, logits)) { + printf("Failed to predict\n"); + return 1; + } + + t_predict_us += ggml_time_us() - t_start_us; + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + // sample next token + const int top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + + const int n_vocab = model.hparams.n_vocab; + + gpt_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ggml_time_us(); + + id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); + + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (size_t k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (int32_t(embd.size()) >= params.n_batch) { + break; + } + } + i += embd.size() - 1; + } + + // display text + for (auto id : embd) { + printf("%s", vocab.id_to_token[id].c_str()); + } + fflush(stdout); + + // end of text token + if (embd.back() == 50256) { + break; + } + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n\n"); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); + printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); + printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + return 0; +} diff --git a/include/ggml/ggml-alloc.h b/include/ggml/ggml-alloc.h index e38758878..882045317 100644 --- a/include/ggml/ggml-alloc.h +++ b/include/ggml/ggml-alloc.h @@ -8,6 +8,8 @@ extern "C" { struct ggml_backend_buffer; +typedef struct ggml_allocr * ggml_allocr_t; + GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment); GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment); GGML_API struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer); diff --git a/src/ggml-alloc.c b/src/ggml-alloc.c index 34eba3f83..dba20a2d9 100644 --- a/src/ggml-alloc.c +++ b/src/ggml-alloc.c @@ -295,7 +295,7 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) } struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { - struct ggml_allocr * alloc = ggml_allocr_new((void *)0x1000, (size_t)-0x1001, alignment); + struct ggml_allocr * alloc = ggml_allocr_new((void *)0x1000, (size_t)SIZE_MAX / 2, alignment); alloc->measure = true; return alloc; @@ -327,6 +327,9 @@ struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * bu } void ggml_allocr_free(struct ggml_allocr * alloc) { + if (alloc == NULL) { + return; + } if (alloc->buffer_owned) { ggml_backend_buffer_free(alloc->buffer); } diff --git a/src/ggml-backend.c b/src/ggml-backend.c index ca8d83daf..1f68634e0 100644 --- a/src/ggml-backend.c +++ b/src/ggml-backend.c @@ -33,6 +33,9 @@ ggml_backend_buffer_t ggml_backend_buffer_init( } void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { + if (buffer == NULL) { + return; + } if (buffer->iface.free_buffer != NULL) { buffer->iface.free_buffer(buffer); } @@ -81,6 +84,9 @@ const char * ggml_backend_name(ggml_backend_t backend) { } void ggml_backend_free(ggml_backend_t backend) { + if (backend == NULL) { + return; + } backend->iface.free(backend); } @@ -119,6 +125,9 @@ ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, } void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + if (plan == NULL) { + return; + } backend->iface.graph_plan_free(backend, plan); } From dd9d14e55fd61530f5099be89dded0d8574e75fd Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 15 Oct 2023 13:50:24 +0200 Subject: [PATCH 2/3] make the context stack thread_local --- examples/ggml-cpp/ggml-cpp.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ggml-cpp/ggml-cpp.h b/examples/ggml-cpp/ggml-cpp.h index ade8e6dee..b96b82ea8 100644 --- a/examples/ggml-cpp/ggml-cpp.h +++ b/examples/ggml-cpp/ggml-cpp.h @@ -58,7 +58,7 @@ namespace ggml { }; inline ctx_stack & get_ctx_stack() { - static ctx_stack s; + static thread_local ctx_stack s; return s; } From a328beae4538722d8ce49163db52028ebde6e89b Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 15 Oct 2023 13:53:05 +0200 Subject: [PATCH 3/3] avoid specifiying the namespace in function calls (ADL) --- examples/ggml-cpp/gpt-2-cpp.cpp | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/ggml-cpp/gpt-2-cpp.cpp b/examples/ggml-cpp/gpt-2-cpp.cpp index d5793bdc5..7435d9bd6 100644 --- a/examples/ggml-cpp/gpt-2-cpp.cpp +++ b/examples/ggml-cpp/gpt-2-cpp.cpp @@ -477,7 +477,7 @@ struct ggml::graph gpt2_graph( } // wte + wpe - ggml::tensor inpL = ggml::get_rows(model.wte, embd) + ggml::get_rows(model.wpe, position); + ggml::tensor inpL = get_rows(model.wte, embd) + get_rows(model.wpe, position); for (int il = 0; il < n_layer; ++il) { ggml::tensor cur; @@ -485,7 +485,7 @@ struct ggml::graph gpt2_graph( // norm { // [ 768, N] - cur = ggml::norm(inpL, hparams.eps); + cur = norm(inpL, hparams.eps); // [ 768, N] cur = cur*model.layers[il].ln_1_g + model.layers[il].ln_1_b; @@ -500,7 +500,7 @@ struct ggml::graph gpt2_graph( // cur = attn_w*cur + attn_b // [2304, N] { - cur = ggml::mul_mat(model.layers[il].c_attn_attn_w, cur) + model.layers[il].c_attn_attn_b; + cur = mul_mat(model.layers[il].c_attn_attn_w, cur) + model.layers[il].c_attn_attn_b; } // self-attention @@ -547,7 +547,7 @@ struct ggml::graph gpt2_graph( // K * Q // [n_past + N, N, 12] - ggml::tensor KQ = ggml::mul_mat(K, Q); + ggml::tensor KQ = mul_mat(K, Q); // KQ_scaled = KQ / sqrt(n_embd/n_head) // [n_past + N, N, 12] @@ -555,11 +555,11 @@ struct ggml::graph gpt2_graph( // KQ_masked = mask_past(KQ_scaled) // [n_past + N, N, 12] - ggml::tensor KQ_masked = ggml::diag_mask_inf(KQ_scaled, n_past); + ggml::tensor KQ_masked = diag_mask_inf(KQ_scaled, n_past); // KQ = soft_max(KQ_masked) // [n_past + N, N, 12] - ggml::tensor KQ_soft_max = ggml::soft_max(KQ_masked); + ggml::tensor KQ_soft_max = soft_max(KQ_masked); // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() // [n_past + N, 64, 12] @@ -570,7 +570,7 @@ struct ggml::graph gpt2_graph( // KQV = transpose(V) * KQ_soft_max // [64, N, 12] - ggml::tensor KQV = ggml::mul_mat(V_trans, KQ_soft_max); + ggml::tensor KQV = mul_mat(V_trans, KQ_soft_max); // KQV_merged = KQV.permute(0, 2, 1, 3) // [64, 12, N] @@ -590,7 +590,7 @@ struct ggml::graph gpt2_graph( // cur = proj_w*cur + proj_b // [768, N] { - cur = ggml::mul_mat(model.layers[il].c_attn_proj_w, cur) + model.layers[il].c_attn_proj_b; + cur = mul_mat(model.layers[il].c_attn_proj_w, cur) + model.layers[il].c_attn_proj_b; } // add the input @@ -602,7 +602,7 @@ struct ggml::graph gpt2_graph( { // norm { - cur = ggml::norm(inpFF, hparams.eps); + cur = norm(inpFF, hparams.eps); // cur = ln_2_g*cur + ln_2_b // [ 768, N] @@ -617,11 +617,11 @@ struct ggml::graph gpt2_graph( // // cur = fc_w*cur + fc_b // [3072, N] - cur = ggml::mul_mat(model.layers[il].c_mlp_fc_w, cur) + model.layers[il].c_mlp_fc_b; + cur = mul_mat(model.layers[il].c_mlp_fc_w, cur) + model.layers[il].c_mlp_fc_b; // GELU activation // [3072, N] - cur = ggml::gelu(cur); + cur = gelu(cur); // projection // [ 768, 3072] - model.layers[il].c_mlp_proj_w @@ -631,7 +631,7 @@ struct ggml::graph gpt2_graph( // // cur = proj_w*cur + proj_b // [768, N] - cur = ggml::mul_mat(model.layers[il].c_mlp_proj_w, cur) + model.layers[il].c_mlp_proj_b; + cur = mul_mat(model.layers[il].c_mlp_proj_w, cur) + model.layers[il].c_mlp_proj_b; } // input for next layer @@ -641,7 +641,7 @@ struct ggml::graph gpt2_graph( // norm { // [ 768, N] - inpL = ggml::norm(inpL, hparams.eps); + inpL = norm(inpL, hparams.eps); // inpL = ln_f_g*inpL + ln_f_b // [ 768, N] @@ -651,10 +651,10 @@ struct ggml::graph gpt2_graph( // inpL = WTE * inpL // [ 768, 50257] - model.lm_head // [ 768, N] - inpL - inpL = ggml::mul_mat(model.lm_head, inpL); + inpL = mul_mat(model.lm_head, inpL); // logits -> probs - //inpL = ggml::soft_max(inpL); + //inpL = soft_max(inpL); gf.expand(inpL);