Skip to content

Commit

Permalink
[DNNL] Good performance for oneDNN backend
Browse files Browse the repository at this point in the history
* use dnnl::inner_product_forward for 2D matrix multiplication.
* gpt-2 sample hacked to enforce FP32 weights in case of GGML_USE_DNNL
  • Loading branch information
rfsaliev committed Jun 7, 2024
1 parent 50e7dbd commit e2edfce
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 134 deletions.
44 changes: 31 additions & 13 deletions examples/gpt-2/main-sched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
}
}

#ifdef GGML_USE_DNNL
// oneDNN does not support FP16 weights for FP32 source and destination
ggml_type wtype = GGML_TYPE_F32;
#else
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
Expand All @@ -229,6 +233,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
__func__, fname.c_str(), model.hparams.ftype);
return false;
}
#endif

auto & ctx = model.ctx_w;

Expand Down Expand Up @@ -424,6 +429,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
bool has_lm_head = false;

std::vector<char> read_buf;
std::vector<float> conv_buf;

while (true) {
int32_t n_dims;
Expand Down Expand Up @@ -471,35 +477,47 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}

const size_t bpe = ggml_type_size(ggml_type(ttype));
// const size_t bpe = ggml_type_size(ggml_type(ttype));

if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
// if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
// fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
// __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe);
// return false;
// }

// allocate the tensor
ggml_backend_t backend = tensor_backends[name];
ggml_tallocr * alloc = &backend_buffers.find(backend)->second;
ggml_tallocr_alloc(alloc, tensor);
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());

if (ggml_backend_is_cpu(backend)
if (false && ggml_backend_is_cpu(backend)
#ifdef GGML_USE_METAL
|| ggml_backend_is_metal(backend)
#endif
#ifdef GGML_USE_DNNL
|| ggml_backend_is_dnnl(backend)
#endif
) {
// for the CPU and Metal backend, we can read directly into the tensor
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
} else {
// read into a temporary buffer first, then copy to device memory
read_buf.resize(ggml_nbytes(tensor));
fin.read(read_buf.data(), ggml_nbytes(tensor));
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
read_buf.resize(ggml_row_size(ggml_type(ttype), nelements));
fin.read(read_buf.data(), ggml_row_size(ggml_type(ttype), nelements));
void* read_buf_data = read_buf.data();

if (ggml_type(ttype) != tensor->type) {
// Convert data if needed
if (tensor->type != GGML_TYPE_F32) {
printf("%s: tensor '%s' has unsupported type: '%s'", __func__, name.c_str(), ggml_type_name(ggml_type(ttype)));
return false;
}
conv_buf.resize(nelements);
auto tt = ggml_internal_get_type_traits(ggml_type(ttype));
tt.to_float(read_buf.data(), conv_buf.data(), nelements);
assert(ggml_row_size(ggml_type(tensor->type), nelements) == ggml_nbytes(tensor));
read_buf_data = conv_buf.data();
}

ggml_backend_tensor_set(tensor, read_buf_data, 0, ggml_nbytes(tensor));
}

// GPT-2 models share the WTE tensor as the LM head
Expand Down
Loading

0 comments on commit e2edfce

Please sign in to comment.