From c035f4479d12b667d4a86b4a2418523e040baae7 Mon Sep 17 00:00:00 2001 From: "Sidorov, Dmitry" Date: Mon, 8 Apr 2024 05:16:52 -0700 Subject: [PATCH] Try to not use __spirv_Load/__spirv_Store Signed-off-by: Sidorov, Dmitry --- .../sycl/ext/oneapi/matrix/matrix-intel.hpp | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp index b5b38630b073e..ce6e10f979a4f 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp @@ -146,7 +146,8 @@ class wi_element { __spirv_AccessChain::value, spv_scope_traits::value>(&M.spvm, idx); - storage_element_type elem = __spirv_Load(ExtractP); + storage_element_type elem = *ExtractP; +// storage_element_type elem = __spirv_Load(ExtractP); #endif // USE_COOP_MATRIX return elem; #else @@ -169,7 +170,8 @@ class wi_element { __spirv_AccessChain::value, spv_scope_traits::value>(&M.spvm, idx); - return __spirv_Load(ExtractP) != static_cast(0); + return *ExtractP != static_cast(0); +// return __spirv_Load(ExtractP) != static_cast(0); #endif // USE_COOP_MATRIX #else throw runtime_error("joint matrix is not supported on host device.", @@ -184,7 +186,8 @@ class wi_element { M.spvm, static_cast(rhs), idx); #else T2 *InsertP = __spirv_AccessChain(&M.spvm, idx); - __spirv_Store(InsertP, static_cast(rhs)); + *InsertP = static_cast(rhs); +// __spirv_Store(InsertP, static_cast(rhs)); #endif // USE_COOP_MATRIX return *this; #else @@ -210,9 +213,13 @@ class wi_element { spv_matrix_use_traits::value, spv_scope_traits::value>( &rhs.M.spvm, rhs.idx); + T *InsertP = __spirv_AccessChain(&M.spvm, idx); + *InsertP = *ExtractP; +/* T RhsVal = __spirv_Load(ExtractP); T *InsertP = __spirv_AccessChain(&M.spvm, idx); __spirv_Store(InsertP, RhsVal); + */ #endif // USE_COOP_MATRIX return *this; #else @@ -245,10 +252,8 @@ class wi_element { spv_matrix_use_traits::value, \ spv_scope_traits::value>( \ &rhs.M.spvm, rhs.idx); \ - T RhsVal = \ - __spirv_Load(ExtractP) op static_cast(rhs); \ T *InsertP = __spirv_AccessChain(&M.spvm, idx); \ - __spirv_Store(static_cast(InsertP), RhsVal); \ + *InsertP = *ExtractP op static_cast(rhs); \ return *this; \ } #endif // USE_COOP_MATRIX @@ -315,7 +320,7 @@ class wi_element::value, spv_scope_traits::value>(&M.spvm, idx); - return __spirv_Load(ExtractP); + return *ExtractP; #endif // USE_COOP_MATRIX #else throw runtime_error("joint matrix is not supported on host device.", @@ -338,8 +343,8 @@ class wi_element::value, spv_scope_traits::value>(&M.spvm, idx); - sycl::ext::oneapi::bfloat16 Elem = - __spirv_Load(ExtractP); + sycl::ext::oneapi::bfloat16 Elem = *ExtractP; +// __spirv_Load(ExtractP); return sycl::fabs(static_cast(Elem)) >= std::numeric_limits::epsilon(); #endif // USE_COOP_MATRIX @@ -384,9 +389,11 @@ class wi_element::value, spv_scope_traits::value>(&rhs.M.spvm, rhs.idx); - sycl::ext::oneapi::bfloat16 RhsVal = __spirv_Load(ExtractP); sycl::ext::oneapi::bfloat16 *InsertP = __spirv_AccessChain(&M.spvm, idx); - __spirv_Store(InsertP, RhsVal); + *InsertP = *ExtractP; +/* sycl::ext::oneapi::bfloat16 RhsVal = __spirv_Load(ExtractP); + sycl::ext::oneapi::bfloat16 *InsertP = __spirv_AccessChain(&M.spvm, idx); + __spirv_Store(InsertP, RhsVal);*/ #endif // USE_COOP_MATRIX return *this; #else @@ -417,9 +424,8 @@ class wi_element::value, \ spv_scope_traits::value>(&M.spvm, idx); \ - sycl::ext::oneapi::bfloat16 RhsVal = __spirv_Load(ExtractP) op rhs; \ sycl::ext::oneapi::bfloat16 *InsertP = __spirv_AccessChain(&M.spvm, idx); \ - __spirv_Store(InsertP, RhsVal); \ + *InsertP = *ExtractP op rhs; \ return *this; \ } #endif // USE_COOP_MATRIX @@ -471,7 +477,7 @@ class wi_element::value, \ spv_scope_traits::value>(&lhs.M.spvm, \ lhs.idx); \ - return __spirv_Load(ExtractP) op rhs; \ + return *ExtractP op rhs; \ } \ friend type operator op( \ const sycl::ext::oneapi::bfloat16 &lhs, \ @@ -482,7 +488,7 @@ class wi_element::value, \ spv_scope_traits::value>(&rhs.M.spvm, \ rhs.idx); \ - return __spirv_Load(ExtractP) op lhs; \ + return *ExtractP op lhs; \ } #endif // USE_COOP_MATRIX OP(sycl::ext::oneapi::bfloat16, +) @@ -527,8 +533,7 @@ class wi_element::value, \ spv_scope_traits::value>(&lhs.M.spvm, \ lhs.idx); \ - return type{static_cast(__spirv_Load( \ - ExtractP)) op static_cast(rhs)}; \ + return type{static_cast(*ExtractP) op static_cast(rhs)}; \ } \ friend type operator op( \ const sycl::ext::oneapi::bfloat16 &lhs, \ @@ -539,8 +544,7 @@ class wi_element::value, \ spv_scope_traits::value>(&rhs.M.spvm, \ rhs.idx); \ - return type{static_cast(__spirv_Load( \ - ExtractP)) op static_cast(lhs)}; \ + return type{static_cast(*ExtractP) op static_cast(lhs)}; \ } #endif // USE_COOP_MATRIX OP(bool, ==)