Skip to content

Commit

Permalink
address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 18, 2024
1 parent 669158e commit 48f1ddf
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 28 deletions.
44 changes: 23 additions & 21 deletions graphbolt/src/cuda/gpu_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,42 @@ GpuCache::GpuCache(const std::vector<int64_t> &shape, torch::ScalarType dtype) {
std::accumulate(shape.begin() + 1, shape.end(), 1ll, std::multiplies<>());
const int element_size =
torch::empty(1, torch::TensorOptions().dtype(dtype)).element_size();
num_bytes = num_feats * element_size;
num_float_feats = (num_bytes + sizeof(float) - 1) / sizeof(float);
cache = std::make_unique<gpu_cache_t>(
(num_items + bucket_size - 1) / bucket_size, num_float_feats);
this->shape = shape;
this->shape[0] = -1;
this->dtype = dtype;
device_id = cuda::GetCurrentStream().device_index();
num_bytes_ = num_feats * element_size;
num_float_feats_ = (num_bytes_ + sizeof(float) - 1) / sizeof(float);
cache_ = std::make_unique<gpu_cache_t>(
(num_items + bucket_size - 1) / bucket_size, num_float_feats_);
shape_ = shape;
shape_[0] = -1;
dtype_ = dtype;
device_id_ = cuda::GetCurrentStream().device_index();
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
torch::Tensor keys) {
TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device.");
TORCH_CHECK(
keys.device().index() == device_id,
keys.device().index() == device_id_,
"Keys should be on the correct CUDA device.");
TORCH_CHECK(keys.sizes().size() == 1, "Keys should be a 1D tensor.");
keys = keys.to(torch::kLong);
auto values = torch::empty(
{keys.size(0), num_float_feats}, keys.options().dtype(torch::kFloat));
{keys.size(0), num_float_feats_}, keys.options().dtype(torch::kFloat));
auto missing_index =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
auto missing_keys =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
cuda::CopyScalar<size_t> missing_len;
auto stream = cuda::GetCurrentStream();
cache->Query(
cache_->Query(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
values.data_ptr<float>(),
reinterpret_cast<uint64_t *>(missing_index.data_ptr()),
reinterpret_cast<key_t *>(missing_keys.data_ptr()), missing_len.get(),
stream);
values =
values.view(torch::kByte).slice(1, 0, num_bytes).view(dtype).view(shape);
values = values.view(torch::kByte)
.slice(1, 0, num_bytes_)
.view(dtype_)
.view(shape_);
// To safely read missing_len, we synchronize
stream.synchronize();
missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len));
Expand All @@ -63,36 +65,36 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) {
TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device.");
TORCH_CHECK(
keys.device().index() == device_id,
keys.device().index() == device_id_,
"Keys should be on the correct CUDA device.");
TORCH_CHECK(values.device().is_cuda(), "Keys should be on a CUDA device.");
TORCH_CHECK(
values.device().index() == device_id,
values.device().index() == device_id_,
"Values should be on the correct CUDA device.");
TORCH_CHECK(
keys.size(0) == values.size(0),
"The first dimensions of keys and values must match.");
TORCH_CHECK(
std::equal(shape.begin() + 1, shape.end(), values.sizes().begin() + 1),
std::equal(shape_.begin() + 1, shape_.end(), values.sizes().begin() + 1),
"Values should have the correct dimensions.");
TORCH_CHECK(
values.scalar_type() == dtype, "Values should have the correct dtype.");
values.scalar_type() == dtype_, "Values should have the correct dtype.");
keys = keys.to(torch::kLong);
torch::Tensor float_values;
if (num_bytes % sizeof(float) != 0) {
if (num_bytes_ % sizeof(float) != 0) {
float_values = torch::empty(
{values.size(0), num_float_feats},
{values.size(0), num_float_feats_},
values.options().dtype(torch::kFloat));
float_values.view(torch::kByte)
.slice(1, 0, num_bytes)
.slice(1, 0, num_bytes_)
.copy_(values.view(torch::kByte).view({values.size(0), -1}));
} else {
float_values = values.view(torch::kByte)
.view({values.size(0), -1})
.view(torch::kFloat)
.contiguous();
}
cache->Replace(
cache_->Replace(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
float_values.data_ptr<float>(), cuda::GetCurrentStream());
}
Expand Down
16 changes: 9 additions & 7 deletions graphbolt/src/cuda/gpu_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,19 @@ class GpuCache : public torch::CustomClassHolder {
const std::vector<int64_t>& shape, torch::ScalarType dtype);

private:
std::vector<int64_t> shape;
torch::ScalarType dtype;
std::unique_ptr<gpu_cache_t> cache;
int64_t num_bytes;
int64_t num_float_feats;
torch::DeviceIndex device_id;
std::vector<int64_t> shape_;
torch::ScalarType dtype_;
std::unique_ptr<gpu_cache_t> cache_;
int64_t num_bytes_;
int64_t num_float_feats_;
torch::DeviceIndex device_id_;
};

// The cu file in HugeCTR gpu cache uses unsigned int and long long.
// Changing to int64_t results in a mismatch of template arguments.
static_assert(sizeof(long long) == 8); // NOLINT
static_assert(
sizeof(long long) == sizeof(int64_t),
"long long and int64_t needs to have the same size."); // NOLINT

} // namespace cuda
} // namespace graphbolt
Expand Down

0 comments on commit 48f1ddf

Please sign in to comment.