Skip to content

Commit

Permalink
Use a small dispatch() helper when it's convenient
Browse files Browse the repository at this point in the history
We repeat the program name too much during dispatches, and most of them
have the same structure; a simple dispatch helper can make the code more
concise and less error-prone here.
  • Loading branch information
zeux committed Dec 27, 2024
1 parent c0f9659 commit 17d800b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
48 changes: 25 additions & 23 deletions src/niagara.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@ VkQueryPool createQueryPool(VkDevice device, uint32_t queryCount, VkQueryType qu
return queryPool;
}

template <typename PushConstants, size_t PushDescriptors>
void dispatch(VkCommandBuffer commandBuffer, const Program& program, uint32_t threadCountX, uint32_t threadCountY, const PushConstants& pushConstants, const DescriptorInfo (&pushDescriptors)[PushDescriptors])
{
assert(program.pushConstantSize == sizeof(pushConstants));
assert(program.pushDescriptorCount == PushDescriptors);

if (program.pushConstantStages)
vkCmdPushConstants(commandBuffer, program.layout, program.pushConstantStages, 0, sizeof(pushConstants), &pushConstants);

if (program.pushDescriptorCount)
vkCmdPushDescriptorSetWithTemplateKHR(commandBuffer, program.updateTemplate, program.layout, 0, pushDescriptors);

vkCmdDispatch(commandBuffer, getGroupCount(threadCountX, program.localSizeX), getGroupCount(threadCountY, program.localSizeY), 1);
}

struct MeshDrawCommand
{
uint32_t drawId;
Expand Down Expand Up @@ -1161,10 +1176,8 @@ int main(int argc, const char** argv)

DescriptorInfo pyramidDesc(depthSampler, depthPyramid.imageView, VK_IMAGE_LAYOUT_GENERAL);
DescriptorInfo descriptors[] = { db.buffer, mb.buffer, dcb.buffer, dccb.buffer, dvb.buffer, pyramidDesc };
vkCmdPushDescriptorSetWithTemplateKHR(commandBuffer, drawcullProgram.updateTemplate, drawcullProgram.layout, 0, descriptors);

vkCmdPushConstants(commandBuffer, drawcullProgram.layout, drawcullProgram.pushConstantStages, 0, sizeof(cullData), &passData);
vkCmdDispatch(commandBuffer, getGroupCount(uint32_t(draws.size()), drawcullProgram.localSizeX), 1, 1);
dispatch(commandBuffer, drawcullProgram, uint32_t(draws.size()), 1, passData, descriptors);
}

if (taskSubmit)
Expand Down Expand Up @@ -1372,15 +1385,13 @@ int main(int argc, const char** argv)
: DescriptorInfo(depthSampler, depthPyramidMips[i - 1], VK_IMAGE_LAYOUT_GENERAL);

DescriptorInfo descriptors[] = { { depthPyramidMips[i], VK_IMAGE_LAYOUT_GENERAL }, sourceDepth };
vkCmdPushDescriptorSetWithTemplateKHR(commandBuffer, depthreduceProgram.updateTemplate, depthreduceProgram.layout, 0, descriptors);

uint32_t levelWidth = std::max(1u, depthPyramidWidth >> i);
uint32_t levelHeight = std::max(1u, depthPyramidHeight >> i);

vec4 reduceData = vec4(levelWidth, levelHeight, 0, 0);

vkCmdPushConstants(commandBuffer, depthreduceProgram.layout, depthreduceProgram.pushConstantStages, 0, sizeof(reduceData), &reduceData);
vkCmdDispatch(commandBuffer, getGroupCount(levelWidth, depthreduceProgram.localSizeX), getGroupCount(levelHeight, depthreduceProgram.localSizeY), 1);
dispatch(commandBuffer, depthreduceProgram, levelWidth, levelHeight, reduceData, descriptors);

VkImageMemoryBarrier2 reduceBarrier = imageBarrier(depthPyramid.image,
VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_ACCESS_SHADER_WRITE_BIT, VK_IMAGE_LAYOUT_GENERAL,
Expand Down Expand Up @@ -1480,20 +1491,18 @@ int main(int argc, const char** argv)
{
vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, shadowQuality == 0 ? shadowlqPipeline : shadowhqPipeline);

DescriptorInfo descriptors[] = { { shadowTarget.imageView, VK_IMAGE_LAYOUT_GENERAL }, { readSampler, depthTarget.imageView, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL }, tlas, db.buffer, mb.buffer, mtb.buffer, vb.buffer, ib.buffer, textureSampler };
vkCmdPushDescriptorSetWithTemplateKHR(commandBuffer, shadowProgram.updateTemplate, shadowProgram.layout, 0, descriptors);

vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, shadowProgram.layout, 1, 1, &textureSet.second, 0, nullptr);

DescriptorInfo descriptors[] = { { shadowTarget.imageView, VK_IMAGE_LAYOUT_GENERAL }, { readSampler, depthTarget.imageView, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL }, tlas, db.buffer, mb.buffer, mtb.buffer, vb.buffer, ib.buffer, textureSampler };

ShadowData shadowData = {};
shadowData.sunDirection = sunDirection;
shadowData.sunJitter = shadowblurEnabled ? 1e-2f : 0;
shadowData.inverseViewProjection = inverse(projection * view);
shadowData.imageSize = vec2(float(swapchain.width), float(swapchain.height));
shadowData.checkerboard = shadowCheckerboard;

vkCmdPushConstants(commandBuffer, shadowProgram.layout, shadowProgram.pushConstantStages, 0, sizeof(shadowData), &shadowData);
vkCmdDispatch(commandBuffer, getGroupCount(shadowWidthCB, shadowProgram.localSizeX), getGroupCount(swapchain.height, shadowProgram.localSizeY), 1);
dispatch(commandBuffer, shadowProgram, shadowWidthCB, swapchain.height, shadowData, descriptors);
}

vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, queryPoolTimestamp, timestamp + 1);
Expand All @@ -1509,12 +1518,10 @@ int main(int argc, const char** argv)
vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, shadowfillPipeline);

DescriptorInfo descriptors[] = { { shadowTarget.imageView, VK_IMAGE_LAYOUT_GENERAL }, { readSampler, depthTarget.imageView, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL } };
vkCmdPushDescriptorSetWithTemplateKHR(commandBuffer, shadowfillProgram.updateTemplate, shadowfillProgram.layout, 0, descriptors);

vec4 fillData = vec4(float(swapchain.width), float(swapchain.height), 0, 0);

vkCmdPushConstants(commandBuffer, shadowProgram.layout, shadowProgram.pushConstantStages, 0, sizeof(fillData), &fillData);
vkCmdDispatch(commandBuffer, getGroupCount(shadowWidthCB, shadowProgram.localSizeX), getGroupCount(swapchain.height, shadowProgram.localSizeY), 1);
dispatch(commandBuffer, shadowProgram, shadowWidthCB, swapchain.height, fillData, descriptors);
}

for (int pass = 0; pass < (shadowblurEnabled ? 2 : 0); ++pass)
Expand All @@ -1536,12 +1543,10 @@ int main(int argc, const char** argv)
vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, shadowblurPipeline);

DescriptorInfo descriptors[] = { { blurTo.imageView, VK_IMAGE_LAYOUT_GENERAL }, { readSampler, blurFrom.imageView, VK_IMAGE_LAYOUT_GENERAL }, { readSampler, depthTarget.imageView, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL } };
vkCmdPushDescriptorSetWithTemplateKHR(commandBuffer, shadowblurProgram.updateTemplate, shadowblurProgram.layout, 0, descriptors);

vec4 blurData = vec4(float(swapchain.width), float(swapchain.height), pass == 0 ? 1 : 0, camera.znear);

vkCmdPushConstants(commandBuffer, shadowblurProgram.layout, shadowblurProgram.pushConstantStages, 0, sizeof(blurData), &blurData);
vkCmdDispatch(commandBuffer, getGroupCount(swapchain.width, shadowblurProgram.localSizeX), getGroupCount(swapchain.height, shadowblurProgram.localSizeY), 1);
dispatch(commandBuffer, shadowblurProgram, swapchain.width, swapchain.height, blurData, descriptors);
}

VkImageMemoryBarrier2 postblurBarrier =
Expand All @@ -1557,16 +1562,14 @@ int main(int argc, const char** argv)
vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, shadePipeline);

DescriptorInfo descriptors[] = { { swapchainImageViews[imageIndex], VK_IMAGE_LAYOUT_GENERAL }, { readSampler, gbufferTargets[0].imageView, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL }, { readSampler, gbufferTargets[1].imageView, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL }, { readSampler, depthTarget.imageView, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL }, { readSampler, shadowTarget.imageView, VK_IMAGE_LAYOUT_GENERAL } };
vkCmdPushDescriptorSetWithTemplateKHR(commandBuffer, shadeProgram.updateTemplate, shadeProgram.layout, 0, descriptors);

ShadeData shadeData = {};
shadeData.cameraPosition = camera.position;
shadeData.sunDirection = sunDirection;
shadeData.inverseViewProjection = inverse(projection * view);
shadeData.imageSize = vec2(float(swapchain.width), float(swapchain.height));

vkCmdPushConstants(commandBuffer, shadeProgram.layout, shadeProgram.pushConstantStages, 0, sizeof(shadeData), &shadeData);
vkCmdDispatch(commandBuffer, getGroupCount(swapchain.width, shadeProgram.localSizeX), getGroupCount(swapchain.height, shadeProgram.localSizeY), 1);
dispatch(commandBuffer, shadeProgram, swapchain.width, swapchain.height, shadeData, descriptors);
}

vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, queryPoolTimestamp, timestamp + 3);
Expand All @@ -1582,11 +1585,10 @@ int main(int argc, const char** argv)
vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, blitPipeline);

DescriptorInfo descriptors[] = { { swapchainImageViews[imageIndex], VK_IMAGE_LAYOUT_GENERAL }, { readSampler, gbufferTargets[0].imageView, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL } };
vkCmdPushDescriptorSetWithTemplateKHR(commandBuffer, blitProgram.updateTemplate, blitProgram.layout, 0, descriptors);

vec4 blitData = vec4(float(swapchain.width), float(swapchain.height), 0, 0);
vkCmdPushConstants(commandBuffer, blitProgram.layout, blitProgram.pushConstantStages, 0, sizeof(blitData), &blitData);
vkCmdDispatch(commandBuffer, getGroupCount(swapchain.width, blitProgram.localSizeX), getGroupCount(swapchain.height, blitProgram.localSizeY), 1);

dispatch(commandBuffer, blitProgram, swapchain.width, swapchain.height, blitData, descriptors);

vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, queryPoolTimestamp, timestamp + 3);
}
Expand Down
8 changes: 5 additions & 3 deletions src/shaders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <dirent.h>
#endif


#include <string>
#include <vector>

Expand Down Expand Up @@ -355,7 +354,7 @@ static VkPipelineLayout createPipelineLayout(VkDevice device, VkDescriptorSetLay
return layout;
}

static VkDescriptorUpdateTemplate createUpdateTemplate(VkDevice device, VkPipelineBindPoint bindPoint, VkPipelineLayout layout, Shaders shaders)
static VkDescriptorUpdateTemplate createUpdateTemplate(VkDevice device, VkPipelineBindPoint bindPoint, VkPipelineLayout layout, Shaders shaders, uint32_t* pushDescriptorCount)
{
std::vector<VkDescriptorUpdateTemplateEntry> entries;

Expand All @@ -376,6 +375,8 @@ static VkDescriptorUpdateTemplate createUpdateTemplate(VkDevice device, VkPipeli
entries.push_back(entry);
}

*pushDescriptorCount = uint32_t(entries.size());

VkDescriptorUpdateTemplateCreateInfo createInfo = { VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO };

createInfo.descriptorUpdateEntryCount = uint32_t(entries.size());
Expand Down Expand Up @@ -656,10 +657,11 @@ Program createProgram(VkDevice device, VkPipelineBindPoint bindPoint, Shaders sh
program.layout = createPipelineLayout(device, program.setLayout, arrayLayout, pushConstantStages, pushConstantSize);
assert(program.layout);

program.updateTemplate = createUpdateTemplate(device, bindPoint, program.layout, shaders);
program.updateTemplate = createUpdateTemplate(device, bindPoint, program.layout, shaders, &program.pushDescriptorCount);
assert(program.updateTemplate);

program.pushConstantStages = pushConstantStages;
program.pushConstantSize = uint32_t(pushConstantSize);

const Shader* shader = shaders.size() == 1 ? *shaders.begin() : nullptr;

Expand Down
3 changes: 3 additions & 0 deletions src/shaders.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ struct Program
VkPipelineLayout layout;
VkDescriptorSetLayout setLayout;
VkDescriptorUpdateTemplate updateTemplate;

VkShaderStageFlags pushConstantStages;
uint32_t pushConstantSize;
uint32_t pushDescriptorCount;

uint32_t localSizeX;
uint32_t localSizeY;
Expand Down

0 comments on commit 17d800b

Please sign in to comment.