Skip to content

Commit

Permalink
Work around change in cuTensorMapEncode (#1567)
Browse files Browse the repository at this point in the history
Previously we were able to pass in `nullptr` in case of empty strides. However, they added some bound checking which fails in that case.

Fixes nvbug4575531
  • Loading branch information
miscco authored Mar 26, 2024
1 parent 95a6620 commit 3a80d7d
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,11 @@ CUtensorMap map_encode(T *tensor_ptr, const cuda::std::array<uint64_t, num_dims>

// The stride is the number of bytes to traverse from the first element of one row to the next.
// It must be a multiple of 16.
constexpr int num_strides = num_dims - 1;
cuda::std::array<uint64_t, num_strides> stride;
// cuTensorMapEncodeTiled requies that the stride array is a valid pointer, so we add one superfluous element
// This is necessary for num_dims == 1
cuda::std::array<uint64_t, num_dims> stride;
uint64_t base_stride = sizeof(T);
for (size_t i = 0; i < stride.size(); ++i) {
for (size_t i = 0; i < stride.size() - 1; ++i) {
base_stride *= gmem_dims[i];
stride[i] = base_stride;
}
Expand Down

0 comments on commit 3a80d7d

Please sign in to comment.