Skip to content

Commit

Permalink
Miscellaneous pre-stream cleanup
Browse files Browse the repository at this point in the history
- Remove BLAS compaction; would rather implement it on stream and
  integrated into buildBLAS
- Add lodIndex in case we want to adjust BLAS LODs
- Add geometry buffer size info
  • Loading branch information
zeux committed Nov 30, 2024
1 parent d7afecc commit 60d1d02
Showing 1 changed file with 12 additions and 97 deletions.
109 changes: 12 additions & 97 deletions src/niagara.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ void buildBLAS(VkDevice device, const std::vector<Mesh>& meshes, const Buffer& v
const size_t kDefaultScratch = 32 * 1024 * 1024; // 32 MB scratch by default

size_t totalAccelerationSize = 0;
size_t totalPrimitiveCount = 0;
size_t maxScratchSize = 0;

std::vector<size_t> accelerationOffsets(meshes.size());
Expand All @@ -151,7 +152,9 @@ void buildBLAS(VkDevice device, const std::vector<Mesh>& meshes, const Buffer& v
VkAccelerationStructureGeometryKHR& geo = geometries[i];
VkAccelerationStructureBuildGeometryInfoKHR& buildInfo = buildInfos[i];

primitiveCounts[i] = mesh.lods[0].indexCount / 3;
unsigned int lodIndex = 0;

primitiveCounts[i] = mesh.lods[lodIndex].indexCount / 3;

geo.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR;
geo.geometryType = VK_GEOMETRY_TYPE_TRIANGLES_KHR;
Expand All @@ -165,7 +168,7 @@ void buildBLAS(VkDevice device, const std::vector<Mesh>& meshes, const Buffer& v
geo.geometry.triangles.vertexStride = sizeof(Vertex);
geo.geometry.triangles.maxVertex = mesh.vertexCount - 1;
geo.geometry.triangles.indexType = VK_INDEX_TYPE_UINT32;
geo.geometry.triangles.indexData.deviceAddress = ibAddress + mesh.lods[0].indexOffset * sizeof(uint32_t);
geo.geometry.triangles.indexData.deviceAddress = ibAddress + mesh.lods[lodIndex].indexOffset * sizeof(uint32_t);

buildInfo.sType = VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR;
buildInfo.type = VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR;
Expand All @@ -182,7 +185,7 @@ void buildBLAS(VkDevice device, const std::vector<Mesh>& meshes, const Buffer& v
scratchSizes[i] = sizeInfo.buildScratchSize;

totalAccelerationSize = (totalAccelerationSize + sizeInfo.accelerationStructureSize + kAlignment - 1) & ~(kAlignment - 1);

totalPrimitiveCount += primitiveCounts[i];
maxScratchSize = std::max(maxScratchSize, size_t(sizeInfo.buildScratchSize));
}

Expand All @@ -191,7 +194,7 @@ void buildBLAS(VkDevice device, const std::vector<Mesh>& meshes, const Buffer& v
Buffer scratchBuffer;
createBuffer(scratchBuffer, device, memoryProperties, std::max(kDefaultScratch, maxScratchSize), VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);

printf("BLAS accelerationStructureSize: %.2f MB, scratchSize: %.2f MB (max %.2f MB)\n", double(totalAccelerationSize) / 1e6, double(scratchBuffer.size) / 1e6, double(maxScratchSize) / 1e6);
printf("BLAS accelerationStructureSize: %.2f MB, scratchSize: %.2f MB (max %.2f MB), %.3fM triangles\n", double(totalAccelerationSize) / 1e6, double(scratchBuffer.size) / 1e6, double(maxScratchSize) / 1e6, double(totalPrimitiveCount) / 1e6);

VkDeviceAddress scratchAddress = getBufferAddress(scratchBuffer, device);

Expand Down Expand Up @@ -258,95 +261,6 @@ void buildBLAS(VkDevice device, const std::vector<Mesh>& meshes, const Buffer& v
destroyBuffer(scratchBuffer, device);
}

void compactBLAS(VkDevice device, std::vector<VkAccelerationStructureKHR>& blas, Buffer& blasBuffer, VkCommandPool commandPool, VkCommandBuffer commandBuffer, VkQueue queue, const VkPhysicalDeviceMemoryProperties& memoryProperties)
{
const size_t kAlignment = 256; // required by spec for acceleration structures

VkQueryPoolCreateInfo createInfo = { VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO };
createInfo.queryType = VK_QUERY_TYPE_ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR;
createInfo.queryCount = blas.size();

VkQueryPool queryPool = 0;
VK_CHECK(vkCreateQueryPool(device, &createInfo, 0, &queryPool));

VK_CHECK(vkResetCommandPool(device, commandPool, 0));

VkCommandBufferBeginInfo beginInfo = { VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO };
beginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;

VK_CHECK(vkBeginCommandBuffer(commandBuffer, &beginInfo));

vkCmdResetQueryPool(commandBuffer, queryPool, 0, blas.size());
vkCmdWriteAccelerationStructuresPropertiesKHR(commandBuffer, blas.size(), blas.data(), VK_QUERY_TYPE_ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR, queryPool, 0);

VK_CHECK(vkEndCommandBuffer(commandBuffer));

VkSubmitInfo submitInfo = { VK_STRUCTURE_TYPE_SUBMIT_INFO };
submitInfo.commandBufferCount = 1;
submitInfo.pCommandBuffers = &commandBuffer;

VK_CHECK(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE));
VK_CHECK(vkDeviceWaitIdle(device));

std::vector<VkDeviceSize> compactedSizes(blas.size());

VK_CHECK(vkGetQueryPoolResults(device, queryPool, 0, blas.size(), blas.size() * sizeof(VkDeviceSize), compactedSizes.data(), sizeof(VkDeviceSize), VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT));
vkDestroyQueryPool(device, queryPool, 0);

size_t totalCompactedSize = 0;
std::vector<size_t> compactedOffsets(blas.size());

for (size_t i = 0; i < blas.size(); ++i)
{
compactedOffsets[i] = totalCompactedSize;
totalCompactedSize = (totalCompactedSize + compactedSizes[i] + kAlignment - 1) & ~(kAlignment - 1);
}

printf("BLAS compacted accelerationStructureSize: %.2f MB\n", double(totalCompactedSize) / 1e6);

Buffer compactedBuffer;
createBuffer(compactedBuffer, device, memoryProperties, totalCompactedSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);

std::vector<VkAccelerationStructureKHR> compactedBlas(blas.size());

for (size_t i = 0; i < blas.size(); ++i)
{
VkAccelerationStructureCreateInfoKHR accelerationInfo = { VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR };
accelerationInfo.buffer = compactedBuffer.buffer;
accelerationInfo.offset = compactedOffsets[i];
accelerationInfo.size = compactedSizes[i];
accelerationInfo.type = VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR;

VK_CHECK(vkCreateAccelerationStructureKHR(device, &accelerationInfo, nullptr, &compactedBlas[i]));
}

VK_CHECK(vkResetCommandPool(device, commandPool, 0));
VK_CHECK(vkBeginCommandBuffer(commandBuffer, &beginInfo));

for (size_t i = 0; i < blas.size(); ++i)
{
VkCopyAccelerationStructureInfoKHR copyInfo = { VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_INFO_KHR };
copyInfo.src = blas[i];
copyInfo.dst = compactedBlas[i];
copyInfo.mode = VK_COPY_ACCELERATION_STRUCTURE_MODE_COMPACT_KHR;

vkCmdCopyAccelerationStructureKHR(commandBuffer, &copyInfo);
}

VK_CHECK(vkEndCommandBuffer(commandBuffer));
VK_CHECK(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE));
VK_CHECK(vkDeviceWaitIdle(device));

for (size_t i = 0; i < blas.size(); ++i)
{
vkDestroyAccelerationStructureKHR(device, blas[i], nullptr);
blas[i] = compactedBlas[i];
}

destroyBuffer(blasBuffer, device);
blasBuffer = compactedBuffer;
}

VkAccelerationStructureKHR buildTLAS(VkDevice device, Buffer& tlasBuffer, const std::vector<MeshDraw>& draws, const std::vector<VkAccelerationStructureKHR>& blas, VkCommandPool commandPool, VkCommandBuffer commandBuffer, VkQueue queue, const VkPhysicalDeviceMemoryProperties& memoryProperties)
{
Buffer instances;
Expand Down Expand Up @@ -862,6 +776,11 @@ int main(int argc, const char** argv)
return 1;
}

printf("Geometry: VB %.2f MB, IB %.2f MB, meshlets %.2f MB\n",
double(geometry.vertices.size() * sizeof(Vertex)) / 1e6,
double(geometry.indices.size() * sizeof(uint32_t)) / 1e6,
double(geometry.meshlets.size() * sizeof(Meshlet) + geometry.meshletdata.size() * sizeof(uint32_t)) / 1e6);

if (draws.empty())
{
rngstate.state = 0x42;
Expand Down Expand Up @@ -986,10 +905,6 @@ int main(int argc, const char** argv)
if (raytracingSupported)
{
buildBLAS(device, geometry.meshes, vb, ib, blas, blasBuffer, commandPool, commandBuffer, queue, memoryProperties);

if (!fastMode)
compactBLAS(device, blas, blasBuffer, commandPool, commandBuffer, queue, memoryProperties);

tlas = buildTLAS(device, tlasBuffer, draws, blas, commandPool, commandBuffer, queue, memoryProperties);
}

Expand Down

0 comments on commit 60d1d02

Please sign in to comment.