Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More fft utils #795

Merged
merged 45 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7ca1e7d
Merge branch 'master' into more_fft_utils
Fletterio Nov 11, 2024
e2df09a
Merge branch 'master' into more_fft_utils
Fletterio Nov 12, 2024
dcc537e
More changes following Bloom PR review
Fletterio Nov 13, 2024
65bbad8
Adds ternary op for complex numbers
Fletterio Nov 19, 2024
16d3261
Merge master
Fletterio Nov 19, 2024
13dd52d
Restore submodule pointer
Fletterio Nov 19, 2024
b31705d
Merge branch 'master' into more_fft_utils
Fletterio Nov 19, 2024
e618e58
Yet more utils, such as bitreversal
Fletterio Nov 23, 2024
58d8929
Add functionality for nabla unpacking trades when doing FFT of packed…
Fletterio Dec 3, 2024
cd07b9b
Share complex types between cpp and hlsl, add mirror trade functional…
Fletterio Dec 3, 2024
ea31887
Also add fast mul by i,-i to cpp
Fletterio Dec 4, 2024
53541ff
Modify ternary operator in complex and add it as a functional struct …
Fletterio Dec 5, 2024
fdba8ce
Point at examples test master to merge Nabla master
Fletterio Dec 5, 2024
620e601
Merge branch 'master' into more_fft_utils
Fletterio Dec 5, 2024
4a16b5d
padDimensions and getOutputBufferSize rewritten so they can be shared…
Fletterio Dec 6, 2024
e61ab7a
Forgot what changed
Fletterio Dec 7, 2024
20b4e3a
adds findLSB and findMSB from std.450 to glsl_compat.hlsl
Fletterio Dec 7, 2024
9abf1de
Require concepts for Accessors for FFT
Fletterio Dec 10, 2024
f3ad5e8
Change submodule pointer so it's not changed by CMake
Fletterio Dec 10, 2024
975a7b7
Fixed accessor concepts for FFT
Fletterio Dec 11, 2024
2dc70c1
Merge branch 'master' into more_fft_utils
Fletterio Dec 11, 2024
83e0cbd
- Differentiate concepts for FFT based on ElementsPerInvocationLog2,
Fletterio Dec 11, 2024
1412b01
Renamed some parameters so they better convey intent
Fletterio Dec 12, 2024
3b72975
Comment change
Fletterio Dec 13, 2024
6a22158
Update examples submodule pointer
Fletterio Dec 13, 2024
dc10958
Merge branch 'master' into more_fft_utils
Fletterio Dec 13, 2024
f44c8d9
Merge branch 'master' into more_fft_utils
Fletterio Dec 17, 2024
e9a5b8e
SharedMemAccessor concept update
Fletterio Dec 18, 2024
a933953
Merge master
Fletterio Dec 18, 2024
3359e34
- Make most of intutil shared, deprecate the versions that were in the.h
Fletterio Dec 20, 2024
18931dd
Roll back the GLSL bitreverse change, it was fine after all
Fletterio Dec 20, 2024
d1666d4
Merge branch 'master' into more_fft_utils
Fletterio Dec 20, 2024
9f0713c
Merge branch 'bitreverse_intrinsic' into more_fft_utils
Fletterio Dec 20, 2024
47f018a
Merge branch 'bitreverse_intrinsic' into more_fft_utils
Fletterio Dec 20, 2024
0e6e31a
Change fft bitReverse name, update examples pointer submodule
Fletterio Dec 20, 2024
2eb0ffd
Merge master
Fletterio Jan 6, 2025
6401e53
Addressed PR review comments
Fletterio Jan 10, 2025
fdb7904
Move some HLSL stuff to CPP-shared
Fletterio Jan 13, 2025
4463278
Merge branch 'master' into more_fft_utils
Fletterio Jan 13, 2025
6b8714d
Moved readme over
Fletterio Jan 14, 2025
d0ed313
Seeing iof this fixes Markdown issue in gh readme
Fletterio Jan 14, 2025
d4dc129
No line break sin latex math for gh readmes
Fletterio Jan 14, 2025
6edab6d
Even worse, no two $ math mode in latex readme
Fletterio Jan 14, 2025
036c7dd
Going insane at GH readme not parding this well
Fletterio Jan 14, 2025
e8f46dd
Fixed
Fletterio Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions include/nbl/builtin/hlsl/complex.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,47 @@
#ifndef _NBL_BUILTIN_HLSL_COMPLEX_INCLUDED_
#define _NBL_BUILTIN_HLSL_COMPLEX_INCLUDED_

#include "nbl/builtin/hlsl/functional.hlsl"
#include "nbl/builtin/hlsl/cpp_compat/promote.hlsl"
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
#include <nbl/builtin/hlsl/functional.hlsl>

using namespace nbl::hlsl;

// -------------------------------------- CPP VERSION ------------------------------------
#ifndef __HLSL_VERSION

#include <complex>

namespace nbl
{
namespace hlsl
{

template<typename Scalar>
using complex_t = std::complex<Scalar>;

// Fast mul by i
template<typename Scalar>
complex_t<Scalar> rotateLeft(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
{
complex_t<Scalar> retVal = { -value.imag(), value.real() };
return retVal;
}

// Fast mul by -i
template<typename Scalar>
complex_t<Scalar> rotateRight(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
{
complex_t<Scalar> retVal = { value.imag(), -value.real() };
return retVal;
}

}
}

// -------------------------------------- END CPP VERSION ------------------------------------

// -------------------------------------- HLSL VERSION ---------------------------------------
#else

namespace nbl
{
Expand Down Expand Up @@ -126,6 +165,8 @@ struct complex_t
template<typename Scalar>
struct plus< complex_t<Scalar> >
{
using type_t = complex_t<Scalar>;

complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
{
return lhs + rhs;
Expand All @@ -137,6 +178,8 @@ struct plus< complex_t<Scalar> >
template<typename Scalar>
struct minus< complex_t<Scalar> >
{
using type_t = complex_t<Scalar>;

complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
{
return lhs - rhs;
Expand All @@ -148,6 +191,8 @@ struct minus< complex_t<Scalar> >
template<typename Scalar>
struct multiplies< complex_t<Scalar> >
{
using type_t = complex_t<Scalar>;

complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
{
return lhs * rhs;
Expand All @@ -164,6 +209,8 @@ struct multiplies< complex_t<Scalar> >
template<typename Scalar>
struct divides< complex_t<Scalar> >
{
using type_t = complex_t<Scalar>;

complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
{
return lhs / rhs;
Expand Down Expand Up @@ -379,6 +426,22 @@ complex_t<Scalar> rotateRight(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
return retVal;
}

template<typename Scalar>
struct ternary_operator< complex_t<Scalar> >
{
using type_t = complex_t<Scalar>;

complex_t<Scalar> operator()(bool condition, NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
{
const vector<Scalar, 2> lhsVector = vector<Scalar, 2>(lhs.real(), lhs.imag());
const vector<Scalar, 2> rhsVector = vector<Scalar, 2>(rhs.real(), rhs.imag());
const vector<Scalar, 2> resultVector = condition ? lhsVector : rhsVector;
const complex_t<Scalar> result = { resultVector.x, resultVector.y };
return result;
}
};


}
}

Expand All @@ -396,4 +459,7 @@ NBL_REGISTER_OBJ_TYPE(complex_t<float64_t2>,::nbl::hlsl::alignment_of_v<float64_
NBL_REGISTER_OBJ_TYPE(complex_t<float64_t3>,::nbl::hlsl::alignment_of_v<float64_t3>)
NBL_REGISTER_OBJ_TYPE(complex_t<float64_t4>,::nbl::hlsl::alignment_of_v<float64_t4>)

// -------------------------------------- END HLSL VERSION ---------------------------------------
#endif

#endif
87 changes: 87 additions & 0 deletions include/nbl/builtin/hlsl/concepts/accessors/fft.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_FFT_INCLUDED_
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_FFT_INCLUDED_

#include "nbl/builtin/hlsl/concepts.hlsl"
#include "nbl/builtin/hlsl/fft/common.hlsl"

namespace nbl
{
namespace hlsl
{
namespace workgroup
{
namespace fft
{
// The SharedMemoryAccessor MUST provide the following methods:
// * void get(uint32_t index, inout uint32_t value);
// * void set(uint32_t index, in uint32_t value);
// * void workgroupExecutionAndMemoryBarrier();

#define NBL_CONCEPT_NAME FFTSharedMemoryAccessor
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
#define NBL_CONCEPT_PARAM_0 (accessor, T)
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
#define NBL_CONCEPT_PARAM_2 (val, uint32_t)
NBL_CONCEPT_BEGIN(3)
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
NBL_CONCEPT_END(
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<uint32_t, uint32_t>(index, val)), is_same_v, void))
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<uint32_t, uint32_t>(index, val)), is_same_v, void))
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
);
#undef val
#undef index
#undef accessor
#include <nbl/builtin/hlsl/concepts/__end.hlsl>


// The Accessor (for a small FFT) MUST provide the following methods:
// * void get(uint32_t index, inout complex_t<Scalar> value);
// * void set(uint32_t index, in complex_t<Scalar> value);

#define NBL_CONCEPT_NAME SmallFFTAccessor
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
#define NBL_CONCEPT_PARAM_0 (accessor, T)
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
#define NBL_CONCEPT_PARAM_2 (val, complex_t<Scalar>)
NBL_CONCEPT_BEGIN(3)
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
NBL_CONCEPT_END(
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index, val)), is_same_v, void))
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index, val)), is_same_v, void))
);
#undef val
#undef index
#undef accessor
#include <nbl/builtin/hlsl/concepts/__end.hlsl>


// The Accessor MUST provide the following methods:
// * void get(uint32_t index, inout complex_t<Scalar> value);
// * void set(uint32_t index, in complex_t<Scalar> value);
// * void memoryBarrier();

#define NBL_CONCEPT_NAME FFTAccessor
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
#define NBL_CONCEPT_PARAM_0 (accessor, T)
NBL_CONCEPT_BEGIN(1)
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
NBL_CONCEPT_END(
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.memoryBarrier()), is_same_v, void))
) && SmallFFTAccessor<T, Scalar>;
#undef accessor
#include <nbl/builtin/hlsl/concepts/__end.hlsl>

}
}
}
}

#endif
2 changes: 2 additions & 0 deletions include/nbl/builtin/hlsl/cpp_compat/basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ inline To _static_cast(From v)
#define NBL_CONSTEXPR_STATIC constexpr static
#define NBL_CONSTEXPR_STATIC_INLINE constexpr static inline
#define NBL_CONSTEXPR_INLINE_FUNC constexpr inline
#define NBL_CONSTEXPR_FORCED_INLINE_FUNC NBL_FORCE_INLINE constexpr
#define NBL_CONST_MEMBER_FUNC const

namespace nbl::hlsl
Expand Down Expand Up @@ -70,6 +71,7 @@ namespace nbl::hlsl
#define NBL_CONSTEXPR_STATIC const static
#define NBL_CONSTEXPR_STATIC_INLINE const static
#define NBL_CONSTEXPR_INLINE_FUNC inline
#define NBL_CONSTEXPR_FORCED_INLINE_FUNC inline
#define NBL_CONST_MEMBER_FUNC

namespace nbl
Expand Down
83 changes: 66 additions & 17 deletions include/nbl/builtin/hlsl/fft/common.hlsl
Original file line number Diff line number Diff line change
@@ -1,58 +1,107 @@
#ifndef _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_
#define _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_

#include "nbl/builtin/hlsl/complex.hlsl"
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
#include "nbl/builtin/hlsl/numbers.hlsl"
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
#include <nbl/builtin/hlsl/complex.hlsl>
#include <nbl/builtin/hlsl/concepts.hlsl>
#include <nbl/builtin/hlsl/math/intutil.hlsl>
#include <nbl/builtin/hlsl/numbers.hlsl>

namespace nbl
namespace nbl
{
namespace hlsl
{
namespace fft
namespace fft
{

// template parameter N controls the number of dimensions of the input
// template parameter M controls the number of dimensions to pad up to PoT
// "axes" indicates which dimensions to pad up to PoT
template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES(M <= N)
inline vector<uint64_t, 3> padDimensions(NBL_CONST_REF_ARG(vector<uint32_t, N>) dimensions, NBL_CONST_REF_ARG(vector<uint16_t, M>) axes, bool realFFT = false)
{
vector<uint32_t, N> newDimensions = dimensions;
uint16_t axisCount = 0;
for (uint16_t i = 0u; i < M; i++)
{
newDimensions[i] = hlsl::roundUpToPoT(newDimensions[i]);
if (realFFT && !axisCount++)
newDimensions[i] /= 2;
}
return newDimensions;
}

// template parameter N controls the number of dimensions of the input
// template parameter M controls the number of dimensions we run an FFT along AND store the result
// "axes" indicates which dimensions we run an FFT along AND store the result
template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES(M <= N)
inline uint64_t getOutputBufferSize(NBL_CONST_REF_ARG(vector<uint32_t, N>) inputDimensions, uint32_t numChannels, NBL_CONST_REF_ARG(vector<uint16_t, M>) axes, bool realFFT = false, bool halfFloats = false)
{
const vector<uint64_t, 3> paddedDims = padDimensions<N, M>(inputDimensions, axes);
const uint64_t numberOfComplexElements = paddedDims[0] * paddedDims[1] * paddedDims[2] * uint64_t(numChannels);
return numberOfComplexElements * (halfFloats ? sizeof(complex_t<float16_t>) : sizeof(complex_t<float32_t>));
}

// Computes the kth element in the group of N roots of unity
// Notice 0 <= k < N/2, rotating counterclockwise in the forward (DIF) transform and clockwise in the inverse (DIT)
template<bool inverse, typename Scalar>
complex_t<Scalar> twiddle(uint32_t k, uint32_t halfN)
{
complex_t<Scalar> retVal;
const Scalar kthRootAngleRadians = numbers::pi<Scalar> * Scalar(k) / Scalar(halfN);
retVal.real( cos(kthRootAngleRadians) );
if (! inverse)
retVal.imag( sin(-kthRootAngleRadians) );
const Scalar kthRootAngleRadians = numbers::pi<Scalar> *Scalar(k) / Scalar(halfN);
retVal.real(cos(kthRootAngleRadians));
if (!inverse)
retVal.imag(sin(-kthRootAngleRadians));
else
retVal.imag( sin(kthRootAngleRadians) );
return retVal;
retVal.imag(sin(kthRootAngleRadians));
return retVal;
}

template<bool inverse, typename Scalar>
struct DIX
{
template<bool inverse, typename Scalar>
struct DIX
{
static void radix2(NBL_CONST_REF_ARG(complex_t<Scalar>) twiddle, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
{
plus_assign< complex_t<Scalar> > plusAss;
//Decimation in time - inverse
if (inverse) {
complex_t<Scalar> wHi = twiddle * hi;
hi = lo - wHi;
plusAss(lo, wHi);
plusAss(lo, wHi);
}
//Decimation in frequency - forward
else {
complex_t<Scalar> diff = lo - hi;
plusAss(lo, hi);
hi = twiddle * diff;
hi = twiddle * diff;
}
}
}
};

template<typename Scalar>
using DIT = DIX<true, Scalar>;

template<typename Scalar>
using DIF = DIX<false, Scalar>;

// ------------------------------------------------- Utils ---------------------------------------------------------
//
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
template<typename Scalar>
void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
{
complex_t<Scalar> x = (lo + conj(hi)) * Scalar(0.5);
hi = rotateRight<Scalar>(lo - conj(hi)) * Scalar(0.5);
lo = x;
}

// Bit-reverses T as a binary string of length given by Bits
template<typename T, uint16_t Bits NBL_FUNC_REQUIRES(is_integral_v<T> && Bits <= sizeof(T) * 8)
T bitReverseAs(T value)
{
return hlsl::bitReverse<uint32_t>(value) >> (sizeof(T) * 8 - Bits);
}

}
}
}
Expand Down
13 changes: 12 additions & 1 deletion include/nbl/builtin/hlsl/functional.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ COMPOUND_ASSIGN(divides)

// ----------------- End of compound assignment ops ----------------

// Min and Max don't use ALIAS_STD because they don't exist in STD
// Min, Max and Ternary Operator don't use ALIAS_STD because they don't exist in STD
// TODO: implement as mix(rhs<lhs,lhs,rhs) (SPIR-V intrinsic from the extended set & glm on C++)
template<typename T>
struct minimum
Expand Down Expand Up @@ -195,6 +195,17 @@ struct maximum
NBL_CONSTEXPR_STATIC_INLINE T identity = numeric_limits<scalar_t>::lowest; // TODO: `all_components<T>`
};

template<typename T>
struct ternary_operator
{
using type_t = T;

T operator()(bool condition, NBL_CONST_REF_ARG(T) lhs, NBL_CONST_REF_ARG(T) rhs)
{
return condition ? lhs : rhs;
}
};

}
}

Expand Down
Loading