Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More fft utils #795

Merged
merged 45 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7ca1e7d
Merge branch 'master' into more_fft_utils
Fletterio Nov 11, 2024
e2df09a
Merge branch 'master' into more_fft_utils
Fletterio Nov 12, 2024
dcc537e
More changes following Bloom PR review
Fletterio Nov 13, 2024
65bbad8
Adds ternary op for complex numbers
Fletterio Nov 19, 2024
16d3261
Merge master
Fletterio Nov 19, 2024
13dd52d
Restore submodule pointer
Fletterio Nov 19, 2024
b31705d
Merge branch 'master' into more_fft_utils
Fletterio Nov 19, 2024
e618e58
Yet more utils, such as bitreversal
Fletterio Nov 23, 2024
58d8929
Add functionality for nabla unpacking trades when doing FFT of packed…
Fletterio Dec 3, 2024
cd07b9b
Share complex types between cpp and hlsl, add mirror trade functional…
Fletterio Dec 3, 2024
ea31887
Also add fast mul by i,-i to cpp
Fletterio Dec 4, 2024
53541ff
Modify ternary operator in complex and add it as a functional struct …
Fletterio Dec 5, 2024
fdba8ce
Point at examples test master to merge Nabla master
Fletterio Dec 5, 2024
620e601
Merge branch 'master' into more_fft_utils
Fletterio Dec 5, 2024
4a16b5d
padDimensions and getOutputBufferSize rewritten so they can be shared…
Fletterio Dec 6, 2024
e61ab7a
Forgot what changed
Fletterio Dec 7, 2024
20b4e3a
adds findLSB and findMSB from std.450 to glsl_compat.hlsl
Fletterio Dec 7, 2024
9abf1de
Require concepts for Accessors for FFT
Fletterio Dec 10, 2024
f3ad5e8
Change submodule pointer so it's not changed by CMake
Fletterio Dec 10, 2024
975a7b7
Fixed accessor concepts for FFT
Fletterio Dec 11, 2024
2dc70c1
Merge branch 'master' into more_fft_utils
Fletterio Dec 11, 2024
83e0cbd
- Differentiate concepts for FFT based on ElementsPerInvocationLog2,
Fletterio Dec 11, 2024
1412b01
Renamed some parameters so they better convey intent
Fletterio Dec 12, 2024
3b72975
Comment change
Fletterio Dec 13, 2024
6a22158
Update examples submodule pointer
Fletterio Dec 13, 2024
dc10958
Merge branch 'master' into more_fft_utils
Fletterio Dec 13, 2024
f44c8d9
Merge branch 'master' into more_fft_utils
Fletterio Dec 17, 2024
e9a5b8e
SharedMemAccessor concept update
Fletterio Dec 18, 2024
a933953
Merge master
Fletterio Dec 18, 2024
3359e34
- Make most of intutil shared, deprecate the versions that were in the.h
Fletterio Dec 20, 2024
18931dd
Roll back the GLSL bitreverse change, it was fine after all
Fletterio Dec 20, 2024
d1666d4
Merge branch 'master' into more_fft_utils
Fletterio Dec 20, 2024
9f0713c
Merge branch 'bitreverse_intrinsic' into more_fft_utils
Fletterio Dec 20, 2024
47f018a
Merge branch 'bitreverse_intrinsic' into more_fft_utils
Fletterio Dec 20, 2024
0e6e31a
Change fft bitReverse name, update examples pointer submodule
Fletterio Dec 20, 2024
2eb0ffd
Merge master
Fletterio Jan 6, 2025
6401e53
Addressed PR review comments
Fletterio Jan 10, 2025
fdb7904
Move some HLSL stuff to CPP-shared
Fletterio Jan 13, 2025
4463278
Merge branch 'master' into more_fft_utils
Fletterio Jan 13, 2025
6b8714d
Moved readme over
Fletterio Jan 14, 2025
d0ed313
Seeing iof this fixes Markdown issue in gh readme
Fletterio Jan 14, 2025
d4dc129
No line break sin latex math for gh readmes
Fletterio Jan 14, 2025
6edab6d
Even worse, no two $ math mode in latex readme
Fletterio Jan 14, 2025
036c7dd
Going insane at GH readme not parding this well
Fletterio Jan 14, 2025
e8f46dd
Fixed
Fletterio Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/nbl/builtin/hlsl/complex.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,19 @@ complex_t<Scalar> rotateRight(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
return retVal;
}

// Annoyed at having to write a lot of boilerplate to do a select
// Essentially returns what you'd expect from doing `condition ? a : b`
template<typename Scalar>
complex_t<Scalar> ternaryOperator(bool condition, NBL_CONST_REF_ARG(complex_t<Scalar>) a, NBL_CONST_REF_ARG(complex_t<Scalar>) b)
{
const vector<Scalar, 2> aVector = vector<Scalar, 2>(a.real(), a.imag());
const vector<Scalar, 2> bVector = vector<Scalar, 2>(b.real(), b.imag());
const vector<Scalar, 2> resultVector = condition ? aVector : bVector;
const complex_t<Scalar> result = { resultVector.x, resultVector.y };
return result;
}


}
}

Expand Down
61 changes: 60 additions & 1 deletion include/nbl/builtin/hlsl/fft/common.hlsl
Original file line number Diff line number Diff line change
@@ -1,9 +1,47 @@
#ifndef _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_
#define _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_

#include "nbl/builtin/hlsl/complex.hlsl"
#include "nbl/builtin/hlsl/cpp_compat.hlsl"

#ifndef __HLSL_VERSION
#include <nbl/core/math/intutil.h>

namespace nbl
{
namespace hlsl
{
namespace fft
{

static inline uint32_t3 padDimensions(uint32_t3 dimensions, std::span<uint16_t> axes, bool realFFT = false)
{
uint16_t axisCount = 0;
for (auto i : axes)
{
dimensions[i] = core::roundUpToPoT(dimensions[i]);
if (realFFT && !axisCount++)
dimensions[i] /= 2;
}
return dimensions;
}

static inline uint64_t getOutputBufferSize(const uint32_t3& inputDimensions, uint32_t numChannels, std::span<uint16_t> axes, bool realFFT = false, bool halfFloats = false)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could have just replaced uint32_t3 with vector<uint32_t,M> and made it available both for C++ and HLSL

also why is span<uint16_t> used, you could just deduce the axis count from N ?

roundUpToPoT could move to hlsl namespace actually

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

roundUpToPoT is in core/math/intutil.h. I can create a builtin/math/intutil.hlsl, copy most functions over and refactor every usage of the functions in that file

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

roundUpToPoT is in core/math/intutil.h. I can create a builtin/math/intutil.hlsl, copy most functions over and refactor every usage of the functions in that file

ask @Przemog1 about the preferred location so it fits in with the spirit of #801

You don't actually need to do big refactor if you provide a "legacy" alias (make the old nbl::core::roundUpToPoT call nbl::hlsl::roundUpToPoT) and add the [[deprecated]] attribute

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm creating builtin/hlsl/math/intutil.hlsl, @Przemog1 lmk if you'd rather have it be named different or placed somewhere else

{
auto paddedDims = padDimensions(inputDimensions, axes);
uint64_t numberOfComplexElements = paddedDims[0] * paddedDims[1] * paddedDims[2] * numChannels;
return 2 * numberOfComplexElements * (halfFloats ? sizeof(float16_t) : sizeof(float32_t));
}


}
}
}

#else

#include "nbl/builtin/hlsl/complex.hlsl"
#include "nbl/builtin/hlsl/numbers.hlsl"
#include "nbl/builtin/hlsl/concepts.hlsl"

namespace nbl
{
Expand Down Expand Up @@ -53,8 +91,29 @@ using DIT = DIX<true, Scalar>;

template<typename Scalar>
using DIF = DIX<false, Scalar>;

// ------------------------------------------------- Utils ---------------------------------------------------------
//
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
template<typename Scalar>
void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
{
complex_t<Scalar> x = (lo + conj(hi)) * Scalar(0.5);
hi = rotateRight<Scalar>(lo - conj(hi)) * Scalar(0.5);
lo = x;
}

// Bit-reverses T as a binary string of length given by Bits
template<typename T, uint16_t Bits NBL_FUNC_REQUIRES(is_integral_v<T> && Bits <= sizeof(T) * 8)
T bitReverse(T value)
{
return glsl::bitfieldReverse<uint32_t>(value) >> (sizeof(T) * 8 - Bits);
}

}
}
}

#endif

#endif
146 changes: 103 additions & 43 deletions include/nbl/builtin/hlsl/workgroup/fft.hlsl
Original file line number Diff line number Diff line change
@@ -1,13 +1,45 @@
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
#define _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_

#include <nbl/builtin/hlsl/cpp_compat.hlsl>
#include <nbl/builtin/hlsl/fft/common.hlsl>

#ifndef __HLSL_VERSION
#include <nbl/video/IPhysicalDevice.h>

namespace nbl
{
namespace hlsl
{
namespace workgroup
{
namespace fft
{

inline std::pair<uint16_t, uint16_t> optimalFFTParameters(const video::ILogicalDevice* device, uint32_t inputArrayLength)
{
uint32_t maxWorkgroupSize = *device->getPhysicalDevice()->getLimits().maxWorkgroupSize;
// This is the logic found in core::roundUpToPoT to get the log2
uint16_t workgroupSizeLog2 = 1u + hlsl::findMSB(core::min(inputArrayLength / 2, maxWorkgroupSize) - 1u);
uint16_t elementPerInvocationLog2 = 1u + hlsl::findMSB(core::max((inputArrayLength >> workgroupSizeLog2) - 1u, 1u));
return { elementPerInvocationLog2, workgroupSizeLog2 };
}

}
}
}
}

#else

#include "nbl/builtin/hlsl/subgroup/fft.hlsl"
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
#include "nbl/builtin/hlsl/workgroup/shuffle.hlsl"
#include "nbl/builtin/hlsl/mpl.hlsl"
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
#include "nbl/builtin/hlsl/bit.hlsl"
#include "nbl/builtin/hlsl/concepts.hlsl"

// Caveats
// - Sin and Cos in HLSL take 32-bit floats. Using this library with 64-bit floats works perfectly fine, but DXC will emit warnings
Expand Down Expand Up @@ -90,20 +122,7 @@ namespace impl
}
} //namespace impl

// Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
template <typename scalar_t, uint16_t WorkgroupSize>
NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof(complex_t<scalar_t>) / sizeof(uint32_t)) * WorkgroupSize;

// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
template<typename Scalar>
void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
{
complex_t<Scalar> x = (lo + conj(hi)) * Scalar(0.5);
hi = rotateRight<Scalar>(lo - conj(hi)) * Scalar(0.5);
lo = x;
}

template<uint16_t ElementsPerInvocation, uint16_t WorkgroupSize>
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>
struct FFTIndexingUtils
{
// This function maps the index `idx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[idx]`
Expand Down Expand Up @@ -132,16 +151,36 @@ struct FFTIndexingUtils
return getNablaIndex(getDFTMirrorIndex(getDFTIndex(idx)));
}

NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = mpl::log2<ElementsPerInvocation>::value;
NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + mpl::log2<WorkgroupSize>::value;
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t(WorkgroupSize) * uint32_t(ElementsPerInvocation);
NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + WorkgroupSizeLog2;
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t(1) << FFTSizeLog2;
};

} //namespace fft

// ----------------------------------- End Utils -----------------------------------------------
// ----------------------------------- End Utils --------------------------------------------------------------

template<uint16_t ElementsPerInvocation, bool Inverse, uint16_t WorkgroupSize, typename Scalar, class device_capabilities=void>
namespace fft
{

template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename _Scalar NBL_PRIMARY_REQUIRES(_ElementsPerInvocationLog2 > 0 && _WorkgroupSizeLog2 >= 5)
struct ConstevalParameters
{
using scalar_t = _Scalar;

NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = _ElementsPerInvocationLog2;
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
NBL_CONSTEXPR_STATIC_INLINE uint32_t TotalSize = uint32_t(1) << (ElementsPerInvocationLog2 + WorkgroupSizeLog2);

NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = uint16_t(1) << ElementsPerInvocationLog2;
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2;

// Required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = (sizeof(complex_t<scalar_t>) / sizeof(uint32_t)) << WorkgroupSizeLog2;
};

} //namespace fft

template<bool Inverse, typename consteval_params_t, class device_capabilities=void>
struct FFT;

// For the FFT methods below, we assume:
Expand All @@ -161,9 +200,11 @@ struct FFT;
// * void workgroupExecutionAndMemoryBarrier();

// 2 items per invocation forward specialization
template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
template<uint16_t WorkgroupSizeLog2, typename Scalar, class device_capabilities>
struct FFT<false, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device_capabilities>
{
using consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;

template<typename SharedMemoryAdaptor>
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
{
Expand All @@ -177,6 +218,8 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
template<typename Accessor, typename SharedMemoryAccessor>
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
{
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;

// Compute the indices only once
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
const uint32_t loIx = threadID;
Expand Down Expand Up @@ -222,12 +265,12 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
}
};



// 2 items per invocation inverse specialization
template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
template<uint16_t WorkgroupSizeLog2, typename Scalar, class device_capabilities>
struct FFT<true, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device_capabilities>
{
using consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;

template<typename SharedMemoryAdaptor>
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
{
Expand All @@ -241,6 +284,8 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
template<typename Accessor, typename SharedMemoryAccessor>
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
{
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;

// Compute the indices only once
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
const uint32_t loIx = threadID;
Expand Down Expand Up @@ -291,17 +336,23 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
};

// Forward FFT
template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2, typename Scalar, class device_capabilities>
struct FFT<false, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>, device_capabilities>
{
using consteval_params_t = fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>;
using small_fft_consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;

template<typename Accessor, typename SharedMemoryAccessor>
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
{
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = consteval_params_t::ElementsPerInvocation;

[unroll]
for (uint32_t stride = (K / 2) * WorkgroupSize; stride > WorkgroupSize; stride >>= 1)
for (uint32_t stride = (ElementsPerInvocation / 2) * WorkgroupSize; stride > WorkgroupSize; stride >>= 1)
{
[unroll]
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (ElementsPerInvocation / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
{
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
const uint32_t hiIx = loIx | stride;
Expand All @@ -318,47 +369,53 @@ struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
}

// do K/2 small workgroup FFTs
// do ElementsPerInvocation/2 small workgroup FFTs
accessor_adaptors::Offset<Accessor> offsetAccessor;
offsetAccessor.accessor = accessor;
[unroll]
for (uint32_t k = 0; k < K; k += 2)
for (uint32_t k = 0; k < ElementsPerInvocation; k += 2)
{
if (k)
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
offsetAccessor.offset = WorkgroupSize*k;
FFT<2,false, WorkgroupSize, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
FFT<false, small_fft_consteval_params_t, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
}
accessor = offsetAccessor.accessor;
}
};

// Inverse FFT
template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
struct FFT<K, true, WorkgroupSize, Scalar, device_capabilities>
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2, typename Scalar, class device_capabilities>
struct FFT<true, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>, device_capabilities>
{
using consteval_params_t = fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>;
using small_fft_consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;

template<typename Accessor, typename SharedMemoryAccessor>
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
{
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = consteval_params_t::ElementsPerInvocation;

// do K/2 small workgroup FFTs
accessor_adaptors::Offset<Accessor> offsetAccessor;
offsetAccessor.accessor = accessor;
[unroll]
for (uint32_t k = 0; k < K; k += 2)
for (uint32_t k = 0; k < ElementsPerInvocation; k += 2)
{
if (k)
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
offsetAccessor.offset = WorkgroupSize*k;
FFT<2,true, WorkgroupSize, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
FFT<true, small_fft_consteval_params_t, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
}
accessor = offsetAccessor.accessor;

[unroll]
for (uint32_t stride = 2 * WorkgroupSize; stride < K * WorkgroupSize; stride <<= 1)
for (uint32_t stride = 2 * WorkgroupSize; stride < ElementsPerInvocation * WorkgroupSize; stride <<= 1)
{
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
[unroll]
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (ElementsPerInvocation / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
{
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
const uint32_t hiIx = loIx | stride;
Expand All @@ -370,11 +427,11 @@ struct FFT<K, true, WorkgroupSize, Scalar, device_capabilities>
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride), lo,hi);

// Divide by special factor at the end
if ( (K / 2) * WorkgroupSize == stride)
if ( (ElementsPerInvocation / 2) * WorkgroupSize == stride)
{
divides_assign< complex_t<Scalar> > divAss;
divAss(lo, K / 2);
divAss(hi, K / 2);
divAss(lo, ElementsPerInvocation / 2);
divAss(hi, ElementsPerInvocation / 2);
}

accessor.set(loIx, lo);
Expand All @@ -390,4 +447,7 @@ struct FFT<K, true, WorkgroupSize, Scalar, device_capabilities>
}
}


#endif

#endif