Skip to content

Commit

Permalink
thrust/mr: fix the case of reuising a block for a smaller alloc.
Browse files Browse the repository at this point in the history
Previously, the pool happily returned a pointer to a larger oversized
block than requested, without storing the information that the block is
now smaller, which meant that on deallocation, it'd look for the
descriptor of the block in the wrong place. This is now fixed by moving
the descriptor to always be where deallocation can find it using the
user-provided size, and by storing the original size to restore the
descriptor to its rightful place when deallocating.

Also a drive-by fix for a bug where in certain cases the reallocated
cached oversized block wasn't removed from the cached list. Whoops.
Kinda surprised this hasn't exploded before.
  • Loading branch information
griwes committed Dec 19, 2023
1 parent 2165845 commit a6202c6
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 34 deletions.
41 changes: 37 additions & 4 deletions thrust/testing/mr_pool.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,26 @@ public:

virtual tracked_pointer<void> do_allocate(std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) override
{
ASSERT_EQUAL(static_cast<bool>(id_to_allocate), true);
ASSERT_EQUAL(id_to_allocate || id_to_allocate == -1u, true);

void * raw = upstream.do_allocate(n, alignment);
tracked_pointer<void> ret(raw);
ret.id = id_to_allocate;
ret.size = n;
ret.alignment = alignment;

id_to_allocate = 0;
if (id_to_allocate != -1u)
{
id_to_allocate = 0;
}

return ret;
}

virtual void do_deallocate(tracked_pointer<void> p, std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) override
{
ASSERT_EQUAL(p.size, n);
ASSERT_EQUAL(p.alignment, alignment);
ASSERT_GEQUAL(p.size, n);
ASSERT_GEQUAL(p.alignment, alignment);

if (id_to_deallocate != 0)
{
Expand Down Expand Up @@ -318,6 +321,36 @@ void TestPoolCachingOversized()
upstream.id_to_allocate = 7;
tracked_pointer<void> a9 = pool.do_allocate(2048, 32);
ASSERT_EQUAL(a9.id, 7u);

// make sure that reusing a larger oversized block for a smaller allocation works
// this is NVIDIA/cccl#585
upstream.id_to_allocate = 8;
tracked_pointer<void> a10 = pool.do_allocate(2048 + 16, THRUST_MR_DEFAULT_ALIGNMENT);
pool.do_deallocate(a10, 2048 + 16, THRUST_MR_DEFAULT_ALIGNMENT);
tracked_pointer<void> a11 = pool.do_allocate(2048, THRUST_MR_DEFAULT_ALIGNMENT);
ASSERT_EQUAL(a11.ptr, a10.ptr);
pool.do_deallocate(a11, 2048, THRUST_MR_DEFAULT_ALIGNMENT);

// original minimized reproducer from NVIDIA/cccl#585:
{
upstream.id_to_allocate = -1u;

auto ptr1 = pool.allocate(43920240);
auto ptr2 = pool.allocate(2465264);
pool.deallocate(ptr1, 43920240);
pool.deallocate(ptr2, 2465264);
auto ptr3 = pool.allocate(4930528);
pool.deallocate(ptr3, 4930528);
auto ptr4 = pool.allocate(14640080);
std::memset(thrust::raw_pointer_cast(ptr4), 0xff, 14640080);

auto crash = pool.allocate(4930528);

pool.deallocate(crash, 4930528);
pool.deallocate(ptr4, 14640080);

upstream.id_to_allocate = 0;
}
}

void TestUnsynchronizedPoolCachingOversized()
Expand Down
97 changes: 67 additions & 30 deletions thrust/thrust/mr/pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class unsynchronized_pool_resource final
oversized_block_descriptor_ptr prev;
oversized_block_descriptor_ptr next;
oversized_block_descriptor_ptr next_cached;
std::size_t current_size;
};

struct pool
Expand Down Expand Up @@ -249,12 +250,15 @@ class unsynchronized_pool_resource final
oversized_block_descriptor_ptr alloc = m_oversized;
m_oversized = thrust::raw_reference_cast(*m_oversized).next;

oversized_block_descriptor desc =
thrust::raw_reference_cast(*alloc);

void_ptr p = static_cast<void_ptr>(
static_cast<char_ptr>(
static_cast<void_ptr>(alloc)
) - thrust::raw_reference_cast(*alloc).size
);
m_upstream->do_deallocate(p, thrust::raw_reference_cast(*alloc).size + sizeof(oversized_block_descriptor), thrust::raw_reference_cast(*alloc).alignment);
static_cast<char_ptr>(static_cast<void_ptr>(alloc)) -
desc.current_size);
m_upstream->do_deallocate(
p, desc.size + sizeof(oversized_block_descriptor),
desc.alignment);
}

m_cached_oversized = oversized_block_descriptor_ptr();
Expand Down Expand Up @@ -305,23 +309,43 @@ class unsynchronized_pool_resource final
{
if (previous != &m_cached_oversized)
{
oversized_block_descriptor previous_desc = **previous;
previous_desc.next_cached = desc.next_cached;
**previous = previous_desc;
*previous = desc.next_cached;
}
else
{
m_cached_oversized = desc.next_cached;
}

desc.next_cached = oversized_block_descriptor_ptr();

auto ret =
static_cast<char_ptr>(static_cast<void_ptr>(ptr)) -
desc.size;

if (bytes != desc.size) {
desc.current_size = bytes;

ptr = static_cast<oversized_block_descriptor_ptr>(
static_cast<void_ptr>(ret + bytes));

if (detail::pointer_traits<
oversized_block_descriptor_ptr>::
get(desc.prev)) {
thrust::raw_reference_cast(*desc.prev).next = ptr;
} else {
m_oversized = ptr;
}

if (detail::pointer_traits<
oversized_block_descriptor_ptr>::
get(desc.next)) {
thrust::raw_reference_cast(*desc.next).prev = ptr;
}
}

*ptr = desc;

return static_cast<void_ptr>(
static_cast<char_ptr>(
static_cast<void_ptr>(ptr)
) - desc.size
);
return static_cast<void_ptr>(ret);
}

previous = &thrust::raw_reference_cast(*ptr).next_cached;
Expand All @@ -343,6 +367,7 @@ class unsynchronized_pool_resource final
desc.prev = oversized_block_descriptor_ptr();
desc.next = m_oversized;
desc.next_cached = oversized_block_descriptor_ptr();
desc.current_size = bytes;
*block = desc;
m_oversized = block;

Expand Down Expand Up @@ -451,35 +476,47 @@ class unsynchronized_pool_resource final
);

oversized_block_descriptor desc = *block;
assert(desc.current_size == n);
assert(desc.alignment == alignment);

if (m_options.cache_oversized)
{
desc.next_cached = m_cached_oversized;
*block = desc;

if (desc.size != n) {
desc.current_size = desc.size;
block = static_cast<oversized_block_descriptor_ptr>(
static_cast<void_ptr>(static_cast<char_ptr>(p) +
desc.size));
if (detail::pointer_traits<
oversized_block_descriptor_ptr>::get(desc.prev)) {
thrust::raw_reference_cast(*desc.prev).next = block;
} else {
m_oversized = block;
}

if (detail::pointer_traits<
oversized_block_descriptor_ptr>::get(desc.next)) {
thrust::raw_reference_cast(*desc.next).prev = block;
}
}

m_cached_oversized = block;
*block = desc;

return;
}

if (!detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.prev))
{
assert(m_oversized == block);
if (detail::pointer_traits<oversized_block_descriptor_ptr>::get(
desc.prev)) {
thrust::raw_reference_cast(*desc.prev).next = desc.next;
} else {
m_oversized = desc.next;
}
else
{
oversized_block_descriptor prev = *desc.prev;
assert(prev.next == block);
prev.next = desc.next;
*desc.prev = prev;
}

if (detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.next))
{
oversized_block_descriptor next = *desc.next;
assert(next.prev == block);
next.prev = desc.prev;
*desc.next = next;
if (detail::pointer_traits<oversized_block_descriptor_ptr>::get(
desc.next)) {
thrust::raw_reference_cast(*desc.next).prev = desc.prev;
}

m_upstream->do_deallocate(p, desc.size + sizeof(oversized_block_descriptor), desc.alignment);
Expand Down

0 comments on commit a6202c6

Please sign in to comment.