diff --git a/cudax/include/cuda/experimental/__utility/basic_any/virtcall.cuh b/cudax/include/cuda/experimental/__utility/basic_any/virtcall.cuh index a51b36d29c9..e090e4b023c 100644 --- a/cudax/include/cuda/experimental/__utility/basic_any/virtcall.cuh +++ b/cudax/include/cuda/experimental/__utility/basic_any/virtcall.cuh @@ -22,6 +22,7 @@ #endif // no system header #include <cuda/std/__concepts/concept_macros.h> +#include <cuda/std/__type_traits/is_callable.h> #include <cuda/experimental/__utility/basic_any/access.cuh> #include <cuda/experimental/__utility/basic_any/basic_any_from.cuh> @@ -60,19 +61,13 @@ namespace cuda::experimental //! except for the virtuals map, which substitutes the correct member function //! pointer for the user so they don't have to think about it. template <auto _Mbr, auto _BoundMbr> -struct __virtuals_map_pair +struct __virtuals_map_element { // map ifoo<>::meow to itself - _CCCL_NODISCARD _CUDAX_TRIVIAL_HOST_API constexpr auto operator()(__ctag<_Mbr>) const noexcept - { - return _Mbr; - } + auto operator()(__ctag<_Mbr>) const -> __virtual_fn<_Mbr>; // map ifoo<_Super>::meow to ifoo<>::meow - _CCCL_NODISCARD _CUDAX_TRIVIAL_HOST_API constexpr auto operator()(__ctag<_BoundMbr>) const noexcept - { - return _Mbr; - } + auto operator()(__ctag<_BoundMbr>) const -> __virtual_fn<_Mbr>; }; template <class, class> @@ -80,15 +75,23 @@ struct __virtuals_map; template <class _Interface, auto... _Mbrs, class _BoundInterface, auto... _BoundMbrs> struct __virtuals_map<overrides_for<_Interface, _Mbrs...>, overrides_for<_BoundInterface, _BoundMbrs...>> - : __virtuals_map_pair<_Mbrs, _BoundMbrs>... + : __virtuals_map_element<_Mbrs, _BoundMbrs>... { - using __virtuals_map_pair<_Mbrs, _BoundMbrs>::operator()...; + using __virtuals_map_element<_Mbrs, _BoundMbrs>::operator()...; }; template <class _Interface, class _Super> using __virtuals_map_for _CCCL_NODEBUG_ALIAS = __virtuals_map<__overrides_for<_Interface>, __overrides_for<__rebind_interface<_Interface, _Super>>>; +template <auto _Mbr, class _Interface, class _Super> +extern _CUDA_VSTD::__call_result_t<__virtuals_map_for<_Interface, _Super>, __ctag<_Mbr>> __virtual_fn_for_v; + +// This alias indirects through the above variable template to cache the result +// of the virtuals map lookup. +template <auto _Mbr, class _Interface, class _Super> +using __virtual_fn_for _CCCL_NODEBUG_ALIAS = decltype(__virtual_fn_for_v<_Mbr, _Interface, _Super>); + //! //! virtcall //! @@ -109,8 +112,8 @@ _CUDAX_HOST_API auto __virtcall(_Self* __self, _Args&&... __args) // auto* __vptr = __basic_any_access::__get_vptr(*__self)->__query_interface(_Interface()); auto* __obj = __basic_any_access::__get_optr(*__self); // map the member function pointer to the correct one if necessary - constexpr auto _Mbr2 = __virtuals_map_for<_Interface, _Super>{}(__ctag<_Mbr>()); - return __vptr->__virtual_fn<_Mbr2>::__fn_(__obj, static_cast<_Args&&>(__args)...); + using __virtual_fn_t = __virtual_fn_for<_Mbr, _Interface, _Super>; + return __vptr->__virtual_fn_t::__fn_(__obj, static_cast<_Args&&>(__args)...); } _CCCL_TEMPLATE(auto _Mbr, template <class...> class _Interface, class _Super, class... _Args) diff --git a/cudax/include/cuda/experimental/__utility/basic_any/virtual_functions.cuh b/cudax/include/cuda/experimental/__utility/basic_any/virtual_functions.cuh index 32d7a77ebba..f31d396b608 100644 --- a/cudax/include/cuda/experimental/__utility/basic_any/virtual_functions.cuh +++ b/cudax/include/cuda/experimental/__utility/basic_any/virtual_functions.cuh @@ -64,8 +64,8 @@ _CUDAX_TRIVIAL_API auto __c_style_cast(_Src* __ptr) noexcept -> _DstPtr } template <class _Tp, auto _Fn, class _Ret, bool _IsConst, bool _IsNothrow, class... _Args> -_CCCL_NODISCARD _CUDAX_API auto __override_fn_([[maybe_unused]] _CUDA_VSTD::__maybe_const<_IsConst, void>* __pv, - [[maybe_unused]] _Args... __args) noexcept(_IsNothrow) -> _Ret +_CCCL_NODISCARD _CUDAX_HOST_API auto __override_fn_([[maybe_unused]] _CUDA_VSTD::__maybe_const<_IsConst, void>* __pv, + [[maybe_unused]] _Args... __args) noexcept(_IsNothrow) -> _Ret { using __value_type _CCCL_NODEBUG_ALIAS = _CUDA_VSTD::__maybe_const<_IsConst, _Tp>;