Skip to content

Commit

Permalink
Merge remote-tracking branch 'remotes/origin/more_fft_utils'
Browse files Browse the repository at this point in the history
  • Loading branch information
devsh committed Jan 16, 2025
2 parents d7a9e13 + e8f46dd commit da80234
Show file tree
Hide file tree
Showing 16 changed files with 956 additions and 173 deletions.
47 changes: 47 additions & 0 deletions include/nbl/builtin/hlsl/bitreverse.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#ifndef _NBL_BUILTIN_HLSL_BITREVERSE_INCLUDED_
#define _NBL_BUILTIN_HLSL_BITREVERSE_INCLUDED_


#include <nbl/builtin/hlsl/cpp_compat.hlsl>

namespace nbl
{
namespace hlsl
{

template<typename T, uint16_t Bits NBL_FUNC_REQUIRES(is_unsigned_v<T>&& Bits <= sizeof(T) * 8)
/**
* @brief Takes the binary representation of `value` as a string of `Bits` bits and returns a value of the same type resulting from reversing the string
*
* @tparam T Type of the value to operate on.
* @tparam Bits The length of the string of bits used to represent `value`.
*
* @param [in] value The value to bitreverse.
*/
T bitReverseAs(T value)
{
return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof(T) * 8 - Bits));
}

template<typename T NBL_FUNC_REQUIRES(is_unsigned_v<T>)
/**
* @brief Takes the binary representation of `value` and returns a value of the same type resulting from reversing the string of bits as if it was `bits` long.
* Keep in mind `bits` cannot exceed `8 * sizeof(T)`.
*
* @tparam T type of the value to operate on.
*
* @param [in] value The value to bitreverse.
* @param [in] bits The length of the string of bits used to represent `value`.
*/
T bitReverseAs(T value, uint16_t bits)
{
return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof(T) * 8 - bits));
}


}
}



#endif
70 changes: 68 additions & 2 deletions include/nbl/builtin/hlsl/complex.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,47 @@
#ifndef _NBL_BUILTIN_HLSL_COMPLEX_INCLUDED_
#define _NBL_BUILTIN_HLSL_COMPLEX_INCLUDED_

#include "nbl/builtin/hlsl/functional.hlsl"
#include "nbl/builtin/hlsl/cpp_compat/promote.hlsl"
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
#include <nbl/builtin/hlsl/functional.hlsl>

using namespace nbl::hlsl;

// -------------------------------------- CPP VERSION ------------------------------------
#ifndef __HLSL_VERSION

#include <complex>

namespace nbl
{
namespace hlsl
{

template<typename Scalar>
using complex_t = std::complex<Scalar>;

// Fast mul by i
template<typename Scalar>
complex_t<Scalar> rotateLeft(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
{
complex_t<Scalar> retVal = { -value.imag(), value.real() };
return retVal;
}

// Fast mul by -i
template<typename Scalar>
complex_t<Scalar> rotateRight(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
{
complex_t<Scalar> retVal = { value.imag(), -value.real() };
return retVal;
}

}
}

// -------------------------------------- END CPP VERSION ------------------------------------

// -------------------------------------- HLSL VERSION ---------------------------------------
#else

namespace nbl
{
Expand Down Expand Up @@ -126,6 +165,8 @@ struct complex_t
template<typename Scalar>
struct plus< complex_t<Scalar> >
{
using type_t = complex_t<Scalar>;

complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
{
return lhs + rhs;
Expand All @@ -137,6 +178,8 @@ struct plus< complex_t<Scalar> >
template<typename Scalar>
struct minus< complex_t<Scalar> >
{
using type_t = complex_t<Scalar>;

complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
{
return lhs - rhs;
Expand All @@ -148,6 +191,8 @@ struct minus< complex_t<Scalar> >
template<typename Scalar>
struct multiplies< complex_t<Scalar> >
{
using type_t = complex_t<Scalar>;

complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
{
return lhs * rhs;
Expand All @@ -164,6 +209,8 @@ struct multiplies< complex_t<Scalar> >
template<typename Scalar>
struct divides< complex_t<Scalar> >
{
using type_t = complex_t<Scalar>;

complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
{
return lhs / rhs;
Expand Down Expand Up @@ -379,6 +426,22 @@ complex_t<Scalar> rotateRight(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
return retVal;
}

template<typename Scalar>
struct ternary_operator< complex_t<Scalar> >
{
using type_t = complex_t<Scalar>;

complex_t<Scalar> operator()(bool condition, NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
{
const vector<Scalar, 2> lhsVector = vector<Scalar, 2>(lhs.real(), lhs.imag());
const vector<Scalar, 2> rhsVector = vector<Scalar, 2>(rhs.real(), rhs.imag());
const vector<Scalar, 2> resultVector = condition ? lhsVector : rhsVector;
const complex_t<Scalar> result = { resultVector.x, resultVector.y };
return result;
}
};


}
}

Expand All @@ -396,4 +459,7 @@ NBL_REGISTER_OBJ_TYPE(complex_t<float64_t2>,::nbl::hlsl::alignment_of_v<float64_
NBL_REGISTER_OBJ_TYPE(complex_t<float64_t3>,::nbl::hlsl::alignment_of_v<float64_t3>)
NBL_REGISTER_OBJ_TYPE(complex_t<float64_t4>,::nbl::hlsl::alignment_of_v<float64_t4>)

// -------------------------------------- END HLSL VERSION ---------------------------------------
#endif

#endif
87 changes: 87 additions & 0 deletions include/nbl/builtin/hlsl/concepts/accessors/fft.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_FFT_INCLUDED_
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_FFT_INCLUDED_

#include "nbl/builtin/hlsl/concepts.hlsl"
#include "nbl/builtin/hlsl/fft/common.hlsl"

namespace nbl
{
namespace hlsl
{
namespace workgroup
{
namespace fft
{
// The SharedMemoryAccessor MUST provide the following methods:
// * void get(uint32_t index, inout uint32_t value);
// * void set(uint32_t index, in uint32_t value);
// * void workgroupExecutionAndMemoryBarrier();

#define NBL_CONCEPT_NAME FFTSharedMemoryAccessor
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
#define NBL_CONCEPT_PARAM_0 (accessor, T)
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
#define NBL_CONCEPT_PARAM_2 (val, uint32_t)
NBL_CONCEPT_BEGIN(3)
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
NBL_CONCEPT_END(
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<uint32_t, uint32_t>(index, val)), is_same_v, void))
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<uint32_t, uint32_t>(index, val)), is_same_v, void))
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
);
#undef val
#undef index
#undef accessor
#include <nbl/builtin/hlsl/concepts/__end.hlsl>


// The Accessor (for a small FFT) MUST provide the following methods:
// * void get(uint32_t index, inout complex_t<Scalar> value);
// * void set(uint32_t index, in complex_t<Scalar> value);

#define NBL_CONCEPT_NAME SmallFFTAccessor
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
#define NBL_CONCEPT_PARAM_0 (accessor, T)
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
#define NBL_CONCEPT_PARAM_2 (val, complex_t<Scalar>)
NBL_CONCEPT_BEGIN(3)
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
NBL_CONCEPT_END(
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index, val)), is_same_v, void))
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index, val)), is_same_v, void))
);
#undef val
#undef index
#undef accessor
#include <nbl/builtin/hlsl/concepts/__end.hlsl>


// The Accessor MUST provide the following methods:
// * void get(uint32_t index, inout complex_t<Scalar> value);
// * void set(uint32_t index, in complex_t<Scalar> value);
// * void memoryBarrier();

#define NBL_CONCEPT_NAME FFTAccessor
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
#define NBL_CONCEPT_PARAM_0 (accessor, T)
NBL_CONCEPT_BEGIN(1)
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
NBL_CONCEPT_END(
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.memoryBarrier()), is_same_v, void))
) && SmallFFTAccessor<T, Scalar>;
#undef accessor
#include <nbl/builtin/hlsl/concepts/__end.hlsl>

}
}
}
}

#endif
2 changes: 2 additions & 0 deletions include/nbl/builtin/hlsl/cpp_compat/basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ inline To _static_cast(From v)
#define NBL_CONSTEXPR_STATIC constexpr static
#define NBL_CONSTEXPR_STATIC_INLINE constexpr static inline
#define NBL_CONSTEXPR_INLINE_FUNC constexpr inline
#define NBL_CONSTEXPR_FORCED_INLINE_FUNC NBL_FORCE_INLINE constexpr
#define NBL_CONST_MEMBER_FUNC const

namespace nbl::hlsl
Expand Down Expand Up @@ -70,6 +71,7 @@ namespace nbl::hlsl
#define NBL_CONSTEXPR_STATIC const static
#define NBL_CONSTEXPR_STATIC_INLINE const static
#define NBL_CONSTEXPR_INLINE_FUNC inline
#define NBL_CONSTEXPR_FORCED_INLINE_FUNC inline
#define NBL_CONST_MEMBER_FUNC

namespace nbl
Expand Down
Loading

0 comments on commit da80234

Please sign in to comment.