diff --git a/src/ggml-dnnl.cpp b/src/ggml-dnnl.cpp index b43b452693..031b3eae8a 100644 --- a/src/ggml-dnnl.cpp +++ b/src/ggml-dnnl.cpp @@ -877,7 +877,7 @@ static void ggml_backend_dnnl_buffer_init_tensor(ggml_backend_buffer_t buffer, g if (tensor->view_src != NULL && tensor->view_offs == 0) { tensor->extra = tensor->view_src->extra; } else { - printf(" op:%s-'%s'\n", ggml_op_desc(tensor), tensor->name); + //printf(" op:%s-'%s'\n", ggml_op_desc(tensor), tensor->name); auto buf = tensor->view_src != NULL ? tensor->view_src->buffer : tensor->buffer; ggml_backend_dnnl_buffer_context* ctx = (ggml_backend_dnnl_buffer_context*)buf->context; dnnl::memory::dim offset = (uintptr_t)tensor->data - DNNL_BUFFER_BASE; @@ -976,7 +976,6 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_dnnl_get_default_buffer /* .get_alignment = */ ggml_backend_dnnl_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .supports_backend = */ ggml_backend_dnnl_buffer_type_supports_backend, /* .is_host = */ NULL, // defaults to false // ggml_backend_dnnl_buffer_type_is_host, }, /* .context = */ backend->context, @@ -993,12 +992,6 @@ GGML_CALL static const char * ggml_backend_dnnl_buffer_type_get_name(ggml_backen GGML_UNUSED(buft); } -GGML_CALL static bool ggml_backend_dnnl_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { - return ggml_backend_is_dnnl(backend) || ggml_backend_is_cpu(backend); - - GGML_UNUSED(buft); -} - GGML_CALL static bool ggml_backend_dnnl_buffer_type_is_host(ggml_backend_buffer_type_t buft) { return true; @@ -1015,7 +1008,6 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_dnnl_get_default_buffer /* .get_alignment = */ cpu_buffer_type->iface.get_alignment, // ggml_backend_cpu_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .supports_backend = */ ggml_backend_dnnl_buffer_type_supports_backend, /* .is_host = */ ggml_backend_dnnl_buffer_type_is_host, }, /* .context = */ NULL, @@ -1067,6 +1059,20 @@ GGML_CALL static bool ggml_backend_dnnl_supports_op(ggml_backend_t backend, cons return ggml_backend_dnnl_node_supported(backend, op); } +static bool ggml_backend_buft_is_dnnl(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_dnnl_buffer_type_get_name; +} + +GGML_CALL static bool ggml_backend_dnnl_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + GGML_UNUSED(backend); +#if USE_DNNL_BUFFER + return ggml_backend_buft_is_dnnl(buft); +#else + return ggml_backend_buft_is_dnnl(buft) || ggml_backend_buft_is_host(buft); +#endif +} + + static struct ggml_backend_i dnnl_backend_i = { /* .get_name = */ ggml_backend_dnnl_name, /* .free = */ ggml_backend_dnnl_free, @@ -1077,15 +1083,17 @@ static struct ggml_backend_i dnnl_backend_i = { /* .synchronize = */ ggml_backend_dnnl_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_dnnl_graph_compute, /* .supports_op = */ ggml_backend_dnnl_supports_op, + /* .supports_buft = */ ggml_backend_dnnl_supports_buft, /* .offload_op = */ NULL, /* .event_new = */ NULL, /* .event_free = */ NULL, /* .event_record = */ NULL, /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_synchronize = */ NULL }; static ggml_guid_t ggml_backend_dnnl_guid() {