Skip to content

Commit

Permalink
Small code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
rfsaliev committed Jun 17, 2024
1 parent fe450a3 commit 2a22db5
Showing 1 changed file with 76 additions and 109 deletions.
185 changes: 76 additions & 109 deletions src/ggml-dnnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,25 @@ struct ggml_backend_dnnl_buffer_context
std::vector<dnnl::memory> sub_mems;
};

static void* get_memory_handle(const struct ggml_tensor * t) {
auto buf_mem = dnnl::memory{(dnnl_memory_t)t->extra, true};
auto buf_md = buf_mem.get_desc();
auto buf_handle = buf_mem.get_data_handle();
auto buf_offset = buf_md.get_submemory_offset();
GGML_ASSERT((size_t)buf_offset == ((uintptr_t)t->data - DNNL_BUFFER_BASE));
// auto buf_ctx = (ggml_backend_dnnl_buffer_context*)t->buffer->context;
//auto parent_buf_handle = buf_ctx->mem.get_data_handle();

// FIXME: buf_handle + offset works for CPU only
return (char*)buf_handle + buf_offset;
}
#else
static void* get_memory_handle(const struct ggml_tensor * t) {
return t->data;
}
#endif


namespace {
template <class T>
struct dnnl_mem_ptr {
Expand Down Expand Up @@ -91,15 +108,6 @@ static bool ggml_dnnl_type_supported(enum ggml_type type) {

static bool ggml_dnnl_tensor_supported(const struct ggml_tensor * t) {
auto type = t->type;
GGML_TENSOR_LOCALS(int64_t, ne, t, ne)
GGML_TENSOR_LOCALS(size_t, nb, t, nb)


// cannot be transposed or permuted
// GGML_ASSERT(nb0 == ggml_type_size(type));
// GGML_ASSERT(nb0 <= nb1);
// GGML_ASSERT(nb1 <= nb2);
// GGML_ASSERT(nb2 <= nb3);

if (!ggml_dnnl_type_supported(type)) {
return false;
Expand Down Expand Up @@ -166,25 +174,6 @@ dnnl::memory::desc ggml_tensor_to_dnnl_md(const struct ggml_tensor * t, bool tra
return dnnl::memory::desc{adims, dt, strides};
}

#if USE_DNNL_BUFFER
static void* get_memory_handle(const struct ggml_tensor * t) {
auto buf_mem = dnnl::memory{(dnnl_memory_t)t->extra, true};
auto buf_md = buf_mem.get_desc();
auto buf_handle = buf_mem.get_data_handle();
auto buf_offset = buf_md.get_submemory_offset();
GGML_ASSERT((size_t)buf_offset == ((uintptr_t)t->data - DNNL_BUFFER_BASE));
// auto buf_ctx = (ggml_backend_dnnl_buffer_context*)t->buffer->context;
//auto parent_buf_handle = buf_ctx->mem.get_data_handle();

// FIXME: buf_handle + offset works for CPU only
return (char*)buf_handle + buf_offset;
}
#else
static void* get_memory_handle(const struct ggml_tensor * t) {
return t->data;
}
#endif

dnnl::memory ggml_tensor_to_dnnl_mem(ggml_backend_t backend, const struct ggml_tensor * t, bool transpose = false,
dnnl::memory::data_type convert_to = dnnl::memory::data_type::undef,
size_t ndims = GGML_MAX_DIMS) {
Expand Down Expand Up @@ -474,7 +463,7 @@ static ggml_status ggml_backend_dnnl_softmax(ggml_backend_t backend, struct ggml
{DNNL_ARG_DST, src_mem},
});
}
//float alpha = *reinterpret_cast<float*>(dst->op_params);

const int axis = src_mem.get_desc().get_dims().size() - 1;
auto pd = dnnl::softmax_forward::primitive_desc{ctx->engine, dnnl::prop_kind::forward_inference, dnnl::algorithm::softmax_accurate, src_mem.get_desc(), dst_mem.get_desc(), axis};
auto prim = dnnl::softmax_forward{pd};
Expand All @@ -493,7 +482,6 @@ static ggml_status ggml_backend_dnnl_norm(ggml_backend_t backend, struct ggml_te
auto src_mem = ggml_tensor_to_dnnl_mem(backend, src);
auto dst_mem = ggml_tensor_to_dnnl_mem(backend, dst);

//float alpha = *reinterpret_cast<float*>(dst->op_params);
float eps = ((const float *)(dst->op_params))[0];

GGML_ASSERT(eps > 0.0f);
Expand Down Expand Up @@ -805,6 +793,62 @@ static ggml_status ggml_backend_dnnl_node_compute(ggml_backend_t backend, struct
*/
}

static bool ggml_backend_dnnl_node_supported(ggml_backend_t backend, const struct ggml_tensor * node) {
GGML_UNUSED(backend);
// return false;
switch (node->op) {
case GGML_OP_NONE:
return true;
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
case GGML_OP_CONT:
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_SCALE:
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_GET_ROWS:
return ggml_dnnl_tensor_supported(node) && ggml_dnnl_tensor_supported(node->src[0]);
case GGML_OP_MUL_MAT:
return ggml_compute_forward_mul_mat_use_dnnl(node);
case GGML_OP_UNARY:
{
enum ggml_unary_op uop = ggml_get_unary_op(node);
switch(uop) {
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
return ggml_dnnl_tensor_supported(node) && ggml_dnnl_tensor_supported(node->src[0]);
default:
// GGML_UNARY_OP_SGN,
// GGML_UNARY_OP_NEG,
// GGML_UNARY_OP_STEP,
// GGML_UNARY_OP_SILU,
return false;
}
}

default:
return false;
}
}

// buffer interface

#if USE_DNNL_BUFFER
Expand Down Expand Up @@ -949,23 +993,6 @@ GGML_CALL static const char * ggml_backend_dnnl_buffer_type_get_name(ggml_backen
GGML_UNUSED(buft);
}

// GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
// size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
// void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h)
// if (data == NULL) {
// fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
// return NULL;
// }

// return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
// }

// GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
// return TENSOR_ALIGNMENT;

// GGML_UNUSED(buft);
// }

GGML_CALL static bool ggml_backend_dnnl_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
return ggml_backend_is_dnnl(backend) || ggml_backend_is_cpu(backend);

Expand Down Expand Up @@ -1037,67 +1064,7 @@ GGML_CALL static ggml_status ggml_backend_dnnl_graph_compute(ggml_backend_t back
}

GGML_CALL static bool ggml_backend_dnnl_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
GGML_UNUSED(backend);
// return false;
switch (op->op) {
case GGML_OP_NONE:
return true;
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
case GGML_OP_CONT:
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_SCALE:
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_GET_ROWS:
return ggml_dnnl_tensor_supported(op) && ggml_dnnl_tensor_supported(op->src[0]);
case GGML_OP_MUL_MAT:
return ggml_compute_forward_mul_mat_use_dnnl(op);
case GGML_OP_UNARY:
{
enum ggml_unary_op uop = ggml_get_unary_op(op);
switch(uop) {
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
return ggml_dnnl_tensor_supported(op) && ggml_dnnl_tensor_supported(op->src[0]);
default:
// GGML_UNARY_OP_SGN,
// GGML_UNARY_OP_NEG,
// GGML_UNARY_OP_STEP,
// GGML_UNARY_OP_SILU,
return false;
}
}

default:
return false;
}
}

GGML_CALL static bool ggml_backend_dnnl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
// const int min_batch_size = 32;

return ggml_backend_dnnl_supports_op(backend, op);

// GGML_UNUSED(backend);
return ggml_backend_dnnl_node_supported(backend, op);
}

static struct ggml_backend_i dnnl_backend_i = {
Expand All @@ -1113,7 +1080,7 @@ static struct ggml_backend_i dnnl_backend_i = {
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_dnnl_graph_compute,
/* .supports_op = */ ggml_backend_dnnl_supports_op,
/* .offload_op = */ NULL, //ggml_backend_dnnl_offload_op,
/* .offload_op = */ NULL,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_record = */ NULL,
Expand Down

0 comments on commit 2a22db5

Please sign in to comment.