diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index d8f57ef60a..b9f0d16c33 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -251,6 +251,7 @@ In chronological order: * Ye Tao * [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1 * [2025-02-27] Add sbgemv_n_neon kernel + * [2025-05-17] Impl prototype of BGEMM inferface * Abhishek Kumar - * [2025-04-22] Optimise dot kernel for NEOVERSE V1 \ No newline at end of file + * [2025-04-22] Optimise dot kernel for NEOVERSE V1 diff --git a/Makefile.system b/Makefile.system index 38646c3c6b..ff6b875554 100644 --- a/Makefile.system +++ b/Makefile.system @@ -1544,6 +1544,9 @@ ifeq ($(USE_TLS), 1) CCOMMON_OPT += -DUSE_TLS endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +CCOMMON_OPT += -DBUILD_BFLOAT16_ONLY +endif ifeq ($(BUILD_BFLOAT16), 1) CCOMMON_OPT += -DBUILD_BFLOAT16 endif @@ -1888,6 +1891,7 @@ export FUNCTION_PROFILE export TARGET_CORE export NO_AVX512 export NO_AVX2 +export BUILD_BFLOAT16_ONLY export BUILD_BFLOAT16 export NO_LSX export NO_LASX @@ -1912,7 +1916,7 @@ export ZGEMM3M_UNROLL_M export ZGEMM3M_UNROLL_N export XGEMM3M_UNROLL_M export XGEMM3M_UNROLL_N - +# Todo: add bgemm unroll factors ifdef USE_CUDA export CUDADIR diff --git a/Makefile.tail b/Makefile.tail index 54ba649dbf..5b5de184af 100644 --- a/Makefile.tail +++ b/Makefile.tail @@ -11,7 +11,7 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX)) HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX)) -BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) +BLASOBJS = $(SBEXTOBJS) $(BBLASOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS) BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P) ifdef EXPRECISION @@ -24,6 +24,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS) BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P) endif +$(BBLASOBJS) : override CFLAGS += -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX -UBFLOAT16 -USMALL_MATRIX_OPT $(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX @@ -42,6 +43,7 @@ $(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) $(SBEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF) + libs :: $(BLASOBJS) $(COMMONOBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ diff --git a/cblas.h b/cblas.h index 83686f7433..25de498b04 100644 --- a/cblas.h +++ b/cblas.h @@ -1,3 +1,31 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + #ifndef CBLAS_H #define CBLAS_H @@ -446,6 +474,9 @@ void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum C void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array, OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array, OPENBLAS_CONST blasint * M_array, OPENBLAS_CONST blasint * N_array, OPENBLAS_CONST blasint * K_array, OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size); +void cblas_bgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, + OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, bfloat16 *C, OPENBLAS_CONST blasint ldc); + #ifdef __cplusplus } #endif /* __cplusplus */ diff --git a/common.h b/common.h index 8d002c4aa0..66f5634574 100644 --- a/common.h +++ b/common.h @@ -1,5 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ +/* Copyright 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -306,6 +307,13 @@ typedef int blasint; #define SIZE 8 #define BASE_SHIFT 3 #define ZBASE_SHIFT 4 +#elif defined(BFLOAT16_ONLY) +#define IFLOAT bfloat16 +#define XFLOAT IFLOAT +#define FLOAT bfloat16 +#define SIZE 2 +#define BASE_SHIFT 1 +#define ZBASE_SHIFT 2 #elif defined(BFLOAT16) #define IFLOAT bfloat16 #define XFLOAT IFLOAT diff --git a/common_b.h b/common_b.h new file mode 100644 index 0000000000..4ac8f5cb35 --- /dev/null +++ b/common_b.h @@ -0,0 +1,86 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#ifndef COMMON_B_H +#define COMMON_B_H + +// for now, only support DYNAMIC_ARCH = 0 case. +#ifndef DYNAMIC_ARCH +#define BGEMM_ONCOPY bgemm_oncopy +#define BGEMM_OTCOPY bgemm_otcopy +#define BGEMM_INCOPY bgemm_incopy +#define BGEMM_ITCOPY bgemm_itcopy + +#define BGEMM_BETA bgemm_beta +#define BGEMM_KERNEL bgemm_kernel + +#else + +#define BGEMM_ONCOPY gotoblas -> bgemm_oncopy +#define BGEMM_OTCOPY gotoblas -> bgemm_otcopy +#define BGEMM_INCOPY gotoblas -> bgemm_incopy +#define BGEMM_ITCOPY gotoblas -> bgemm_itcopy +#define BGEMM_BETA gotoblas -> bgemm_beta +#define BGEMM_KERNEL gotoblas -> bgemm_kernel + +#endif + +#define BGEMM_NN bgemm_nn +#define BGEMM_CN bgemm_tn +#define BGEMM_TN bgemm_tn +#define BGEMM_NC bgemm_nt +#define BGEMM_NT bgemm_nt +#define BGEMM_CC bgemm_tt +#define BGEMM_CT bgemm_tt +#define BGEMM_TC bgemm_tt +#define BGEMM_TT bgemm_tt +#define BGEMM_NR bgemm_nn +#define BGEMM_TR bgemm_tn +#define BGEMM_CR bgemm_tn +#define BGEMM_RN bgemm_nn +#define BGEMM_RT bgemm_nt +#define BGEMM_RC bgemm_nt +#define BGEMM_RR bgemm_nn + +#define BGEMM_THREAD_NN bgemm_thread_nn +#define BGEMM_THREAD_CN bgemm_thread_tn +#define BGEMM_THREAD_TN bgemm_thread_tn +#define BGEMM_THREAD_NC bgemm_thread_nt +#define BGEMM_THREAD_NT bgemm_thread_nt +#define BGEMM_THREAD_CC bgemm_thread_tt +#define BGEMM_THREAD_CT bgemm_thread_tt +#define BGEMM_THREAD_TC bgemm_thread_tt +#define BGEMM_THREAD_TT bgemm_thread_tt +#define BGEMM_THREAD_NR bgemm_thread_nn +#define BGEMM_THREAD_TR bgemm_thread_tn +#define BGEMM_THREAD_CR bgemm_thread_tn +#define BGEMM_THREAD_RN bgemm_thread_nn +#define BGEMM_THREAD_RT bgemm_thread_nt +#define BGEMM_THREAD_RC bgemm_thread_nt +#define BGEMM_THREAD_RR bgemm_thread_nn +#endif \ No newline at end of file diff --git a/common_interface.h b/common_interface.h index efd3c6649d..ae4786acb1 100644 --- a/common_interface.h +++ b/common_interface.h @@ -1,5 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ +/* Copyright 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -480,7 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint xdouble *, blasint *, xdouble *, xdouble *, blasint *); /* Level 3 routines */ - +void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, float *, + bfloat16 *, blasint *, bfloat16 *, blasint *, float *, bfloat16 *, blasint *); void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *, bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *); void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *, diff --git a/common_level3.h b/common_level3.h index d370c1f96a..eaa33b2a26 100644 --- a/common_level3.h +++ b/common_level3.h @@ -54,7 +54,8 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K, int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); - +int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, + bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG); int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, @@ -78,6 +79,12 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG); #endif +// add bgemm copy functions +int bgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int bgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int bgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); +int bgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); + int sbgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); int sbgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b); @@ -505,6 +512,8 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag); +// add bgemm kernel +int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG); int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG); @@ -657,6 +666,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG); int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); +int bgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); + int sbgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); @@ -754,6 +768,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG); #endif +int bgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); +int bgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); + int sbgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); int sbgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG); diff --git a/common_macro.h b/common_macro.h index 820cb472a6..48e78564d1 100644 --- a/common_macro.h +++ b/common_macro.h @@ -1,5 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ +/* Copyright 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -39,6 +40,7 @@ #ifndef COMMON_MACRO #define COMMON_MACRO +#include "common_b.h" #include "common_sb.h" #include "common_s.h" #include "common_d.h" @@ -657,8 +659,52 @@ #define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN #define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT -#elif defined(BFLOAT16) +#elif defined(BFLOAT16_ONLY) +#define GEMM_BETA BGEMM_BETA +#define GEMM_KERNEL_N BGEMM_KERNEL +#define GEMM_KERNEL_L BGEMM_KERNEL +#define GEMM_KERNEL_R BGEMM_KERNEL +#define GEMM_KERNEL_B BGEMM_KERNEL + +#define GEMM_NN BGEMM_NN +#define GEMM_CN BGEMM_TN +#define GEMM_TN BGEMM_TN +#define GEMM_NC BGEMM_NT +#define GEMM_NT BGEMM_NT +#define GEMM_CC BGEMM_TT +#define GEMM_CT BGEMM_TT +#define GEMM_TC BGEMM_TT +#define GEMM_TT BGEMM_TT +#define GEMM_NR BGEMM_NN +#define GEMM_TR BGEMM_TN +#define GEMM_CR BGEMM_TN +#define GEMM_RN BGEMM_NN +#define GEMM_RT BGEMM_NT +#define GEMM_RC BGEMM_NT +#define GEMM_RR BGEMM_NN +#define GEMM_ONCOPY BGEMM_ONCOPY +#define GEMM_OTCOPY BGEMM_OTCOPY +#define GEMM_INCOPY BGEMM_INCOPY +#define GEMM_ITCOPY BGEMM_ITCOPY + +#define GEMM_THREAD_NN BGEMM_THREAD_NN +#define GEMM_THREAD_CN BGEMM_THREAD_TN +#define GEMM_THREAD_TN BGEMM_THREAD_TN +#define GEMM_THREAD_NC BGEMM_THREAD_NT +#define GEMM_THREAD_NT BGEMM_THREAD_NT +#define GEMM_THREAD_CC BGEMM_THREAD_TT +#define GEMM_THREAD_CT BGEMM_THREAD_TT +#define GEMM_THREAD_TC BGEMM_THREAD_TT +#define GEMM_THREAD_TT BGEMM_THREAD_TT +#define GEMM_THREAD_NR BGEMM_THREAD_NN +#define GEMM_THREAD_TR BGEMM_THREAD_TN +#define GEMM_THREAD_CR BGEMM_THREAD_TN +#define GEMM_THREAD_RN BGEMM_THREAD_NN +#define GEMM_THREAD_RT BGEMM_THREAD_NT +#define GEMM_THREAD_RC BGEMM_THREAD_NT +#define GEMM_THREAD_RR BGEMM_THREAD_NN +#elif defined(BFLOAT16) #define D_TO_BF16_K SBDTOBF16_K #define D_BF16_TO_K DBF16TOD_K #define S_TO_BF16_K SBSTOBF16_K @@ -2618,6 +2664,9 @@ || defined(ARCH_LOONGARCH64) || defined(ARCH_E2K) || defined(ARCH_ALPHA)) extern BLASLONG gemm_offset_a; extern BLASLONG gemm_offset_b; +extern BLASLONG bgemm_p; +extern BLASLONG bgemm_q; +extern BLASLONG bgemm_r; extern BLASLONG sbgemm_p; extern BLASLONG sbgemm_q; extern BLASLONG sbgemm_r; diff --git a/common_param.h b/common_param.h index 2d771a27da..b2b4b9cf64 100644 --- a/common_param.h +++ b/common_param.h @@ -1,6 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ -/* Copyright 2023 The OpenBLAS Project. */ +/* Copyright 2023, 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -49,6 +49,21 @@ typedef struct { int switch_ratio; int offsetA, offsetB, align; +#if BUILD_BFLOAT16_ONLY == 1 + int bgemm_p, bgemm_q, bgemm_r; + int bgemm_unroll_m, bgemm_unroll_n, bgemm_unroll_mn; + int bgemm_align_k; + + int (*bgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG); + int (*bgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG); + + int (*bgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*bgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*bgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + int (*bgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *); + +#endif + #if BUILD_BFLOAT16 == 1 int sbgemm_p, sbgemm_q, sbgemm_r; int sbgemm_unroll_m, sbgemm_unroll_n, sbgemm_unroll_mn; @@ -1229,6 +1244,15 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 gotoblas -> exclusive_cache +#if (BUILD_BFLOAT16_ONLY==1) +#define BGEMM_P gotoblas -> bgemm_p +#define BGEMM_Q gotoblas -> bgemm_q +#define BGEMM_R gotoblas -> bgemm_r +#define BGEMM_UNROLL_M gotoblas -> bgemm_unroll_m +#define BGEMM_UNROLL_N gotoblas -> bgemm_unroll_n +#define BGEMM_UNROLL_MN gotoblas -> bgemm_unroll_mn +#endif + #if (BUILD_BFLOAT16==1) #define SBGEMM_P gotoblas -> sbgemm_p #define SBGEMM_Q gotoblas -> sbgemm_q @@ -1357,6 +1381,19 @@ extern gotoblas_t *gotoblas; #define HAVE_EX_L2 0 #endif +#if (BUILD_BFLOAT16_ONLY == 1) +#define BGEMM_P BGEMM_DEFAULT_P +#define BGEMM_Q BGEMM_DEFAULT_Q +#define BGEMM_R BGEMM_DEFAULT_R +#define BGEMM_UNROLL_M BGEMM_DEFAULT_UNROLL_M +#define BGEMM_UNROLL_N BGEMM_DEFAULT_UNROLL_N +#ifdef BGEMM_DEFAULT_UNROLL_MN +#define BGEMM_UNROLL_MN BGEMM_DEFAULT_UNROLL_MN +#else +#define BGEMM_UNROLL_MN MAX((BGEMM_UNROLL_M), (BGEMM_UNROLL_N)) +#endif +#endif + #if (BUILD_BFLOAT16 == 1) #define SBGEMM_P SBGEMM_DEFAULT_P #define SBGEMM_Q SBGEMM_DEFAULT_Q @@ -1517,6 +1554,18 @@ extern gotoblas_t *gotoblas; #define GEMM_DEFAULT_R SBGEMM_DEFAULT_R #define GEMM_DEFAULT_UNROLL_M SBGEMM_DEFAULT_UNROLL_M #define GEMM_DEFAULT_UNROLL_N SBGEMM_DEFAULT_UNROLL_N +#elif defined(BFLOAFT16_ONLY) +#define GEMM_P BGEMM_P +#define GEMM_Q BGEMM_Q +#define GEMM_R BGEMM_R +#define GEMM_UNROLL_M BGEMM_UNROLL_M +#define GEMM_UNROLL_N BGEMM_UNROLL_N +#define GEMM_UNROLL_MN BGEMM_UNROLL_MN +#define GEMM_DEFAULT_P BGEMM_DEFAULT_P +#define GEMM_DEFAULT_Q BGEMM_DEFAULT_Q +#define GEMM_DEFAULT_R BGEMM_DEFAULT_R +#define GEMM_DEFAULT_UNROLL_M BGEMM_DEFAULT_UNROLL_M +#define GEMM_DEFAULT_UNROLL_N BGEMM_DEFAULT_UNROLL_N #else #define GEMM_P SGEMM_P #define GEMM_Q SGEMM_Q diff --git a/driver/level3/Makefile b/driver/level3/Makefile index c304838423..02c5ab92ef 100644 --- a/driver/level3/Makefile +++ b/driver/level3/Makefile @@ -1,3 +1,32 @@ +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### + TOPDIR = ../.. include ../../Makefile.system @@ -23,6 +52,13 @@ ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX) endif +ifeq ($(BUILD_BFLOAT16_ONLY),1) +BBLASOBJS += bgemm_nn.$(SUFFIX) bgemm_nt.$(SUFFIX) bgemm_tn.$(SUFFIX) bgemm_tt.$(SUFFIX) +endif + +BLASOBJS += \ + gemm_nn.$(SUFFIX) gemm_nt.$(SUFFIX) gemm_tn.$(SUFFIX) gemm_tt.$(SUFFIX) + SBLASOBJS += \ sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \ strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \ @@ -207,6 +243,10 @@ COMMONOBJS += gemm_thread_m.$(SUFFIX) gemm_thread_n.$(SUFFIX) gemm_thread_mn.$( COMMONOBJS += syrk_thread.$(SUFFIX) ifneq ($(USE_SIMPLE_THREADED_LEVEL3), 1) +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +BBLASOBJS += bgemm_thread_nn.$(SUFFIX) bgemm_thread_nt.$(SUFFIX) bgemm_thread_tn.$(SUFFIX) bgemm_thread_tt.$(SUFFIX) +endif + ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += sbgemm_thread_nn.$(SUFFIX) sbgemm_thread_nt.$(SUFFIX) sbgemm_thread_tn.$(SUFFIX) sbgemm_thread_tt.$(SUFFIX) endif @@ -343,6 +383,18 @@ endif all :: +bgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +bgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +bgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +bgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sbgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) @@ -399,8 +451,8 @@ cgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h cgemm_nr.$(SUFFIX) : gemm.c level3.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $(@F) - cgemm_nc.$(SUFFIX) : gemm.c level3.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $(@F) cgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h @@ -550,6 +602,18 @@ gemm_thread_variable.$(SUFFIX) : gemm_thread_variable.c ../../common.h beta_thread.$(SUFFIX) : beta_thread.c ../../common.h $(CC) -c $(CFLAGS) $< -o $(@F) +bgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) + +bgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F) + +bgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F) + +bgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h + $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F) + sbgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F) diff --git a/driver/level3/level3.c b/driver/level3/level3.c index b7328876b4..4596a8c12d 100644 --- a/driver/level3/level3.c +++ b/driver/level3/level3.c @@ -1,5 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ +/* Copyright (c) 2025, The OpenBLAS Project */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -42,9 +43,11 @@ #if !defined(XDOUBLE) || !defined(QUAD_PRECISION) #ifndef COMPLEX #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ - GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ - BETA[0], NULL, 0, NULL, 0, \ - (FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC) + do { \ + GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ + BETA[0], NULL, 0, NULL, 0, \ + (FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC); \ + } while (0) #else #define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \ GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \ @@ -169,10 +172,30 @@ #define STOP_RPCC(COUNTER) #endif +#if defined(HALF) +#if defined(DYNAMIC_ARCH) + #if defined(BUILD_BFLOAT16) + #define HALF_DTYPE_ALIGN_K gotoblas->sbgemm_align_k + #else + #define HALF_DTYPE_ALIGN_K gotoblas->bgemm_align_k + #endif +#else + #if defined(BUILD_BFLOAT16) + #define HALF_DTYPE_ALIGN_K SBGEMM_ALIGN_K + #else + #define HALF_DTYPE_ALIGN_K BGEMM_ALIGN_K + #endif +#endif +#endif + int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){ BLASLONG k, lda, ldb, ldc; +#if defined(BUILD_BFLOAT16_ONLY) + float *alpha, *beta; +#else FLOAT *alpha, *beta; +#endif IFLOAT *a, *b; FLOAT *c; BLASLONG m_from, m_to, n_from, n_to; @@ -207,8 +230,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, ldb = LDB; ldc = LDC; +#if defined(BUILD_BFLOAT16_ONLY) + alpha = (float *)args -> alpha; + beta = (float *)args -> beta; +#else alpha = (FLOAT *)args -> alpha; beta = (FLOAT *)args -> beta; +#endif + m_from = 0; m_to = M; @@ -305,12 +334,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, } BLASLONG pad_min_l = min_l; + #if defined(HALF) -#if defined(DYNAMIC_ARCH) - pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1); -#else - pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);; -#endif + pad_min_l = (min_l + HALF_DTYPE_ALIGN_K - 1) & ~(HALF_DTYPE_ALIGN_K - 1); #endif /* First, we have to move data A to L2 cache */ diff --git a/driver/level3/level3_thread.c b/driver/level3/level3_thread.c index db3bffc10a..64be542c2a 100644 --- a/driver/level3/level3_thread.c +++ b/driver/level3/level3_thread.c @@ -1,6 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ -/* Copyright 2023 The OpenBLAS Project. */ +/* Copyright 2023, 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -216,6 +216,22 @@ typedef struct { #define STOP_RPCC(COUNTER) #endif +#if defined(HALF) +#if defined(DYNAMIC_ARCH) + #if defined(BUILD_BFLOAT16) + #define HALF_DTYPE_ALIGN_K gotoblas->sbgemm_align_k + #else + #define HALF_DTYPE_ALIGN_K gotoblas->bgemm_align_k + #endif +#else + #if defined(BUILD_BFLOAT16) + #define HALF_DTYPE_ALIGN_K SBGEMM_ALIGN_K + #else + #define HALF_DTYPE_ALIGN_K BGEMM_ALIGN_K + #endif +#endif +#endif + static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){ IFLOAT *buffer[DIVIDE_RATE]; @@ -223,7 +239,11 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, BLASLONG k, lda, ldb, ldc; BLASLONG m_from, m_to, n_from, n_to; +#if defined(BUILD_BFLOAT16_ONLY) + float *alpha, *beta; +#else FLOAT *alpha, *beta; +#endif IFLOAT *a, *b; FLOAT *c; job_t *job = (job_t *)args -> common; @@ -261,8 +281,14 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, ldb = LDB; ldc = LDC; +#if defined(BUILD_BFLOAT16_ONLY) + alpha = (float *)args -> alpha; + beta = (float *)args -> beta; +#else alpha = (FLOAT *)args -> alpha; beta = (FLOAT *)args -> beta; +#endif + /* Initialize 2D CPU distribution */ nthreads_m = args -> nthreads; @@ -325,11 +351,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, BLASLONG pad_min_l = min_l; #if defined(HALF) -#if defined(DYNAMIC_ARCH) - pad_min_l = (min_l + gotoblas->sbgemm_align_k - 1) & ~(gotoblas->sbgemm_align_k-1); -#else - pad_min_l = (min_l + SBGEMM_ALIGN_K - 1) & ~(SBGEMM_ALIGN_K - 1);; -#endif + pad_min_l = (min_l + HALF_DTYPE_ALIGN_K - 1) & ~(HALF_DTYPE_ALIGN_K - 1); #endif /* Determine step size in m diff --git a/getarch_2nd.c b/getarch_2nd.c index dd1f830895..1df6cc3c7c 100644 --- a/getarch_2nd.c +++ b/getarch_2nd.c @@ -1,3 +1,31 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + #include #ifndef BUILD_KERNEL #include "config.h" @@ -17,6 +45,10 @@ typedef unsigned long BLASULONG; int main(int argc, char **argv) { if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) { + printf("BGEMM_UNROLL_M=%d\n", BGEMM_DEFAULT_UNROLL_M); + printf("BGEMM_UNROLL_N=%d\n", BGEMM_DEFAULT_UNROLL_N); + printf("BGEMM_UNROLL_M=%d\n", BGEMM_DEFAULT_UNROLL_M); + printf("BGEMM_UNROLL_N=%d\n", BGEMM_DEFAULT_UNROLL_N); printf("SBGEMM_UNROLL_M=%d\n", SBGEMM_DEFAULT_UNROLL_M); printf("SBGEMM_UNROLL_N=%d\n", SBGEMM_DEFAULT_UNROLL_N); printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M); diff --git a/interface/Makefile b/interface/Makefile index f09a6f46b9..adf4a9a6f7 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -53,6 +53,10 @@ SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +BBLAS3OBJ = bgemm.$(SUFFIX) +endif + DBLAS1OBJS = \ daxpy.$(SUFFIX) dswap.$(SUFFIX) \ dcopy.$(SUFFIX) dscal.$(SUFFIX) \ @@ -291,6 +295,10 @@ CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX) cblas_sbgemmt.$(SUFFIX) cblas_sbgemmtr.$(S CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX) endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +CBBLAS3OBJS = cblas_bgemm.$(SUFFIX) +endif + CDBLAS1OBJS = \ cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \ cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \ @@ -388,6 +396,7 @@ SBLAS3OBJS += $(CSBLAS3OBJS) SBBLAS1OBJS += $(CSBBLAS1OBJS) SBBLAS2OBJS += $(CSBBLAS2OBJS) SBBLAS3OBJS += $(CSBBLAS3OBJS) +BBLAS3OBJ += $(CBBLAS3OBJS) DBLAS1OBJS += $(CDBLAS1OBJS) DBLAS2OBJS += $(CDBLAS2OBJS) DBLAS3OBJS += $(CDBLAS3OBJS) @@ -403,6 +412,7 @@ SBEXTOBJS += $(CSBEXTOBJS) CBAUXOBJS += $(CXERBLAOBJ) endif +BBLASOBJS = $(BBLAS3OBJ) SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS) SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS) DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS) @@ -550,7 +560,7 @@ level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $ level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ -level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) +level3 : $(BBLAS3OBJ) $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS) $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^ aux : $(CBAUXOBJS) @@ -1309,6 +1319,11 @@ sbgemmtr.$(SUFFIX) sbgemmtr.$(PSUFFIX) : sbgemmt.c ../param.h $(CC) -c $(CFLAGS) -DRNAME $< -o $(@F) endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +bgemm.$(SUFFIX) : gemm.c ../param.h + $(CC) -c $(CFLAGS) $< -o $(@F) +endif + sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -c $(CFLAGS) $< -o $(@F) @@ -1968,6 +1983,11 @@ cblas_sbgemm.$(SUFFIX) cblas_sbgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) endif +ifeq ($(BUILD_BFLOAT16_ONLY),1) +cblas_bgemm.$(SUFFIX) cblas_bgemm.$(PSUFFIX) : gemm.c ../param.h + $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) +endif + cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F) diff --git a/interface/gemm.c b/interface/gemm.c index 54e5604fd3..d08304b575 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -56,6 +56,9 @@ #elif defined(BFLOAT16) #define ERROR_NAME "SBGEMM " #define GEMV BLASFUNC(sbgemv) +#elif defined(BFLOAT16_ONLY) +#define ERROR_NAME "BGEMM " +#undef GEMM_GEMV_FORWARD #else #define ERROR_NAME "SGEMM " #define GEMV BLASFUNC(sgemv) @@ -247,6 +250,15 @@ static inline int get_gemm_optimal_nthreads(double MNK) { #ifndef CBLAS +#ifdef BFLOAT16_ONLY +void NAME(char *TRANSA, char *TRANSB, + blasint *M, blasint *N, blasint *K, + float *alpha, + IFLOAT *a, blasint *ldA, + IFLOAT *b, blasint *ldB, + float *beta, + FLOAT *c, blasint *ldC){ +#else void NAME(char *TRANSA, char *TRANSB, blasint *M, blasint *N, blasint *K, FLOAT *alpha, @@ -254,7 +266,7 @@ void NAME(char *TRANSA, char *TRANSB, IFLOAT *b, blasint *ldB, FLOAT *beta, FLOAT *c, blasint *ldC){ - +#endif blas_arg_t args; int transa, transb, nrowa, nrowb; @@ -363,11 +375,19 @@ void NAME(char *TRANSA, char *TRANSB, void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint m, blasint n, blasint k, #ifndef COMPLEX + #ifdef BFLOAT16_ONLY + float alpha, + IFLOAT *a, blasint lda, + IFLOAT *b, blasint ldb, + float beta, + FLOAT *c, blasint ldc) { + #else FLOAT alpha, IFLOAT *a, blasint lda, IFLOAT *b, blasint ldb, FLOAT beta, FLOAT *c, blasint ldc) { + #endif #else void *valpha, void *va, blasint lda, diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 2bd6b294fb..fea5b0ea9f 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -1,3 +1,30 @@ +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### USE_GEMM3M = 0 OS := $(shell uname) @@ -109,6 +136,25 @@ endif endif endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +ifndef BGEMMKERNEL +BGEMM_BETA = ../generic/gemm_beta.c +BGEMMKERNEL = ../generic/gemmkernel_2x2.c +BGEMMINCOPY = ../generic/gemm_ncopy_2.c +BGEMMITCOPY = ../generic/gemm_tcopy_2.c +BGEMMONCOPY = ../generic/gemm_ncopy_2.c +BGEMMOTCOPY = ../generic/gemm_tcopy_2.c +BGEMMINCOPYOBJ = bgemm_incopy$(TSUFFIX).$(SUFFIX) +BGEMMITCOPYOBJ = bgemm_itcopy$(TSUFFIX).$(SUFFIX) +BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) +BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) +endif +BKERNELOBJS += \ + bgemm_kernel$(TSUFFIX).$(SUFFIX) \ + $(BGEMMINCOPYOBJ) $(BGEMMITCOPYOBJ) \ + $(BGEMMONCOPYOBJ) $(BGEMMOTCOPYOBJ) +endif + ifeq ($(BUILD_BFLOAT16), 1) ifndef SBGEMMKERNEL SBGEMM_BETA = ../generic/gemm_beta.c @@ -189,6 +235,11 @@ XKERNELOBJS += \ $(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \ $(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ) + +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +BBLASOBJS += $(BKERNELOBJS) +endif + ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += $(SBKERNELOBJS) endif @@ -199,6 +250,10 @@ CBLASOBJS += $(CKERNELOBJS) ZBLASOBJS += $(ZKERNELOBJS) XBLASOBJS += $(XKERNELOBJS) +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +BBLASOBJS += bgemm_beta$(TSUFFIX).$(SUFFIX) +endif + ifeq ($(BUILD_BFLOAT16),1) SBBLASOBJS += sbgemm_beta$(TSUFFIX).$(SUFFIX) endif @@ -624,6 +679,11 @@ XGEMMITCOPYOBJ_P = $(XGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX)) +ifeq ($(BUILD_BFLOAT16_ONLY),1) +$(KDIR)bgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMM_BETA) + $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ +endif + ifeq ($(BUILD_BFLOAT16),1) $(KDIR)sbgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_BETA) $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ @@ -651,6 +711,21 @@ ifeq ($(ARCH), E2K) USE_TRMM = 1 endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) + +$(KDIR)$(BGEMMONCOPYOBJ) : $(KERNELDIR)/$(BGEMMONCOPY) + $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(BGEMMOTCOPYOBJ) : $(KERNELDIR)/$(BGEMMOTCOPY) + $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(BGEMMINCOPYOBJ) : $(KERNELDIR)/$(BGEMMINCOPY) + $(CC) $(CFLAGS) -c -DDBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)$(BGEMMITCOPYOBJ) : $(KERNELDIR)/$(BGEMMITCOPY) + $(CC) $(CFLAGS) -c -DDBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ + +endif ifeq ($(BUILD_BFLOAT16), 1) @@ -660,7 +735,7 @@ $(KDIR)$(SBGEMMONCOPYOBJ) : $(KERNELDIR)/$(SBGEMMONCOPY) $(KDIR)$(SBGEMMOTCOPYOBJ) : $(KERNELDIR)/$(SBGEMMOTCOPY) $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ -ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) +#ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) $(KDIR)$(SBGEMMINCOPYOBJ) : $(KERNELDIR)/$(SBGEMMINCOPY) $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ @@ -668,7 +743,7 @@ $(KDIR)$(SBGEMMINCOPYOBJ) : $(KERNELDIR)/$(SBGEMMINCOPY) $(KDIR)$(SBGEMMITCOPYOBJ) : $(KERNELDIR)/$(SBGEMMITCOPY) $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ -endif +#endif endif $(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY) @@ -847,6 +922,11 @@ endif endif endif +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +$(KDIR)bgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BGEMMKERNEL) + $(CC) $(CFLAGS) -c -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX $< -o $@ +endif + ifeq ($(BUILD_BFLOAT16), 1) $(KDIR)sbgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMMKERNEL) $(SBGEMMDEPEND) diff --git a/kernel/arm64/KERNEL.NEOVERSEV1 b/kernel/arm64/KERNEL.NEOVERSEV1 index 3e622bcbfb..3388247653 100644 --- a/kernel/arm64/KERNEL.NEOVERSEV1 +++ b/kernel/arm64/KERNEL.NEOVERSEV1 @@ -21,4 +21,18 @@ SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) SBGEMVNKERNEL = sbgemv_n_neon.c SBGEMVTKERNEL = sbgemv_t_bfdot.c +endif + +ifeq ($(BUILD_BFLOAT16_ONLY), 1) +BGEMM_BETA = bgemm_beta_neon.c +BGEMMKERNEL = bgemm_kernel_$(BGEMM_UNROLL_M)x$(BGEMM_UNROLL_N)_neoversev1.c +BGEMMINCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_M)_neoversev1.c +BGEMMITCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_M)_neoversev1.c +BGEMMINCOPYOBJ = bgemm_incopy$(TSUFFIX).$(SUFFIX) +BGEMMITCOPYOBJ = bgemm_itcopy$(TSUFFIX).$(SUFFIX) +BGEMMONCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_N)_neoversev1.c +BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversev1.c +BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) +BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) + endif \ No newline at end of file diff --git a/kernel/arm64/bgemm_beta_neon.c b/kernel/arm64/bgemm_beta_neon.c new file mode 100644 index 0000000000..3b2b704191 --- /dev/null +++ b/kernel/arm64/bgemm_beta_neon.c @@ -0,0 +1,103 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include "common.h" + +#include + +int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, float beta, IFLOAT *dummy2, + BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, + BLASLONG ldc) { + BLASLONG i, j; + BLASLONG chunk, remain; + + bfloat16_t *ptr_c, *ptr_c0; + + bfloat16x8_t x0, z0; + float32x4_t y0, y1; + + float x, z; + + bfloat16_t zero_bf16 = vcvth_bf16_f32(0.0f); + bfloat16x8_t zeros = vdupq_n_bf16(zero_bf16); + + float32x4_t beta_neon = vdupq_n_f32(beta); + + ptr_c = (bfloat16_t *)c; + + chunk = m >> 3; + remain = m & 7; + + if (beta == 0.0f){ + for (j = 0; j < n; j ++){ + ptr_c0 = ptr_c; + ptr_c += ldc; + + for (i = 0; i < chunk; i ++){ + vst1q_bf16(ptr_c0, zeros); + ptr_c0 += 8; + } + + for (i = 0; i < remain; i ++){ + ptr_c0[0] = zero_bf16; + ptr_c0 ++; + } + } + } else { + for (j = 0; j < n; j ++){ + ptr_c0 = ptr_c; + ptr_c += ldc; + + for (i = 0; i < chunk; i ++){ + x0 = vld1q_bf16(ptr_c0); + + y0 = vcvtq_low_f32_bf16(x0); + y1 = vcvtq_high_f32_bf16(x0); + + y0 = vmulq_f32(y0, beta_neon); + y1 = vmulq_f32(y1, beta_neon); + + z0 = vcvtq_low_bf16_f32(y0); + z0 = vcvtq_high_bf16_f32(z0, y1); + + vst1q_bf16(ptr_c0, z0); + + ptr_c0 += 8; + } + + for (i = 0; i < remain; i ++){ + x = vcvtah_f32_bf16(ptr_c0[0]); + z = vcvth_bf16_f32(x * beta); + + ptr_c0[0] = z; + ptr_c0 ++; + } + } + } + return 0; +}; diff --git a/kernel/arm64/bgemm_kernel.c b/kernel/arm64/bgemm_kernel.c new file mode 100644 index 0000000000..223cdc39c4 --- /dev/null +++ b/kernel/arm64/bgemm_kernel.c @@ -0,0 +1,37 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include +#include "common.h" + +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 alpha, IFLOAT *A, IFLOAT *B, + FLOAT *C, BLASLONG ldc) { + printf("running bgemm_kernel...\n"); + return 0; +} + diff --git a/kernel/arm64/bgemm_kernel_4x4_neoversev1.c b/kernel/arm64/bgemm_kernel_4x4_neoversev1.c new file mode 100644 index 0000000000..413342851d --- /dev/null +++ b/kernel/arm64/bgemm_kernel_4x4_neoversev1.c @@ -0,0 +1,46 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include + +#include "common.h" + +#define ALPHA_ONE +#include "bgemm_kernel_4x4_neoversev1_impl.c" +#undef ALPHA_ONE +#include "bgemm_kernel_4x4_neoversev1_impl.c" + +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, IFLOAT *A, IFLOAT *B, + FLOAT *C, BLASLONG ldc) { + if (alpha == 1.0f) + return bgemm_kernel_neoversev1_alpha_one(m, n, k, alpha, A, B, C, ldc); + else + return bgemm_kernel_neoversev1_alpha(m, n, k, alpha, A, B, C, ldc); + return 0; +} + diff --git a/kernel/arm64/bgemm_kernel_4x4_neoversev1_impl.c b/kernel/arm64/bgemm_kernel_4x4_neoversev1_impl.c new file mode 100644 index 0000000000..2eb754c004 --- /dev/null +++ b/kernel/arm64/bgemm_kernel_4x4_neoversev1_impl.c @@ -0,0 +1,431 @@ +/*************************************************************************** + * Copyright (c) 2025, The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include + +#include "common.h" + +#define INIT_C(M, N) mc##M##N = svdup_f32(0); + +#define MATMUL(M, N) mc##M##N = svbfmmla(mc##M##N, ma##M, mb##N); + +#define INIT_C_4x4 \ + do { \ + INIT_C(0, 0); \ + INIT_C(0, 1); \ + INIT_C(1, 0); \ + INIT_C(1, 1); \ + } while (0); + +#ifdef ALPHA_ONE +#define UPDATE_C(PG16, PG32, PTR, TMP32, TMP16, SRC32) \ + do { \ + TMP16 = svld1_bf16((PG16), (PTR)); \ + TMP16 = svzip1_bf16(BF16_ZEROS, TMP16); \ + TMP32 = svreinterpret_f32(TMP16); \ + TMP32 = svadd_z((PG32), SRC32, TMP32); \ + TMP16 = svcvt_bf16_f32_z((PG32), TMP32); \ + TMP16 = svuzp1_bf16(TMP16, TMP16); \ + svst1_bf16((PG16), (PTR), TMP16); \ + } while (0) +#else +#define UPDATE_C(PG16, PG32, PTR, TMP32, TMP16, SRC32) \ + do { \ + TMP16 = svld1_bf16((PG16), (PTR)); \ + TMP16 = svzip1_bf16(BF16_ZEROS, TMP16); \ + TMP32 = svreinterpret_f32(TMP16); \ + TMP32 = svmad_z((PG32), svalpha, SRC32, TMP32); \ + TMP16 = svcvt_bf16_f32_z((PG32), TMP32); \ + TMP16 = svuzp1_bf16(TMP16, TMP16); \ + svst1_bf16((PG16), (PTR), TMP16); \ + } while (0) +#endif + +#define ZIP_EVEN_ELEMENTS(PG, mc0, mc1, tmp, vc) \ + do { \ + (tmp) = svuzp1_f32((mc0), (mc1)); \ + (vc) = svcompact_f32((PG), (tmp)); \ + } while (0) + +#define ZIP_ODD_ELEMENTS(PG, mc0, mc1, tmp, vc) \ + do { \ + (tmp) = svuzp2_f32((mc0), (mc1)); \ + (vc) = svcompact_f32((PG), (tmp)); \ + } while (0) + +#define ACCUMULATE_LAST4_TO_FIRST4(M, N, TMP) \ + do { \ + TMP = svext_f32(mc##M##N, mc##M##N, 4); \ + mc##M##N = svadd_f32_z(svptrue_b32(), mc##M##N, (TMP)); \ + } while (0) + +#ifdef ALPHA_ONE +int bgemm_kernel_neoversev1_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, + float alpha, IFLOAT *A, IFLOAT *B, + FLOAT *C, BLASLONG ldc) +#else +int bgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k, + float alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, + BLASLONG ldc) +#endif +{ + BLASLONG pad_k = (k + 7) & ~7; + svbfloat16_t ma0, ma1, mb0, mb1; + svfloat32_t mc00, mc01, mc10, mc11, vc0, vc1, vc2, vc3; + svfloat32_t tmp; + svfloat32_t svalpha = svdup_f32(alpha); + + svbool_t pg16_all = svptrue_b16(); + + svbool_t pg32_first_1 = svwhilelt_b32(0, 1); + svbool_t pg32_first_2 = svwhilelt_b32(0, 2); + svbool_t pg32_first_4 = svwhilelt_b32(0, 4); + + svbool_t pg16_first_1 = svwhilelt_b16(0, 1); + svbool_t pg16_first_2 = svwhilelt_b16(0, 2); + svbool_t pg16_first_4 = svwhilelt_b16(0, 4); + + svbool_t pg32_select_first_2_per_quadword = svdupq_b32(1, 1, 0, 0); + + bfloat16_t *ptr_a = (bfloat16_t *)A; + bfloat16_t *ptr_b = (bfloat16_t *)B; + bfloat16_t *ptr_c = (bfloat16_t *)C; + + bfloat16_t *ptr_a0; + bfloat16_t *ptr_b0; + bfloat16_t *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3; + + svfloat32_t tmp32; + svbfloat16_t tmp16; + + svbfloat16_t BF16_ZEROS = svdup_n_bf16(0.0); + + for (BLASLONG j = 0; j < n / 4; j++) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c2 = ptr_c1 + ldc; + ptr_c3 = ptr_c2 + ldc; + ptr_c += 4 * ldc; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 4; i++) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C_4x4; + + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + ma1 = svld1_bf16(pg16_all, ptr_a0 + 16); + + mb0 = svld1_bf16(pg16_all, ptr_b0); + mb1 = svld1_bf16(pg16_all, ptr_b0 + 16); + + MATMUL(0, 0); + MATMUL(0, 1); + MATMUL(1, 0); + MATMUL(1, 1); + + ptr_a0 += 32; + ptr_b0 += 32; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp); + ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(1, 1, tmp); + + ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0); + ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc1); + + ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc01, mc11, tmp, vc2); + ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc01, mc11, tmp, vc3); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, tmp32, tmp16, vc1); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2, tmp32, tmp16, vc2); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3, tmp32, tmp16, vc3); + + ptr_c0 += 4; + ptr_c1 += 4; + ptr_c2 += 4; + ptr_c3 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + + ptr_b0 = ptr_b; + INIT_C(0, 0); + INIT_C(0, 1); + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + mb1 = svld1_bf16(pg16_all, ptr_b0 + 16); + + MATMUL(0, 0); + MATMUL(0, 1); + + ptr_a0 += 16; + ptr_b0 += 32; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp); + + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + vc2 = svuzp1(mc01, mc01); + vc3 = svuzp2(mc01, mc01); + + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, tmp32, tmp16, vc1); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c2, tmp32, tmp16, vc2); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c3, tmp32, tmp16, vc3); + + ptr_c0 += 2; + ptr_c1 += 2; + ptr_c2 += 2; + ptr_c3 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(0, 1); + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + mb1 = svld1_bf16(pg16_all, ptr_b0 + 16); + + MATMUL(0, 0); + MATMUL(0, 1); + + ptr_a0 += 16; + ptr_b0 += 32; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(0, 1, tmp); + + // use compact is more straightforward + vc1 = svuzp2(mc00, mc00); + vc3 = svuzp2(mc01, mc01); + + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, tmp32, tmp16, vc1); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c2, tmp32, tmp16, mc01); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c3, tmp32, tmp16, vc3); + } + + ptr_b += 4 * pad_k; + } + + if (n & 2) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c += 2 * ldc; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 4; i++) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(1, 0); + + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + ma1 = svld1_bf16(pg16_all, ptr_a0 + 16); + + mb0 = svld1_bf16(pg16_all, ptr_b0); + + MATMUL(0, 0); + MATMUL(1, 0); + + ptr_a0 += 32; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp); + + ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0); + ZIP_ODD_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc2); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, tmp32, tmp16, vc2); + + ptr_c0 += 4; + ptr_c1 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + + MATMUL(0, 0); + + ptr_a0 += 16; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + + + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, tmp32, tmp16, vc1); + + ptr_c0 += 2; + ptr_c1 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + INIT_C(0, 0); + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + MATMUL(0, 0); + ptr_a0 += 16; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + vc1 = svuzp2(mc00, mc00); + + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, tmp32, tmp16, vc1); + } + + ptr_b += 2 * pad_k; + } + + if (n & 1) { // TODO: this case seems a overhead. find out whether it's in our + // case. + ptr_c0 = ptr_c; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 4; i++) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(1, 0); + + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + ma1 = svld1_bf16(pg16_all, ptr_a0 + 16); + + mb0 = svld1_bf16(pg16_all, ptr_b0); + + MATMUL(0, 0); + MATMUL(1, 0); + + ptr_a0 += 32; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + ACCUMULATE_LAST4_TO_FIRST4(1, 0, tmp); + + ZIP_EVEN_ELEMENTS(pg32_select_first_2_per_quadword, mc00, mc10, tmp, vc0); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, tmp32, tmp16, vc0); + + ptr_c0 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + + for (BLASLONG p = 0; p < pad_k; p += 8) { + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + + MATMUL(0, 0); + + ptr_a0 += 16; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + + vc0 = svuzp1(mc00, mc00); + + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, tmp32, tmp16, vc0); + + ptr_c0 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + for (BLASLONG p = 0; p < pad_k; p += 8) { + + ma0 = svld1_bf16(pg16_all, ptr_a0); + mb0 = svld1_bf16(pg16_all, ptr_b0); + + MATMUL(0, 0); + ptr_a0 += 16; + ptr_b0 += 16; + } + + ACCUMULATE_LAST4_TO_FIRST4(0, 0, tmp); + + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, tmp32, tmp16, mc00); + } + } + + return 0; +} diff --git a/kernel/arm64/sbgemm_kernel_4x4_neoversev1.c b/kernel/arm64/sbgemm_kernel_4x4_neoversev1.c index 889b5fc5b8..772e45da98 100644 --- a/kernel/arm64/sbgemm_kernel_4x4_neoversev1.c +++ b/kernel/arm64/sbgemm_kernel_4x4_neoversev1.c @@ -35,7 +35,7 @@ #undef ALPHA_ONE #include "sbgemm_kernel_4x4_neoversev1_impl.c" -int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) { if (alpha == 1.0f) return sbgemm_kernel_neoversev1_alpha_one(m, n, k, alpha, A, B, C, ldc); diff --git a/kernel/arm64/sbgemm_kernel_4x4_neoversev1_impl.c b/kernel/arm64/sbgemm_kernel_4x4_neoversev1_impl.c index b6d9e9816c..02b101f112 100644 --- a/kernel/arm64/sbgemm_kernel_4x4_neoversev1_impl.c +++ b/kernel/arm64/sbgemm_kernel_4x4_neoversev1_impl.c @@ -78,11 +78,11 @@ #ifdef ALPHA_ONE int sbgemm_kernel_neoversev1_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, - FLOAT alpha, IFLOAT *A, IFLOAT *B, + float alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) #else int sbgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k, - FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, + float alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc) #endif { diff --git a/kernel/generic/gemm_beta.c b/kernel/generic/gemm_beta.c index ccb772cc7d..36522ad222 100644 --- a/kernel/generic/gemm_beta.c +++ b/kernel/generic/gemm_beta.c @@ -1,5 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ +/* Copyright 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -38,10 +39,47 @@ #include "common.h" +#if (defined(BFLOAT16) || defined(BFLOAT16_ONLY)) && defined(BFLOAT16CONVERSION) +static float +bfloat16tof32 (bfloat16 f16) +{ + float result = 0; + unsigned short* q = (unsigned short*)(&result); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = f16; +#else + q[1] = f16; +#endif + return result; +} + +static bfloat16 +f32tobfloat16(float f32) +{ + unsigned short* q = (unsigned short*)(&f32); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return q[0]; +#else + return q[1]; +#endif +} + +#define BF16TOF32(x) (bfloat16tof32(x)) +#define F32TOBF16(x) (f32tobfloat16(x)) +#else +#define BF16TOF32(x) x +#define F32TOBF16(x) x +#endif + +#if defined(BFLOAT16_ONLY) +int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, float beta, + IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, + FLOAT *c, BLASLONG ldc){ +#else int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, BLASLONG ldc){ - +#endif BLASLONG i, j; BLASLONG chunk, remain; @@ -49,23 +87,24 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, c_offset = c; chunk = m >> 3; remain = m & 7; + if (beta == ZERO){ for(j=n; j>0; j--){ c_offset1 = c_offset; c_offset += ldc; for(i=chunk; i>0; i--){ - *(c_offset1 + 0) = ZERO; - *(c_offset1 + 1) = ZERO; - *(c_offset1 + 2) = ZERO; - *(c_offset1 + 3) = ZERO; - *(c_offset1 + 4) = ZERO; - *(c_offset1 + 5) = ZERO; - *(c_offset1 + 6) = ZERO; - *(c_offset1 + 7) = ZERO; + *(c_offset1 + 0) = F32TOBF16(ZERO); + *(c_offset1 + 1) = F32TOBF16(ZERO); + *(c_offset1 + 2) = F32TOBF16(ZERO); + *(c_offset1 + 3) = F32TOBF16(ZERO); + *(c_offset1 + 4) = F32TOBF16(ZERO); + *(c_offset1 + 5) = F32TOBF16(ZERO); + *(c_offset1 + 6) = F32TOBF16(ZERO); + *(c_offset1 + 7) = F32TOBF16(ZERO); c_offset1 += 8; } for(i=remain; i>0; i--){ - *c_offset1 = ZERO; + *c_offset1 = F32TOBF16(ZERO); c_offset1 ++; } } @@ -74,18 +113,18 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, c_offset1 = c_offset; c_offset += ldc; for(i=chunk; i>0; i--){ - *(c_offset1 + 0) *= beta; - *(c_offset1 + 1) *= beta; - *(c_offset1 + 2) *= beta; - *(c_offset1 + 3) *= beta; - *(c_offset1 + 4) *= beta; - *(c_offset1 + 5) *= beta; - *(c_offset1 + 6) *= beta; - *(c_offset1 + 7) *= beta; + *(c_offset1 + 0) = F32TOBF16(beta * BF16TOF32(c_offset1[0])); + *(c_offset1 + 1) = F32TOBF16(beta * BF16TOF32(c_offset1[1])); + *(c_offset1 + 2) = F32TOBF16(beta * BF16TOF32(c_offset1[2])); + *(c_offset1 + 3) = F32TOBF16(beta * BF16TOF32(c_offset1[3])); + *(c_offset1 + 4) = F32TOBF16(beta * BF16TOF32(c_offset1[4])); + *(c_offset1 + 5) = F32TOBF16(beta * BF16TOF32(c_offset1[5])); + *(c_offset1 + 6) = F32TOBF16(beta * BF16TOF32(c_offset1[6])); + *(c_offset1 + 7) = F32TOBF16(beta * BF16TOF32(c_offset1[7])); c_offset1 += 8; } for(i=remain; i>0; i--){ - *c_offset1 *= beta; + *c_offset1 = F32TOBF16(beta * BF16TOF32(c_offset1[0])); c_offset1 ++; } } diff --git a/kernel/generic/gemmkernel_2x2.c b/kernel/generic/gemmkernel_2x2.c index bf1c3ae381..3cf2d928fd 100644 --- a/kernel/generic/gemmkernel_2x2.c +++ b/kernel/generic/gemmkernel_2x2.c @@ -1,5 +1,32 @@ +/*************************************************************************** +Copyright (c) 2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + #include "common.h" -#if defined(BFLOAT16) && defined(BFLOAT16CONVERSION) +#if (defined(BFLOAT16) || defined(BFLOAT16_ONLY))&& defined(BFLOAT16CONVERSION) static float bfloat16tof32 (bfloat16 f16) { @@ -12,12 +39,29 @@ bfloat16tof32 (bfloat16 f16) #endif return result; } + +static bfloat16 f32tobfloat16(float f32) { + unsigned short *q = (unsigned short *)(&f32); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + return q[0]; +#else + return q[1]; +#endif +} + #define BF16TOF32(x) (bfloat16tof32(x)) +#define F32TOBF16(x) (f32tobfloat16(x)) #else #define BF16TOF32(x) x +#define F32TOBF16(x) x #endif + +#ifdef BFLOAT16_ONLY +int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk, float alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc +#else int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc -#ifdef TRMMKERNEL +#endif + #ifdef TRMMKERNEL ,BLASLONG offset #endif ) @@ -90,13 +134,17 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+2; } res0 = res0*alpha; - C0[0] = C0[0]+res0; + C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); res1 = res1*alpha; - C0[1] = C0[1]+res1; + C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[1])); res2 = res2*alpha; - C1[0] = C1[0]+res2; + C1[0] = F32TOBF16(BF16TOF32(C1[0])+res2); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[0])); res3 = res3*alpha; - C1[1] = C1[1]+res3; + C1[1] = F32TOBF16(BF16TOF32(C1[1])+res3); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[1])); C0 = C0+2; C1 = C1+2; } @@ -116,9 +164,11 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+2; } res0 = res0*alpha; - C0[0] = C0[0]+res0; + C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); res1 = res1*alpha; - C1[0] = C1[0]+res1; + C1[0] = F32TOBF16(BF16TOF32(C1[1])+res1); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C1[0])); C0 = C0+1; C1 = C1+1; } @@ -147,9 +197,11 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+1; } res0 = res0*alpha; - C0[0] = C0[0]+res0; + C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); res1 = res1*alpha; - C0[1] = C0[1]+res1; + C0[1] = F32TOBF16(BF16TOF32(C0[1])+res1); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[1])); C0 = C0+2; } for (i=0; i<(bm&1); i+=1) @@ -165,7 +217,8 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb, ptrbb = ptrbb+1; } res0 = res0*alpha; - C0[0] = C0[0]+res0; + C0[0] = F32TOBF16(BF16TOF32(C0[0])+res0); + //printf("i = %d, j = %d, r = %.2f\n", i, j , BF16TOF32(C0[0])); C0 = C0+1; } k = (bk<<0); diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 5a5045ce23..ea05e8dc9f 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -1,6 +1,6 @@ /*********************************************************************/ /* Copyright 2009, 2010 The University of Texas at Austin. */ -/* Copyright 2023 The OpenBLAS Project. */ +/* Copyright 2023, 2025 The OpenBLAS Project. */ /* All rights reserved. */ /* */ /* Redistribution and use in source and binary forms, with or */ @@ -56,6 +56,20 @@ gotoblas_t TABLE_NAME = { GEMM_DEFAULT_OFFSET_A, GEMM_DEFAULT_OFFSET_B, GEMM_DEFAULT_ALIGN, +#ifdef BUILD_BFLOAT16_ONLY + 0, 0, 0, + BGEMM_DEFAULT_UNROLL_M, BGEMM_DEFAULT_UNROLL_N, +#ifdef BGEMM_DEFAULT_UNROLL_MN + BGEMM_DEFAULT_UNROLL_MN, +#else + MAX(BGEMM_DEFAULT_UNROLL_M, BGEMM_DEFAULT_UNROLL_N), +#endif + BGEMM_ALIGN_K, + bgemm_kernelTS, bgemm_betaTS, + bgemm_incopyTS, bgemm_itcopyTS, + bgemm_oncopyTS, bgemm_otcopyTS, +#endif + #ifdef BUILD_BFLOAT16 0, 0, 0, SBGEMM_DEFAULT_UNROLL_M, SBGEMM_DEFAULT_UNROLL_N, diff --git a/param.h b/param.h index 48b64fd2ae..fe2239c774 100644 --- a/param.h +++ b/param.h @@ -1,5 +1,5 @@ /***************************************************************************** -Copyright (c) 2011-2023, The OpenBLAS Project +Copyright (c) 2011-2023, 2025 The OpenBLAS Project All rights reserved. Redistribution and use in source and binary forms, with or without @@ -72,10 +72,17 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef PARAM_H #define PARAM_H +#define BGEMM_DEFAULT_UNROLL_N 4 +#define BGEMM_DEFAULT_UNROLL_M 8 +#define BGEMM_DEFAULT_UNROLL_MN 32 +#define BGEMM_DEFAULT_P 256 +#define BGEMM_DEFAULT_R 256 +#define BGEMM_DEFAULT_Q 256 +#define BGEMM_ALIGN_K 1 // must be 2^x #define SBGEMM_DEFAULT_UNROLL_N 4 -#define SBGEMM_DEFAULT_UNROLL_M 8 -#define SBGEMM_DEFAULT_UNROLL_MN 32 +#define SBGEMM_DEFAULT_UNROLL_M 4 +#define SBGEMM_DEFAULT_UNROLL_MN 4 #define SBGEMM_DEFAULT_P 256 #define SBGEMM_DEFAULT_R 256 #define SBGEMM_DEFAULT_Q 256 @@ -3556,6 +3563,13 @@ is a big desktop or server with abundant cache rather than a phone or embedded d #define GEMM_PREFERED_SIZE 8 #endif +#undef BGEMM_ALIGN_K +#undef BGEMM_DEFAULT_UNROLL_M +#undef BGEMM_DEFAULT_UNROLL_N +#define BGEMM_ALIGN_K 8 +#define BGEMM_DEFAULT_UNROLL_M 4 +#define BGEMM_DEFAULT_UNROLL_N 4 + #undef SBGEMM_ALIGN_K #undef SBGEMM_DEFAULT_UNROLL_M #undef SBGEMM_DEFAULT_UNROLL_N diff --git a/test/Makefile b/test/Makefile index 9ba88988b2..e9a77dc056 100644 --- a/test/Makefile +++ b/test/Makefile @@ -1,3 +1,31 @@ +############################################################################### +# Copyright (c) 2025, The OpenBLAS Project +# All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# 3. Neither the name of the OpenBLAS project nor the names of +# its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +############################################################################### TOPDIR = .. include ../Makefile.system ifeq ($(F_COMPILER),GFORTRAN) @@ -164,6 +192,9 @@ endif endif endif +ifeq ($(BUILD_BFLOAT16_ONLY),1) +BF3= test_bgemm +endif ifeq ($(BUILD_BFLOAT16),1) B3= test_sbgemm endif @@ -192,11 +223,15 @@ endif ifeq ($(SUPPORT_GEMM3M),1) level3: $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m else -level3: $(B3) $(S3) $(D3) $(C3) $(Z3) +level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) endif ifneq ($(CROSS), 1) rm -f ?BLAT3.SUMM +ifeq ($(BUILD_BFLOAT16_ONLY),1) + OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_bgemm > BBLAT3.SUMM + @$(GREP) -q FATAL SBBLAT3.SUMM && cat BBLAT3.SUMM || exit 0 +endif ifeq ($(BUILD_BFLOAT16),1) OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemm > SBBLAT3.SUMM @$(GREP) -q FATAL SBBLAT3.SUMM && cat SBBLAT3.SUMM || exit 0 @@ -366,6 +401,11 @@ zblat3 : zblat3.$(SUFFIX) ../$(LIBNAME) $(FC) $(FLDFLAGS) -o zblat3 zblat3.$(SUFFIX) ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) endif +ifeq ($(BUILD_BFLOAT16_ONLY),1) +test_bgemm : compare_sgemm_bgemm.c ../$(LIBNAME) + $(CC) $(CLDFLAGS) -o test_bgemm compare_sgemm_bgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) +endif + ifeq ($(BUILD_BFLOAT16),1) test_sbgemm : compare_sgemm_sbgemm.c ../$(LIBNAME) $(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) @@ -387,7 +427,7 @@ clean: @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ sblat1 dblat1 cblat1 zblat1 \ sblat2 dblat2 cblat2 zblat2 \ - test_sbgemm sblat3 dblat3 cblat3 zblat3 \ + test_bgemm test_sbgemm sblat3 dblat3 cblat3 zblat3 \ sblat1p dblat1p cblat1p zblat1p \ sblat2p dblat2p cblat2p zblat2p \ sblat3p dblat3p cblat3p zblat3p \ diff --git a/test/compare_sgemm_bgemm.c b/test/compare_sgemm_bgemm.c new file mode 100644 index 0000000000..711d389f0b --- /dev/null +++ b/test/compare_sgemm_bgemm.c @@ -0,0 +1,187 @@ +/*************************************************************************** +Copyright (c) 2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ +#include "../common.h" +#include +#include + +#include + +#define SGEMM BLASFUNC(sgemm) +#define BGEMM BLASFUNC(bgemm) +#define BGEMM_LARGEST 256 + +void *malloc_safe(size_t size) { + if (size == 0) + return malloc(1); + else + return malloc(size); +} + +bfloat16 convert_to_bf16(float x) { + bfloat16_t src = x; + bfloat16 dst = 0; + memcpy(&dst, &src, sizeof(src)); + return dst; +} + +int main(int argc, char *argv[]) { + blasint m, n, k; + int i, j, l; + blasint x, y; + int ret = 0; + int loop = BGEMM_LARGEST; + char transA = 'N', transB = 'N'; + + float alpha = 1.0, beta = 0.0; + + for (x = 1; x <= loop; x++) { + if ((x > 100) && (x != BGEMM_LARGEST)) + continue; + m = k = n = x; + + float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); + float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); + float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); + + bfloat16_t *AA = (bfloat16_t *)malloc_safe(m * k * sizeof(bfloat16)); + bfloat16_t *BB = (bfloat16_t *)malloc_safe(k * n * sizeof(bfloat16)); + bfloat16_t *CC = (bfloat16_t *)malloc_safe(m * n * sizeof(bfloat16)); + + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || + (BB == NULL) || (CC == NULL)) + return 1; + + for (int i = 0; i < m; i++) { + for (int j = 0; j < k; j++) { + A[i * k + j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + AA[i * k + j] = A[i * k + j] ; + } + } + + for (int i = 0; i < n; i++) { + for (int j = 0; j < k; j++) { + // BB[i * k + j] = (i * k + j + 1) % 100; + B[i * k + j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + BB[i * k + j] = B[i * k + j] ; + } + } + + for (y = 0; y < 1; y++) { + if ((y == 0) || (y == 2)) { + transA = 'N'; + } else { + transA = 'T'; + } + if ((y == 0) || (y == 1)) { + transB = 'N'; + } else { + transB = 'T'; + } + // printf("******** x = %d, y = %d********\n", x, y); + // printf("Matrix AA (m x k):\n"); + // for (int i = 0; i < m; i++) { + // for (int j = 0; j < k; j++) { + // printf("%.2f ", (float)AA[i * k + j]); // or %4.1f if float + // } + // printf("\n"); + // } + + // printf("Matrix A (copy of AA):\n"); + // for (int i = 0; i < m; i++) { + // for (int j = 0; j < k; j++) { + // printf("%.2f ", A[i * k + j]); + // } + // printf("\n"); + // } + + // printf("Matrix BB (n x k):\n"); + // for (int i = 0; i < n; i++) { + // for (int j = 0; j < k; j++) { + // printf("%.2f ", (float)BB[i * k + j]); + // } + // printf("\n"); + // } + + // printf("Matrix B (copy of BB):\n"); + // for (int i = 0; i < n; i++) { + // for (int j = 0; j < k; j++) { + // printf("%.2f ", B[i * k + j]); + // } + // printf("\n"); + // } + + memset(C, 0, m * n * sizeof(FLOAT)); + memset(CC, 0, m * n * sizeof(bfloat16)); + SGEMM(&transA, &transB, &m, &n, &k, &alpha, A, &m, B, &k, &beta, + C, &m); + BGEMM(&transA, &transB, &m, &n, &k, &alpha, (bfloat16 *)AA, &m, + (bfloat16 *)BB, &k, &beta, (bfloat16 *)CC, &m); + + + // printf("Matrix CC (n x m):\n"); + // for (int i = 0; i < n; i++) { + // for (int j = 0; j < m; j++) { + // printf("%.2f ", (float)CC[i * m + j]); + // } + // printf("\n"); + // } + + // printf("Matrix C :\n"); + // for (int i = 0; i < n; i++) { + // for (int j = 0; j < k; j++) { + // printf("%.2f ", C[i * k + j]); + // } + // printf("\n"); + // } + + for (i = 0; i < n; i++) { + for (j = 0; j < m; j++) { + if (fabs((float)CC[i * m + j] - C[i * m + j]) > 1.0) { + ret ++; + } + } + } + + printf("x = %d, err = %d\n", x, ret); + ret = 0; + } + + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(CC); + } + + if (ret != 0) { + fprintf(stderr, "FATAL ERROR BGEMM - Return code: %d\n", ret); + return ret; + } + + return 0; +} \ No newline at end of file