Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cpp compat intrinsics refactor #801

Merged
merged 21 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 64 additions & 25 deletions include/nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,19 @@ namespace hlsl
{

template<typename Integer>
int bitCount(NBL_CONST_REF_ARG(Integer) val)
inline int bitCount(NBL_CONST_REF_ARG(Integer) val)
{
#ifdef __HLSL_VERSION
if (sizeof(Integer) == 8u)
{
uint32_t lowBits = val;
uint32_t highBits = val >> 32u;

return countbits(lowBits) + countbits(highBits);
}

return countbits(val);

#else
return glm::bitCount(val);
#endif
Expand All @@ -49,7 +58,7 @@ T clamp(NBL_CONST_REF_ARG(T) val, NBL_CONST_REF_ARG(T) min, NBL_CONST_REF_ARG(T)
#endif
}

namespace dot_product_impl
namespace cpp_compat_intrinsics_impl
{
template<typename T>
struct dot_helper
Expand Down Expand Up @@ -100,18 +109,19 @@ DEFINE_BUILTIN_VECTOR_SPECIALIZATION(float64_t, BUILTIN_VECTOR_SPECIALIZATION_RE
template<typename T>
typename vector_traits<T>::scalar_type dot(NBL_CONST_REF_ARG(T) lhs, NBL_CONST_REF_ARG(T) rhs)
{
return dot_product_impl::dot_helper<T>::dot(lhs, rhs);
return cpp_compat_intrinsics_impl::dot_helper<T>::dot(lhs, rhs);
}

// TODO: for clearer error messages, use concepts to ensure that input type is a square matrix
// determinant not defined cause its implemented via hidden friend
// https://stackoverflow.com/questions/67459950/why-is-a-friend-function-not-treated-as-a-member-of-a-namespace-of-a-class-it-wa
template<typename T, uint16_t N, uint16_t M>
inline T determinant(NBL_CONST_REF_ARG(matrix<T, N, M>) m)
template<typename T, uint16_t N>
inline T determinant(NBL_CONST_REF_ARG(matrix<T, N, N>) m)
{
#ifdef __HLSL_VERSION

spirv::determinant(m);
#else
return glm::determinant(reinterpret_cast<typename matrix<T, N, M>::Base const&>(m));
return glm::determinant(reinterpret_cast<typename matrix<T, N, N>::Base const&>(m));
#endif
}

Expand Down Expand Up @@ -169,7 +179,7 @@ int findMSB(NBL_CONST_REF_ARG(Integer) val)
}

// TODO: some of the functions in this header should move to `tgmath`
template<typename T> //requires ::nbl::hlsl::is_floating_point_v<T>
template<typename T>
inline T floor(NBL_CONST_REF_ARG(T) val)
{
#ifdef __HLSL_VERSION
Expand All @@ -191,28 +201,52 @@ inline matrix<T, N, M> inverse(NBL_CONST_REF_ARG(matrix<T, N, M>) m)
#endif
}

namespace cpp_compat_intrinsics_impl
{

// TODO: concept requiring T to be a float
template<typename T, typename U>
inline T lerp(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(U) a)
struct lerp_helper
{
static inline T lerp(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(U) a)
{
#ifdef __HLSL_VERSION
return spirv::fMix(x, y, a);
return spirv::fMix(x, y, a);
#else
if constexpr (std::is_same_v<U, bool>)
return glm::mix<T, U>(x, y, a);
#endif
}
};

template<typename T>
struct lerp_helper<T, bool>
{
static inline T lerp(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(bool) a)
{
return a ? y : x;
else
}
};

template<typename T, int N>
struct lerp_helper<vector<T, N>, vector<bool, N> >
{
using output_vec_t = vector<T, N>;

static inline output_vec_t lerp(NBL_CONST_REF_ARG(output_vec_t) x, NBL_CONST_REF_ARG(output_vec_t) y, NBL_CONST_REF_ARG(vector<bool, N>) a)
{
if constexpr (std::is_same_v<scalar_type_t<U>, bool>)
{
T retval;
// whatever has a `scalar_type` specialization should be a pure vector
for (auto i = 0; i < sizeof(a) / sizeof(scalar_type_t<U>); i++)
retval[i] = a[i] ? y[i] : x[i];
return retval;
}
else
return glm::mix<T, U>(x, y, a);
output_vec_t retval;
for (uint32_t i = 0; i < vector_traits<output_vec_t>::Dimension; i++)
retval[i] = a[i] ? y[i] : x[i];
return retval;
}
#endif
};

}

template<typename T, typename U>
inline T lerp(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(U) a)
{
return cpp_compat_intrinsics_impl::lerp_helper<T, U>::lerp(x, y, a);
}

// transpose not defined cause its implemented via hidden friend
Expand All @@ -232,14 +266,18 @@ inline T min(NBL_CONST_REF_ARG(T) a, NBL_CONST_REF_ARG(T) b)
#ifdef __HLSL_VERSION
min(a, b);
#else
return std::min(a, b);
return glm::min(a, b);
#endif
}

template<typename T>
inline T max(NBL_CONST_REF_ARG(T) a, NBL_CONST_REF_ARG(T) b)
{
return lerp<T>(a, b, b > a);
#ifdef __HLSL_VERSION
max(a, b);
#else
return glm::max(a, b);
#endif
}

template<typename FloatingPoint>
Expand Down Expand Up @@ -289,6 +327,7 @@ DEFINE_EXP2_SPECIALIZATION(uint64_t)
template<typename FloatingPoint>
inline FloatingPoint rsqrt(FloatingPoint x)
{
// TODO: https://stackoverflow.com/a/62239778
#ifdef __HLSL_VERSION
return spirv::inverseSqrt(x);
#else
Expand Down
12 changes: 5 additions & 7 deletions include/nbl/builtin/hlsl/spirv_intrinsics/glsl.std.450.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ template<typename Integral>
[[vk::ext_instruction(GLSLstd450::GLSLstd450FindILsb, "GLSL.std.450")]]
enable_if_t<is_integral_v<Integral> && (sizeof(scalar_type_t<Integral>) == 4), Integral> findILsb(Integral value);

template<typename Integral>
[[vk::ext_instruction(GLSLstd450::GLSLstd450FindSMsb, "GLSL.std.450")]]
enable_if_t<is_integral_v<Integral> && (sizeof(scalar_type_t<Integral>) == 4), Integral> findSMsb(Integral value);
int32_t findSMsb(int32_t value);

template<typename Integral>
[[vk::ext_instruction(GLSLstd450::GLSLstd450FindUMsb, "GLSL.std.450")]]
enable_if_t<is_integral_v<Integral> && (sizeof(scalar_type_t<Integral>) == 4), Integral> findUMsb(Integral value);
uint32_t findUMsb(uint32_t value);

template<typename FloatingPoint>
[[vk::ext_instruction(GLSLstd450::GLSLstd450Exp2, "GLSL.std.450")]]
Expand All @@ -44,11 +42,11 @@ enable_if_t<is_floating_point_v<FloatingPoint>, vector<FloatingPoint, 3> > cross

template<typename FloatingPoint>
[[vk::ext_instruction(GLSLstd450::GLSLstd450FMix, "GLSL.std.450")]]
enable_if_t<is_floating_point_v<FloatingPoint>, FloatingPoint> fMix(FloatingPoint val);
enable_if_t<is_floating_point_v<FloatingPoint>, FloatingPoint> fMix(FloatingPoint val, FloatingPoint min, FloatingPoint max);

template<typename SquareMatrix>
template<typename T, int N>
[[vk::ext_instruction(GLSLstd450::GLSLstd450Determinant, "GLSL.std.450")]]
SquareMatrix determinant(in SquareMatrix mat);
T determinant(in matrix<T, N, N> mat);

}
}
Expand Down