diff --git a/cudax/include/cuda/experimental/__utility/basic_any/virtcall.cuh b/cudax/include/cuda/experimental/__utility/basic_any/virtcall.cuh index a51b36d29c9..e090e4b023c 100644 --- a/cudax/include/cuda/experimental/__utility/basic_any/virtcall.cuh +++ b/cudax/include/cuda/experimental/__utility/basic_any/virtcall.cuh @@ -22,6 +22,7 @@ #endif // no system header #include +#include #include #include @@ -60,19 +61,13 @@ namespace cuda::experimental //! except for the virtuals map, which substitutes the correct member function //! pointer for the user so they don't have to think about it. template -struct __virtuals_map_pair +struct __virtuals_map_element { // map ifoo<>::meow to itself - _CCCL_NODISCARD _CUDAX_TRIVIAL_HOST_API constexpr auto operator()(__ctag<_Mbr>) const noexcept - { - return _Mbr; - } + auto operator()(__ctag<_Mbr>) const -> __virtual_fn<_Mbr>; // map ifoo<_Super>::meow to ifoo<>::meow - _CCCL_NODISCARD _CUDAX_TRIVIAL_HOST_API constexpr auto operator()(__ctag<_BoundMbr>) const noexcept - { - return _Mbr; - } + auto operator()(__ctag<_BoundMbr>) const -> __virtual_fn<_Mbr>; }; template @@ -80,15 +75,23 @@ struct __virtuals_map; template struct __virtuals_map, overrides_for<_BoundInterface, _BoundMbrs...>> - : __virtuals_map_pair<_Mbrs, _BoundMbrs>... + : __virtuals_map_element<_Mbrs, _BoundMbrs>... { - using __virtuals_map_pair<_Mbrs, _BoundMbrs>::operator()...; + using __virtuals_map_element<_Mbrs, _BoundMbrs>::operator()...; }; template using __virtuals_map_for _CCCL_NODEBUG_ALIAS = __virtuals_map<__overrides_for<_Interface>, __overrides_for<__rebind_interface<_Interface, _Super>>>; +template +extern _CUDA_VSTD::__call_result_t<__virtuals_map_for<_Interface, _Super>, __ctag<_Mbr>> __virtual_fn_for_v; + +// This alias indirects through the above variable template to cache the result +// of the virtuals map lookup. +template +using __virtual_fn_for _CCCL_NODEBUG_ALIAS = decltype(__virtual_fn_for_v<_Mbr, _Interface, _Super>); + //! //! virtcall //! @@ -109,8 +112,8 @@ _CUDAX_HOST_API auto __virtcall(_Self* __self, _Args&&... __args) // auto* __vptr = __basic_any_access::__get_vptr(*__self)->__query_interface(_Interface()); auto* __obj = __basic_any_access::__get_optr(*__self); // map the member function pointer to the correct one if necessary - constexpr auto _Mbr2 = __virtuals_map_for<_Interface, _Super>{}(__ctag<_Mbr>()); - return __vptr->__virtual_fn<_Mbr2>::__fn_(__obj, static_cast<_Args&&>(__args)...); + using __virtual_fn_t = __virtual_fn_for<_Mbr, _Interface, _Super>; + return __vptr->__virtual_fn_t::__fn_(__obj, static_cast<_Args&&>(__args)...); } _CCCL_TEMPLATE(auto _Mbr, template class _Interface, class _Super, class... _Args) diff --git a/cudax/include/cuda/experimental/__utility/basic_any/virtual_functions.cuh b/cudax/include/cuda/experimental/__utility/basic_any/virtual_functions.cuh index 32d7a77ebba..f31d396b608 100644 --- a/cudax/include/cuda/experimental/__utility/basic_any/virtual_functions.cuh +++ b/cudax/include/cuda/experimental/__utility/basic_any/virtual_functions.cuh @@ -64,8 +64,8 @@ _CUDAX_TRIVIAL_API auto __c_style_cast(_Src* __ptr) noexcept -> _DstPtr } template -_CCCL_NODISCARD _CUDAX_API auto __override_fn_([[maybe_unused]] _CUDA_VSTD::__maybe_const<_IsConst, void>* __pv, - [[maybe_unused]] _Args... __args) noexcept(_IsNothrow) -> _Ret +_CCCL_NODISCARD _CUDAX_HOST_API auto __override_fn_([[maybe_unused]] _CUDA_VSTD::__maybe_const<_IsConst, void>* __pv, + [[maybe_unused]] _Args... __args) noexcept(_IsNothrow) -> _Ret { using __value_type _CCCL_NODEBUG_ALIAS = _CUDA_VSTD::__maybe_const<_IsConst, _Tp>;