diff --git a/CMakeLists.txt b/CMakeLists.txt index fe8acc803..7c069e420 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,7 +58,8 @@ else() set(GGML_BLAS_VENDOR_DEFAULT "Generic") endif() -if (CMAKE_CROSSCOMPILING) +if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH}) + message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF") set(GGML_NATIVE_DEFAULT OFF) else() set(GGML_NATIVE_DEFAULT ON) @@ -153,6 +154,8 @@ option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashA option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) option(GGML_HIP "ggml: use HIP" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) @@ -185,6 +188,9 @@ option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increas option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON) +# toolchain for vulkan-shaders-gen +set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") + # extra artifacts option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) @@ -261,3 +267,74 @@ if (GGML_STANDALONE) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc DESTINATION share/pkgconfig) endif() + +# +# Create CMake package +# + +# Generate version info based on git commit. + +find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH) +execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_NUMBER + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +if(GGML_BUILD_NUMBER EQUAL 1) + message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.") +endif() + +execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_COMMIT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# Capture variables prefixed with GGML_. + +set(variable_set_statements +" +####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run ####### + +") + +set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS}) + +get_cmake_property(all_variables VARIABLES) +foreach(variable_name IN LISTS all_variables) + if(variable_name MATCHES "^GGML_") + string(REPLACE ";" "\\;" + variable_value "${${variable_name}}") + + set(variable_set_statements + "${variable_set_statements}set(${variable_name} \"${variable_value}\")\n") + endif() +endforeach() + +set(GGML_VARIABLES_EXPANDED ${variable_set_statements}) + +# Create the CMake package and set install location. + +set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER}) +set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml + PATH_VARS GGML_INCLUDE_INSTALL_DIR + GGML_LIB_INSTALL_DIR + GGML_BIN_INSTALL_DIR) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake + VERSION ${GGML_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml) diff --git a/cmake/ggml-config.cmake.in b/cmake/ggml-config.cmake.in new file mode 100644 index 000000000..bf39f9c00 --- /dev/null +++ b/cmake/ggml-config.cmake.in @@ -0,0 +1,147 @@ + +@GGML_VARIABLES_EXPANDED@ + +@PACKAGE_INIT@ + +set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@") +set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@") +set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@") + +find_package(Threads REQUIRED) + +find_library(GGML_LIBRARY ggml + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + +add_library(ggml::ggml UNKNOWN IMPORTED) +set_target_properties(ggml::ggml + PROPERTIES + IMPORTED_LOCATION "${GGML_LIBRARY}") + +find_library(GGML_BASE_LIBRARY ggml-base + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + +add_library(ggml::ggml-base UNKNOWN IMPORTED) +set_target_properties(ggml::ggml-base + PROPERTIES + IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") + +if (NOT GGML_SHARED_LIB) + if (APPLE AND GGML_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK}) + endif() + + if (GGML_OPENMP) + find_package(OpenMP REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + endif() + + if (GGML_CPU_HBM) + find_library(memkind memkind REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind) + endif() + + if (GGML_BLAS) + find_package(BLAS REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES}) + list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS}) + endif() + + if (GGML_CUDA) + find_package(CUDAToolkit REQUIRED) + endif() + + if (GGML_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + + list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES + ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK}) + endif() + + if (GGML_VULKAN) + find_package(Vulkan REQUIRED) + list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan) + endif() + + if (GGML_HIP) + find_package(hip REQUIRED) + find_package(hipblas REQUIRED) + find_package(rocblas REQUIRED) + list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas) + endif() + + if (GGML_SYCL) + find_package(DNNL) + if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") + list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl) + endif() + if (WIN32) + find_package(IntelSYCL REQUIRED) + find_package(MKL REQUIRED) + list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL) + endif() + endif() +endif() + +set(_ggml_all_targets "") +foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) + string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}") + string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx) + + find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend} + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + + message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}") + + add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}" + INTERFACE_COMPILE_FEATURES c_std_90 + POSITION_INDEPENDENT_CODE ON) + + string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}") + if(is_cpu_variant) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base") + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}") + + if(GGML_CPU_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}") + endif() + + else() + list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base") + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}") + + if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}") + endif() + endif() + + list(APPEND _ggml_all_targets ggml::${_ggml_backend}) +endforeach() + +add_library(ggml::all INTERFACE IMPORTED) +set_target_properties(ggml::all + PROPERTIES + INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}") + +check_required_components(ggml) diff --git a/include/ggml-backend.h b/include/ggml-backend.h index 7221a0830..fc9571c82 100644 --- a/include/ggml-backend.h +++ b/include/ggml-backend.h @@ -203,6 +203,8 @@ extern "C" { // Backend registry // + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); + // Backend (reg) enumeration GGML_API size_t ggml_backend_reg_count(void); GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index); diff --git a/include/ggml.h b/include/ggml.h index 8f8cb9e1a..1198dc1fd 100644 --- a/include/ggml.h +++ b/include/ggml.h @@ -1384,16 +1384,20 @@ extern "C" { float scale, float max_bias); - GGML_API struct ggml_tensor * ggml_soft_max_back( + GGML_API struct ggml_tensor * ggml_soft_max_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + float scale, + float max_bias); // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( + GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + float scale, + float max_bias); // rotary position embedding // if (mode & 1) - skip n_past elements (NOT SUPPORTED) @@ -1500,7 +1504,7 @@ extern "C" { // rotary position embedding backward, i.e compute dx from dy // a - dy - GGML_API struct ggml_tensor * ggml_rope_back( + GGML_API struct ggml_tensor * ggml_rope_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, // gradients of ggml_rope result struct ggml_tensor * b, // positions @@ -1515,6 +1519,23 @@ extern "C" { float beta_fast, float beta_slow); + GGML_API struct ggml_tensor * ggml_rope_multi_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + // clamp // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_clamp( diff --git a/scripts/sync-llama.last b/scripts/sync-llama.last index 9fea66d40..55dae02a6 100644 --- a/scripts/sync-llama.last +++ b/scripts/sync-llama.last @@ -1 +1 @@ -504af20ee4eae72080a56d59d744f6774f7901ce +815857791d3639a4d544d0a8cf25a49b0325c08c diff --git a/scripts/sync-llama.sh b/scripts/sync-llama.sh index dba29c90b..78672142e 100755 --- a/scripts/sync-llama.sh +++ b/scripts/sync-llama.sh @@ -2,7 +2,7 @@ cp -rpv ../llama.cpp/ggml/CMakeLists.txt CMakeLists.txt cp -rpv ../llama.cpp/ggml/src/CMakeLists.txt src/CMakeLists.txt -cp -rpv ../llama.cpp/ggml/cmake/FindSIMD.cmake cmake/FindSIMD.cmake +cp -rpv ../llama.cpp/ggml/cmake/* cmake/ cp -rpv ../llama.cpp/ggml/src/ggml*.c src/ cp -rpv ../llama.cpp/ggml/src/ggml*.cpp src/ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ae1cd2337..566709135 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -250,6 +250,17 @@ function(ggml_add_backend_library backend) target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD) target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED) endif() + + if(NOT GGML_AVAILABLE_BACKENDS) + set(GGML_AVAILABLE_BACKENDS "${backend}" + CACHE INTERNAL "List of backends for cmake package") + else() + list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend) + if(has_backend EQUAL -1) + set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}" + CACHE INTERNAL "List of backends for cmake package") + endif() + endif() endfunction() function(ggml_add_backend backend) @@ -297,7 +308,7 @@ if (GGML_CPU_ALL_VARIANTS) # MSVC doesn't support AMX ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) endif() -else () +elseif (GGML_CPU) ggml_add_cpu_backend_variant_impl("") endif() diff --git a/src/ggml-alloc.c b/src/ggml-alloc.c index 8dc8226ac..9a3bf9f29 100644 --- a/src/ggml-alloc.c +++ b/src/ggml-alloc.c @@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml return true; } +// ops that return true for this function must not use restrict pointers for their backend implementations static bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { case GGML_OP_SCALE: @@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_LOG: case GGML_OP_UNARY: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: return true; default: diff --git a/src/ggml-backend-impl.h b/src/ggml-backend-impl.h index 36d72e95f..d1c2d76d8 100644 --- a/src/ggml-backend-impl.h +++ b/src/ggml-backend-impl.h @@ -208,7 +208,6 @@ extern "C" { // Internal backend registry API GGML_API void ggml_backend_register(ggml_backend_reg_t reg); - GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); // Add backend dynamic loading support to the backend diff --git a/src/ggml-cpu/ggml-cpu-quants.c b/src/ggml-cpu/ggml-cpu-quants.c index 8e1472266..88303ff0e 100644 --- a/src/ggml-cpu/ggml-cpu-quants.c +++ b/src/ggml-cpu/ggml-cpu-quants.c @@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r uint32_t utmp[4]; -#ifdef __ARM_NEON +#ifdef __ARM_FEATURE_SVE + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, K_SCALE_SIZE); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const int vector_length = ggml_cpu_get_sve_cnt()*8; + const svuint8_t m4b = svdup_n_u8(0xf); + const svint32_t mzero = svdup_n_s32(0); + svint32_t sumi1 = svdup_n_s32(0); + svint32_t sumi1_1 = svdup_n_s32(0); + svint32_t sumi1_2 = svdup_n_s32(0); + svint32_t sumi2 = svdup_n_s32(0); + svint32_t sumi2_1 = svdup_n_s32(0); + svint32_t sumi2_2 = svdup_n_s32(0); + switch (vector_length) { + case 128: + { + for (int j = 0; j < QK_K/64; ++j) { + svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b)); + svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + q4 += 32; + } + sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2); + sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2); + sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2))); + } break; + case 256: + case 512: + { + for (int j = 0; j < QK_K/64; ++j) { + const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32; + svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b)); + svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; + sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4)); + q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; + sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + } + sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2))); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + } + *s = sumf; +#elif __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); diff --git a/src/ggml-cpu/ggml-cpu.c b/src/ggml-cpu/ggml-cpu.c index 424c9781f..e809f05d2 100644 --- a/src/ggml-cpu/ggml-cpu.c +++ b/src/ggml-cpu/ggml-cpu.c @@ -3967,6 +3967,57 @@ static void ggml_compute_forward_dup_bytes( } } +static void ggml_compute_forward_dup_q( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float; + + size_t qk = ggml_blck_size(type); + const int64_t nr = ggml_nelements(src1) / qk; + + // destination must be contiguous in the first dimension + GGML_ASSERT(nb10 == ggml_type_size(dst->type)); + // must either have first dimension large enough to hold a row, or fully contiguous + GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + + uint32_t i = ir * qk; + + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + + dequantize_row_q( + (const void *) ((char *) src0->data + x_offset), + (float *) ((char *) dst->data + dst_offset), qk); + } +} + static void ggml_compute_forward_dup( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -3993,6 +4044,10 @@ static void ggml_compute_forward_dup( } break; default: { + if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) { + ggml_compute_forward_dup_q(params, dst); + break; + } GGML_ABORT("fatal error"); } } @@ -6691,20 +6746,20 @@ static void ggml_compute_forward_silu_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * grad = dst->src[1]; + const struct ggml_tensor * grad = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; assert(ggml_is_contiguous_1(grad)); - assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(src1)); assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - assert(ggml_are_same_shape(src0, grad)); + assert(ggml_are_same_shape(src1, dst)); + assert(ggml_are_same_shape(src1, grad)); const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); + const int nc = src1->ne[0]; + const int nr = ggml_nrows(src1); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -6716,7 +6771,7 @@ static void ggml_compute_forward_silu_back_f32( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_silu_backward_f32(nc, (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])), + (float *) ((char *) src1->data + i1*(src1->nb[1])), (float *) ((char *) grad->data + i1*(grad->nb[1]))); #ifndef NDEBUG @@ -6895,7 +6950,7 @@ static void ggml_compute_forward_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - GGML_ASSERT(eps > 0.0f); + GGML_ASSERT(eps >= 0.0f); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -6966,7 +7021,7 @@ static void ggml_compute_forward_rms_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - GGML_ASSERT(eps > 0.0f); + GGML_ASSERT(eps >= 0.0f); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -7018,12 +7073,13 @@ static void ggml_compute_forward_rms_norm_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output + const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -7042,8 +7098,8 @@ static void ggml_compute_forward_rms_norm_back_f32( const int64_t i12 = i02; const int64_t i13 = i03; - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); ggml_float sum_xx = 0.0; ggml_float sum_xdz = 0.0; @@ -7066,9 +7122,9 @@ static void ggml_compute_forward_rms_norm_back_f32( { // z = rms_norm(x) // - // rms_norm(src0) = + // rms_norm(src1) = // scale( - // src0, + // src1, // div( // 1, // sqrt( @@ -7076,13 +7132,13 @@ static void ggml_compute_forward_rms_norm_back_f32( // scale( // sum( // sqr( - // src0)), + // src1)), // (1.0/N)), // eps)))); // postorder: // ## op args grad - // 00 param src0 grad[#00] + // 00 param src1 grad[#00] // 01 const 1 // 02 sqr (#00) grad[#02] // 03 sum (#02) grad[#03] @@ -7159,6 +7215,7 @@ static void ggml_compute_forward_rms_norm_back_f32( // dx := scale(dx, rrms) float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps) ggml_vec_cpy_f32 (ne00, dx, x); // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); @@ -7750,12 +7807,13 @@ static void ggml_compute_forward_out_prod_f32( const int ith = params->ith; const int nth = params->nth; - GGML_ASSERT(ne0 == ne00); - GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne3 == ne13); - GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + GGML_ASSERT(ne2 % ne02 == 0); + GGML_ASSERT(ne3 % ne03 == 0); // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == sizeof(float)); @@ -7797,6 +7855,10 @@ static void ggml_compute_forward_out_prod_f32( const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); const int64_t blck_1 = 16; + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { const int64_t bir1 = MIN(bir + blck_1, ir1); for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { @@ -7807,8 +7869,8 @@ static void ggml_compute_forward_out_prod_f32( const int64_t i2 = (ir - i3*ne2*ne1)/ne1; const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); - const int64_t i02 = i2; - const int64_t i03 = i3; + const int64_t i02 = i2 / dps2; + const int64_t i03 = i3 / dps3; //const int64_t i10 = i1; const int64_t i12 = i2; @@ -7821,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32( float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); } @@ -7830,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32( float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32(ne0, d, s0, *s1); } @@ -8906,9 +8968,9 @@ static void ggml_compute_forward_soft_max( } -// ggml_compute_forward_soft_max_back +// ggml_compute_forward_soft_max_ext_back -static void ggml_compute_forward_soft_max_back_f32( +static void ggml_compute_forward_soft_max_ext_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8921,6 +8983,14 @@ static void ggml_compute_forward_soft_max_back_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src1, dst)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + // TODO: handle transposed/permuted matrices const int ith = params->ith; @@ -8969,10 +9039,11 @@ static void ggml_compute_forward_soft_max_back_f32( // linear runtime, no additional memory float dot_y_dy = 0; - ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); - ggml_vec_cpy_f32 (nc, dx, dy); - ggml_vec_acc1_f32(nc, dx, -dot_y_dy); - ggml_vec_mul_f32 (nc, dx, dx, y); + ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); + ggml_vec_cpy_f32 (nc, dx, dy); + ggml_vec_acc1_f32 (nc, dx, -dot_y_dy); + ggml_vec_mul_f32 (nc, dx, dx, y); + ggml_vec_scale_f32(nc, dx, scale); #ifndef NDEBUG for (int i = 0; i < nc; ++i) { @@ -8983,7 +9054,7 @@ static void ggml_compute_forward_soft_max_back_f32( } } -static void ggml_compute_forward_soft_max_back( +static void ggml_compute_forward_soft_max_ext_back( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8992,7 +9063,7 @@ static void ggml_compute_forward_soft_max_back( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_soft_max_back_f32(params, dst); + ggml_compute_forward_soft_max_ext_back_f32(params, dst); } break; default: { @@ -9985,9 +10056,10 @@ static void ggml_compute_forward_im2col_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output + const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -10009,11 +10081,11 @@ static void ggml_compute_forward_im2col_back_f32( const int64_t IH = is_2D ? ne1 : 1; const int64_t IW = ne0; - const int64_t KH = is_2D ? ne01 : 1; - const int64_t KW = ne00; + const int64_t KH = is_2D ? ne11 : 1; + const int64_t KW = ne10; - const int64_t OH = is_2D ? ne12 : 1; - const int64_t OW = ne11; + const int64_t OH = is_2D ? ne02 : 1; + const int64_t OW = ne01; int ofs0 = is_2D ? nb3 : nb2; int ofs1 = is_2D ? nb2 : nb1; @@ -10059,9 +10131,9 @@ static void ggml_compute_forward_im2col_back_f32( continue; } - const float * const src_data = (const float *) src1->data + const float * const grad_in = (const float *) src0->data + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - grad += src_data[iic*(KH*KW) + ikh*KW + ikw]; + grad += grad_in[iic*(KH*KW) + ikh*KW + ikw]; } } float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] @@ -12484,22 +12556,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * opt0 = dst->src[2]; + const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output + const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass + const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(opt0)); - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(src0f)); + GGML_ASSERT(ggml_is_contiguous(src1f)); + GGML_ASSERT(ggml_is_contiguous(grad)); + GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst)); const int64_t ith = params->ith; const int64_t nth = params->nth; // TODO: handle transposed/permuted matrices - const int64_t nc = src0->ne[0]; - const int64_t nr = ggml_nrows(src0); + const int64_t nc = src0f->ne[0]; + const int64_t nr = ggml_nrows(src0f); // rows per thread const int64_t dr = (nr + nth - 1)/nth; @@ -12508,12 +12580,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ir0 = dr*ith; const int64_t ir1 = MIN(ir0 + dr, nr); - const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr; + const float d_by_nr = ((const float *) grad->data)[0] / (float) nr; for (int64_t i1 = ir0; i1 < ir1; i1++) { - float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); - float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); - float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); + const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]); + const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]); #ifndef NDEBUG for (int64_t i = 0; i < nc; ++i) { @@ -12526,11 +12598,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( // soft_max float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); - ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); + const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); assert(sum > 0.0); ggml_vec_scale_f32(nc, ds0, 1.0/sum); - // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr + // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr ggml_vec_sub_f32(nc, ds0, ds0, s1); ggml_vec_scale_f32(nc, ds0, d_by_nr); @@ -12827,7 +12899,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SOFT_MAX_BACK: { - ggml_compute_forward_soft_max_back(params, tensor); + ggml_compute_forward_soft_max_ext_back(params, tensor); } break; case GGML_OP_ROPE: { @@ -13668,6 +13740,7 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } break; diff --git a/src/ggml-cpu/ggml-cpu.cpp b/src/ggml-cpu/ggml-cpu.cpp index f11399cc6..2ccb4b472 100644 --- a/src/ggml-cpu/ggml-cpu.cpp +++ b/src/ggml-cpu/ggml-cpu.cpp @@ -403,12 +403,21 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type; - case GGML_OP_ROPE_BACK: - return op->src[2] == NULL && (op->op_params[2] & 4) == 0; + case GGML_OP_SOFT_MAX_BACK: { + if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) { + return false; + } + float max_bias = 0.0f; + + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + + return max_bias == 0.0f; + } case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; case GGML_OP_OUT_PROD: - return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32; + return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && + src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; default: return true; } diff --git a/src/ggml-cuda/binbcast.cu b/src/ggml-cuda/binbcast.cu index c7b6be4e2..ce4b9cfb5 100644 --- a/src/ggml-cuda/binbcast.cu +++ b/src/ggml-cuda/binbcast.cu @@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s template static __global__ void k_repeat_back( - const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, - const int64_t ne0, const int64_t ne1, const int64_t ne2) { + const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) { - const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; - const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y; - const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z; + const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x; + const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y; + const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z; + const int64_t tid2 = tid23 % ne2; + const int64_t tid3 = tid23 / ne2; if (tid0 >= ne0) { return; } T sum = 0; - for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { - for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { - for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { - sum += src[i2*ne01*ne00 + i1*ne00 + i0]; + for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { + for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { + for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { + for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { + sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00]; + } } } } - dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; + dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; } template @@ -274,12 +279,14 @@ struct bin_bcast_cuda { template static void repeat_back_cuda( - const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, - const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) { + const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); - const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2); - k_repeat_back<<>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2); + const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3); + k_repeat_back<<>> + (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3); } template @@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(src0->type == dst->type); - GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_can_repeat(dst, src0)); cudaStream_t stream = ctx.stream(); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - GGML_ASSERT(src0->ne[3] == 1); + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT(ne2*ne3 <= (1 << 15)); - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - GGML_ASSERT(dst->ne[3] == 1); + const size_t ts = ggml_type_size(src0->type); + const size_t s00 = nb00 / ts; + const size_t s01 = nb01 / ts; + const size_t s02 = nb02 / ts; + const size_t s03 = nb03 / ts; switch (dst->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; float * dst_d = (float *) dst->data; - repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream); + repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream); } break; default: { GGML_ASSERT(false); diff --git a/src/ggml-cuda/common.cuh b/src/ggml-cuda/common.cuh index 2c0a56226..a66322da0 100644 --- a/src/ggml-cuda/common.cuh +++ b/src/ggml-cuda/common.cuh @@ -46,20 +46,20 @@ #define GGML_CUDA_CC_VOLTA 700 #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 -#define GGML_CUDA_CC_OFFSET_AMD 1000000 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 // GCN/CNDA, wave size is 64 -#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16 -#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue -#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a -#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers -#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing -#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300 +#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 +#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue +#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a +#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 -#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000 -#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a -#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 +#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a +#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA #define GGML_CUDA_CC_QY1 210 #define GGML_CUDA_CC_QY2 220 @@ -131,6 +131,10 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif // GGML_CUDA_F16 +#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) +#define GGML_USE_VMM +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) + #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #define FP16_AVAILABLE #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL @@ -588,7 +592,7 @@ struct ggml_tensor_extra_gpu { }; -#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS) +#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS) #define USE_CUDA_GRAPH #endif diff --git a/src/ggml-cuda/cross-entropy-loss.cu b/src/ggml-cuda/cross-entropy-loss.cu index ed09406a8..27599a2b0 100644 --- a/src/ggml-cuda/cross-entropy-loss.cu +++ b/src/ggml-cuda/cross-entropy-loss.cu @@ -5,95 +5,89 @@ #include #include -static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) { - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE; - - const int ne_tmp = WARP_SIZE*nclasses; - - extern __shared__ float tmp_all[]; - float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp; - float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp; - - // Each warp first loads ne_tmp logits/labels into shared memory: - for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) { - const int ig = i0*nclasses + i; // ig == i global - - tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f; - tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f; - } +template +static __global__ void cross_entropy_loss_f32( + const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) { + extern __shared__ float tmp[]; - // Each thread in the warp then calculates the cross entropy loss for a single row. - // TODO: pad in order to avoid shared memory bank conflicts. + logits += int64_t(blockIdx.x)*nclasses; + labels += int64_t(blockIdx.x)*nclasses; // Find maximum for softmax: - float max = -INFINITY; - for (int i = 0; i < nclasses; ++i) { - max = fmaxf(max, tmp_logits[lane_id*nclasses + i]); + float max_logit = -INFINITY; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = logits[i]; + max_logit = fmaxf(max_logit, val); + + if (use_shared) { + tmp[i] = val; + } } + max_logit = warp_reduce_max(max_logit); // Calculate log(softmax(logits)) which is just logits - max: float sum = 0.0f; - for (int i = 0; i < nclasses; ++i) { - float val = tmp_logits[lane_id*nclasses + i] - max; - sum += expf(val); - tmp_logits[lane_id*nclasses + i] = val; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float logit_i = use_shared ? tmp[i] : logits[i]; + sum += expf(logit_i - max_logit); } + sum = warp_reduce_sum(sum); sum = logf(sum); // log(exp(logits - max) / sum) = (logits - max) - log(sum) float loss = 0.0f; - for (int i = 0; i < nclasses; ++i) { - loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i]; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float logit_i = use_shared ? tmp[i] : logits[i]; + loss += (logit_i - max_logit - sum) * labels[i]; } loss = -warp_reduce_sum(loss) / (float)k; - __syncthreads(); - - if (lane_id == 0) { - tmp_all[warp_id] = loss; - } - - __syncthreads(); - - if (warp_id != 0) { - return; - } - - loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f; - loss = warp_reduce_sum(loss); - - if (lane_id != 0) { + if (threadIdx.x != 0) { return; } dst[blockIdx.x] = loss; } -static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) { +template +static __global__ void cross_entropy_loss_back_f32( + const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels, + float * __restrict__ dst, const int nclasses) { extern __shared__ float tmp[]; + logits += int64_t(blockIdx.x)*nclasses; + labels += int64_t(blockIdx.x)*nclasses; + dst += int64_t(blockIdx.x)*nclasses; + float maxval = -INFINITY; for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { - const float val = logits[blockIdx.x*nclasses + i]; + const float val = logits[i]; maxval = fmaxf(maxval, val); - tmp[i] = val; + + if (use_shared) { + tmp[i] = val; + } } maxval = warp_reduce_max(maxval); float sum = 0.0f; for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { - const float val = expf(tmp[i] - maxval); + const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval); sum += val; - tmp[i] = val; + + if (use_shared) { + tmp[i] = val; + } else { + dst[i] = val; + } } sum = warp_reduce_sum(sum); const float sm_scale = 1.0f/sum; - const float d_by_nrows = *loss/gridDim.x; + const float d_by_nrows = *grad/gridDim.x; for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { - dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows; + const float val = use_shared ? tmp[i] : dst[i]; + dst[i] = (val*sm_scale - labels[i])*d_by_nrows; } } @@ -119,48 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_pool & pool = ctx.pool(); cudaStream_t stream = ctx.stream(); - const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); - const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); - const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float); + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(nrows, 1, 1); + const size_t nbytes_shared = ne00*sizeof(float); + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); - cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + if (nbytes_shared <= smpbo) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + } else { + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + } + CUDA_CHECK(cudaGetLastError()); // Combine results from individual blocks: sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream); } void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - const ggml_tensor * opt0 = dst->src[2]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(opt0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(opt0)); + const ggml_tensor * grad = dst->src[0]; + const ggml_tensor * src0f = dst->src[1]; + const ggml_tensor * src1f = dst->src[2]; + + GGML_ASSERT(src0f->type == GGML_TYPE_F32); + GGML_ASSERT(src1f->type == GGML_TYPE_F32); + GGML_ASSERT( grad->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_scalar(grad)); + GGML_ASSERT(ggml_is_contiguous(src0f)); + GGML_ASSERT(ggml_is_contiguous(src1f)); GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0f, src1f)); + GGML_ASSERT(ggml_are_same_shape(src0f, dst)); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + const int64_t ne00 = src0f->ne[0]; + const int64_t nrows = ggml_nrows(src0f); - const float * src0_d = (const float *) src0->data; - const float * src1_d = (const float *) src1->data; - const float * opt0_d = (const float *) opt0->data; - float * dst_d = (float *) dst->data; + const float * grad_d = (const float *) grad->data; + const float * src0f_d = (const float *) src0f->data; + const float * src1f_d = (const float *) src1f->data; + float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); const dim3 blocks_dim(WARP_SIZE, 1, 1); const dim3 blocks_num(nrows, 1, 1); - const int shmem = ne00*sizeof(float); - - cross_entropy_loss_back_f32<<>>(src0_d, src1_d, opt0_d, dst_d, ne00); + const size_t nbytes_shared = ne00*sizeof(float); + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + if (nbytes_shared <= smpbo) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); + } else { + cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); + } } diff --git a/src/ggml-cuda/getrows.cu b/src/ggml-cuda/getrows.cu index 4c3703238..4cef53a98 100644 --- a/src/ggml-cuda/getrows.cu +++ b/src/ggml-cuda/getrows.cu @@ -3,15 +3,15 @@ template static __global__ void k_get_rows( - const void * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12/*, size_t s13*/) { + const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ + /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, + /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; @@ -22,10 +22,10 @@ static __global__ void k_get_rows( const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index const int iybs = i00 - i00%qk; // dst block start index const int y_offset = qr == 1 ? 1 : qk/2; @@ -39,15 +39,15 @@ static __global__ void k_get_rows( template static __global__ void k_get_rows_float( - const src0_t * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12/*, size_t s13*/) { - - const int i00 = blockIdx.x*blockDim.x + threadIdx.x; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ + /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, + /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + + const int i00 = blockIdx.x*blockDim.x + threadIdx.x; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; @@ -58,14 +58,38 @@ static __global__ void k_get_rows_float( const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); dst_row[i00] = src0_row[i00]; } +template +static __global__ void k_get_rows_back_float( + const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) { + const int col = blockIdx.x*blockDim.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int dst_row = blockIdx.y*blockDim.y + threadIdx.y; + + float sum = 0.0f; + + for (int64_t i = 0; i < nrows_grad; ++i) { + if (rows[i] != dst_row) { + continue; + } + sum += grad[i*ncols + col]; + } + + dst[dst_row*ncols + col] = sum; +} + template -static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { +static void get_rows_cuda( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg GGML_ASSERT(ne00 % 2 == 0); k_get_rows<<>>( - src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ - /* s0,*/ s1, s2, s3, - /* nb00,*/ nb01, nb02, nb03, - s10, s11, s12/*, s13*/); + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); GGML_UNUSED(dst); } template -static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { +static void get_rows_cuda_float( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(ne13 == 1); + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; const dim3 block_nums(block_num_x, ne10, ne11*ne12); @@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr //const size_t s13 = nb13 / ggml_element_size(src1); k_get_rows_float<<>>( - src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ - /* s0,*/ s1, s2, s3, - /* nb00,*/ nb01, nb02, nb03, - s10, s11, s12/*, s13*/); + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); GGML_UNUSED(dst); } @@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); + const void * src0_d = (const void *) src0->data; + const int32_t * src1_d = (const int32_t *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); - GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); - - const int32_t * src1_i32 = (const int32_t *) src1_d; + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); switch (src0->type) { case GGML_TYPE_F16: - get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream); + get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_F32: - get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q4_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q4_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q5_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q5_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q8_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; default: // TODO: k-quants @@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { break; } } + +void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output + const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass + + GGML_TENSOR_BINARY_OP_LOCALS + + const float * src0_d = (const float *) src0->data; + const int32_t * src1_d = (const int32_t *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(ne02*ne03 == 1); + GGML_ASSERT(ne12*ne13 == 1); + GGML_ASSERT(ne2*ne3 == 1); + + const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1); + const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE; + const dim3 block_nums(block_num_x, ne1, 1); + + k_get_rows_back_float<<>>(src0_d, src1_d, dst_d, ne00, ne10); +} diff --git a/src/ggml-cuda/getrows.cuh b/src/ggml-cuda/getrows.cuh index bbf130232..a1ca643f1 100644 --- a/src/ggml-cuda/getrows.cuh +++ b/src/ggml-cuda/getrows.cuh @@ -1,5 +1,8 @@ #include "common.cuh" #define CUDA_GET_ROWS_BLOCK_SIZE 256 +#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml-cuda/ggml-cuda.cu b/src/ggml-cuda/ggml-cuda.cu index 1dac397c4..de3f9c2ca 100644 --- a/src/ggml-cuda/ggml-cuda.cu +++ b/src/ggml-cuda/ggml-cuda.cu @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -62,7 +63,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails - cudaGetDevice(&id); + (void)cudaGetDevice(&id); GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg); GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); @@ -119,12 +120,78 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) #endif } +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +static int ggml_cuda_parse_id(char devName[]) { + // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp + // these values are not stable so this is susceptible to breakage + // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp + int archMajor = 0x0; + int archMinor = 0x0; + int archNum = GGML_CUDA_CC_OFFSET_AMD; + int archLen = strlen(devName); + char archName[archLen + 1]; + + // strip leading 'gfx' while copying into our buffer + if (archLen > 3) { + strcpy(archName, &devName[3]); + archLen -= 3; + } + + // trim trailing :xnack- or :sramecc- statuses + archLen = strcspn(archName, ":"); + archName[archLen] = '\0'; + + // tease out the version information + if (archLen > 8) { + // versions labeled generic use '-' as delimiter + // strip the trailing "-generic" then iterate through what remains + if ((strstr(archName, "-generic"))) { + archName[archLen - 8] = '\0'; + char * pch; + if ((pch = strtok(archName, "-"))) { + archMajor = (int)strtoul(pch, 0, 16); + if ((pch = strtok(NULL, "-"))) { + archMinor = 0x10 * (int)strtoul(pch, 0, 16); + } + } + } + } else if (archLen >= 3) { + // last two digits should be the minor * 0x10 + stepping + archMinor = (int)strtoul(&archName[archLen - 2], 0, 16); + archName[archLen - 2] = '\0'; + + // only the major version remains + archMajor = (int)strtoul(archName, 0, 16); + } + archNum += archMajor * 0x100; + archNum += archMinor; + return archNum; +} +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + static ggml_cuda_device_info ggml_cuda_init() { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346 - rocblas_initialize(); - CUDA_CHECK(cudaDeviceSynchronize()); + { + int major_version = 0; + size_t version_length = 0; + if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) { + std::string version(version_length, '\0'); + if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) { + version.resize(::strlen(version.c_str())); + int parsed_value = 0; + if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) { + major_version = parsed_value; + } + } + } + if (major_version < 4) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n"); + rocblas_initialize(); + CUDA_CHECK(cudaDeviceSynchronize()); + } + } #endif ggml_cuda_device_info info = {}; @@ -152,7 +219,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if defined(GGML_USE_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -164,12 +231,11 @@ static ggml_cuda_device_info ggml_cuda_init() { alloc_prop.location.id = id; CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // defined(GGML_USE_VMM) info.devices[id].vmm = !!device_vmm; cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; @@ -178,10 +244,25 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].smpb = prop.sharedMemPerBlock; #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; - info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD; + + info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName); + if ((info.devices[id].cc & 0xff00) == 0x0) { + GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n", + id, prop.name, prop.gcnArchName, prop.major, prop.minor); + + // Fallback to prop.major and prop.minor + if (prop.major > 0) { + info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100; + info.devices[id].cc += prop.minor * 0x10; + } + } + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s\n", + id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no"); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } @@ -300,7 +381,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if defined(GGML_USE_VMM) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -309,6 +390,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { size_t pool_used = 0; size_t pool_size = 0; size_t granularity; +#if defined(GGML_USE_HIP) + std::vector> mappings; +#endif explicit ggml_cuda_pool_vmm(int device) : device(device), @@ -317,7 +401,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { ~ggml_cuda_pool_vmm() { if (pool_addr != 0) { +#if defined(GGML_USE_HIP) + // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 + for (std::pair & mapping : mappings) { + CU_CHECK(cuMemUnmap(mapping.first, mapping.second)); + } +#else CU_CHECK(cuMemUnmap(pool_addr, pool_size)); +#endif CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE)); } } @@ -350,7 +441,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { } // map at the end of the pool - CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0)); + CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); + CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0)); +#if defined(GGML_USE_HIP) + mappings.push_back({start_ptr, reserve_size}); +#endif // the memory allocation handle is no longer needed after mapping CU_CHECK(cuMemRelease(handle)); @@ -360,7 +455,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; access.location.id = device; access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1)); + CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); // add to the pool pool_size += reserve_size; @@ -372,7 +467,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(pool_addr != 0); - void * ptr = (void *) (pool_addr + pool_used); + void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used)); *actual_size = size; pool_used += size; @@ -391,17 +486,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { pool_used -= size; // all deallocations must be in reverse order of the allocations - GGML_ASSERT(ptr == (void *) (pool_addr + pool_used)); + GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); } }; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // defined(GGML_USE_VMM) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // defined(GGML_USE_VMM) return std::unique_ptr(new ggml_cuda_pool_leg(device)); } @@ -547,7 +642,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err)); return nullptr; } @@ -962,7 +1057,7 @@ static void * ggml_cuda_host_malloc(size_t size) { cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0, cudaGetErrorString(err)); return nullptr; @@ -1082,7 +1177,9 @@ static void ggml_cuda_op_mul_mat_cublas( const int compute_capability = ggml_cuda_info().devices[id].cc; - if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { + const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + + if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { @@ -1103,28 +1200,38 @@ static void ggml_cuda_op_mul_mat_cublas( to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); - ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; - if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { - cu_compute_type = CUBLAS_COMPUTE_32F; - } + if (compute_capability == GGML_CUDA_CC_CDNA) { + const float alpha = 1.0f; + const float beta = 0.0f; + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); - CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - CUBLAS_CHECK( - cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } } else { ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); ggml_cuda_pool_alloc src1_ddq_as_f32(ctx.pool(id)); @@ -1197,7 +1304,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { CUDA_CHECK(err); } else { // reset the error - cudaGetLastError(); + (void)cudaGetLastError(); } } else { cudaError_t err = cudaDeviceDisablePeerAccess(id_other); @@ -1205,7 +1312,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { CUDA_CHECK(err); } else { // reset the error - cudaGetLastError(); + (void)cudaGetLastError(); } } } @@ -1613,10 +1720,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; cudaDataType_t cu_data_type = CUDA_R_16F; - if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { - cu_compute_type = CUBLAS_COMPUTE_32F; - } - // dst strides size_t nbd2 = dst->nb[2]; size_t nbd3 = dst->nb[3]; @@ -1645,6 +1748,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co beta = &beta_f32; } + if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { + cu_compute_type = CUBLAS_COMPUTE_32F; + alpha = &alpha_f32; + beta = &beta_f32; + } + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -2003,6 +2112,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GET_ROWS: ggml_cuda_op_get_rows(ctx, dst); break; + case GGML_OP_GET_ROWS_BACK: + ggml_cuda_op_get_rows_back(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -2091,9 +2203,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_LEAKY_RELU: ggml_cuda_op_leaky_relu(ctx, dst); break; + case GGML_OP_SILU_BACK: + ggml_cuda_op_silu_back(ctx, dst); + break; case GGML_OP_RMS_NORM: ggml_cuda_op_rms_norm(ctx, dst); break; + case GGML_OP_RMS_NORM_BACK: + ggml_cuda_op_rms_norm_back(ctx, dst); + break; case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); @@ -2138,9 +2256,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOFT_MAX: ggml_cuda_op_soft_max(ctx, dst); break; + case GGML_OP_SOFT_MAX_BACK: + ggml_cuda_op_soft_max_back(ctx, dst); + break; case GGML_OP_ROPE: ggml_cuda_op_rope(ctx, dst); break; + case GGML_OP_ROPE_BACK: + ggml_cuda_op_rope_back(ctx, dst); + break; case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; @@ -2423,7 +2547,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto if (stat == cudaErrorInvalidDeviceFunction) { // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. // We don't need to update blas nodes, so clear error and move on. - cudaGetLastError(); + (void)cudaGetLastError(); } else { GGML_ASSERT(stat == cudaSuccess); } @@ -2478,14 +2602,20 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { cudaGraphExecUpdateResultInfo result_info; +#ifdef __HIP_PLATFORM_AMD__ + hipGraphNode_t errorNode; + hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); +#else cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); +#endif if (stat == cudaErrorGraphExecUpdateFailure) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); #endif + // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate - cudaGetLastError(); + (void)cudaGetLastError(); CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); cuda_ctx->cuda_graph->instance = nullptr; CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); @@ -2713,7 +2843,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0, cudaGetErrorString(err)); @@ -2733,7 +2863,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) { cudaError_t err = cudaHostUnregister(buffer); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); } } @@ -2909,7 +3039,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } } break; case GGML_OP_OUT_PROD: - return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { @@ -2925,6 +3055,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } } break; + case GGML_OP_GET_ROWS_BACK: + { + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; @@ -2983,7 +3117,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; case GGML_OP_REPEAT_BACK: - return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1; + return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15); case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; @@ -2998,8 +3132,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } return false; } break; + case GGML_OP_SILU_BACK: + return ggml_is_contiguous(op->src[0]); + break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; break; case GGML_OP_NONE: @@ -3024,8 +3162,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: return true; + case GGML_OP_SOFT_MAX_BACK: { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; + } case GGML_OP_ROPE: - return ggml_is_contiguous(op->src[0]); + case GGML_OP_ROPE_BACK: { + const size_t ts = ggml_type_size(op->src[0]->type); + const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2]; + return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts; + } case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM: @@ -3081,6 +3228,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { return op->ne[1]; case GGML_OP_MUL_MAT_ID: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: return op->ne[2]; default: return ggml_nrows(op); @@ -3183,7 +3331,7 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t features.push_back({ "FORCE_CUBLAS", "1" }); #endif - #ifdef GGML_CUDA_NO_VMM + #ifndef GGML_USE_VMM features.push_back({ "NO_VMM", "1" }); #endif diff --git a/src/ggml-cuda/mmvq.cu b/src/ggml-cuda/mmvq.cu index e3b912d87..4fb466ca0 100644 --- a/src/ggml-cuda/mmvq.cu +++ b/src/ggml-cuda/mmvq.cu @@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda( int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; - if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA + if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2 switch(ncols_y) { case 1: nwarps = 4; @@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda( break; } } + const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; const dim3 block_nums(nblocks, 1, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); diff --git a/src/ggml-cuda/norm.cu b/src/ggml-cuda/norm.cu index 133e219f0..d991ec972 100644 --- a/src/ggml-cuda/norm.cu +++ b/src/ggml-cuda/norm.cu @@ -5,20 +5,24 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - float2 mean_var = make_float2(0.f, 0.f); + x += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + + float2 mean_var = make_float2(0.0f, 0.0f); for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; + const float xi = x[col]; mean_var.x += xi; mean_var.y += xi * xi; } // sum up partial sums mean_var = warp_reduce_sum(mean_var); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float2 s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = mean_var; } @@ -32,7 +36,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c const float inv_std = rsqrtf(var + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; + dst[col] = (x[col] - mean) * inv_std; } } @@ -40,14 +44,8 @@ template static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { // blockIdx.x: num_groups idx // threadIdx.x: block_size idx - int start = blockIdx.x * group_size; - int end = start + group_size; - - start += threadIdx.x; - - if (end >= ne_elements) { - end = ne_elements; - } + const int start = blockIdx.x*group_size + threadIdx.x; + const int end = min(blockIdx.x*group_size + group_size, ne_elements); float tmp = 0.0f; // partial sum for thread in warp @@ -56,10 +54,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -68,11 +67,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); } - float mean = tmp / group_size; + const float mean = tmp / group_size; tmp = 0.0f; for (int j = start; j < end; j += block_size) { - float xi = x[j] - mean; + const float xi = x[j] - mean; dst[j] = xi; tmp += xi * xi; } @@ -80,8 +79,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -90,8 +89,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); } - float variance = tmp / group_size; - float scale = rsqrtf(variance + eps); + const float variance = tmp / group_size; + const float scale = rsqrtf(variance + eps); for (int j = start; j < end; j += block_size) { dst[j] *= scale; } @@ -102,19 +101,23 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; + x += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; + const float xi = x[col]; tmp += xi * xi; } // sum up partial sums tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -127,12 +130,63 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = scale * x[row*ncols + col]; + dst[col] = scale * x[col]; + } +} + +template +static __global__ void rms_norm_back_f32( + const float * grad, const float * xf, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + grad += int64_t(row)*ncols; + xf += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + + float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass + float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs + + for (int col = tid; col < ncols; col += block_size) { + const float xfi = xf[col]; + sum_xx += xfi * xfi; + sum_xg += xfi * grad[col]; + } + + // sum up partial sums + sum_xx = warp_reduce_sum(sum_xx); + sum_xg = warp_reduce_sum(sum_xg); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum_xx[32]; + __shared__ float s_sum_xg[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum_xx[warp_id] = sum_xx; + s_sum_xg[warp_id] = sum_xg; + } + __syncthreads(); + + sum_xx = s_sum_xx[lane_id]; + sum_xx = warp_reduce_sum(sum_xx); + + sum_xg = s_sum_xg[lane_id]; + sum_xg = warp_reduce_sum(sum_xg); + } + + const float mean_eps = sum_xx / ncols + eps; + const float sum_eps = sum_xx + ncols*eps; + + const float scale_grad = rsqrtf(mean_eps); + const float scale_x = -scale_grad * sum_xg/sum_eps; + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale_grad*grad[col] + scale_x*xf[col]; } } static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); norm_f32<<>>(x, dst, ncols, eps); @@ -142,7 +196,8 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i } } -static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { +static void group_norm_f32_cuda( + const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { if (group_size < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); @@ -153,7 +208,6 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou } static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); rms_norm_f32<<>>(x, dst, ncols, eps); @@ -163,6 +217,16 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } +static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_back_f32<<>>(grad, xf, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_back_f32<1024><<>>(grad, xf, dst, ncols, eps); + } +} + void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -179,6 +243,7 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } @@ -198,6 +263,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float eps; memcpy(&eps, dst->op_params + 1, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream); @@ -219,6 +285,33 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } + +void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * grad = dst->src[0]; // gradients + const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass + + const float * grad_d = (const float *) grad->data; + const float * src0f_d = (const float *) src0f->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(grad)); + + GGML_ASSERT( grad->type == GGML_TYPE_F32); + GGML_ASSERT(src0f->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0f->ne[0]; + const int64_t nrows = ggml_nrows(src0f); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + + rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream); +} diff --git a/src/ggml-cuda/norm.cuh b/src/ggml-cuda/norm.cuh index 431a8f74d..d63d34380 100644 --- a/src/ggml-cuda/norm.cuh +++ b/src/ggml-cuda/norm.cuh @@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml-cuda/out-prod.cu b/src/ggml-cuda/out-prod.cu index 619cfdcb5..c9b2b699c 100644 --- a/src/ggml-cuda/out-prod.cu +++ b/src/ggml-cuda/out-prod.cu @@ -11,16 +11,15 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ne01 == ne11); GGML_ASSERT(ne0 == ne00); GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == src0->ne[2]); + GGML_ASSERT(ne2 % src0->ne[2] == 0); + GGML_ASSERT(ne3 % src0->ne[3] == 0); + GGML_ASSERT(ne2 == src1->ne[2]); - GGML_ASSERT(ne3 == src0->ne[3]); GGML_ASSERT(ne3 == src1->ne[3]); const float * src0_d = (const float *) src0->data; @@ -33,19 +32,37 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const float alpha = 1.0f; const float beta = 0.0f; - GGML_ASSERT(ne2 == 1); - GGML_ASSERT(ne3 == 1); CUBLAS_CHECK(cublasSetStream(handle, stream)); + const int64_t lda = nb01 / sizeof(float); + const int64_t ldc = nb1 / sizeof(float); + const bool src1_T = ggml_is_transposed(src1); const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); - CUBLAS_CHECK( - cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, - ne0, ne1, ne01, - &alpha, src0_d, ne00, - src1_d, ldb, - &beta, dst_d, ne0)); + // data strides in dimensions 2/3 + const size_t s02 = nb02 / sizeof(float); + const size_t s03 = nb03 / sizeof(float); + const size_t s12 = nb12 / sizeof(float); + const size_t s13 = nb13 / sizeof(float); + const size_t s2 = nb2 / sizeof(float); + const size_t s3 = nb3 / sizeof(float); + + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + + // TODO batched matrix multiplication + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, + src1_d + i3 *s13 + i2 *s12, ldb, + &beta, dst_d + i3 *s3 + i2 *s2, ldc)); + } + } } diff --git a/src/ggml-cuda/rope.cu b/src/ggml-cuda/rope.cu index 2c84778d2..18f691b2d 100644 --- a/src/ggml-cuda/rope.cu +++ b/src/ggml-cuda/rope.cu @@ -16,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +template static __device__ void rope_yarn( - float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta) { + const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor, + float mscale, float & cos_theta, float & sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -29,24 +30,28 @@ static __device__ void rope_yarn( // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); } - *cos_theta = cosf(theta) * mscale; - *sin_theta = sinf(theta) * mscale; + cos_theta = cosf(theta) * mscale; + sin_theta = sinf(theta) * mscale; + if (!forward) { + sin_theta *= -1.0f; + } } -template +template static __global__ void rope_norm( - const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; } - const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= n_dims) { - const int i = row*ne0 + i0; + const int i = row_dst*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -54,39 +59,43 @@ static __global__ void rope_norm( return; } - const int i = row*ne0 + i0; - const int i2 = row/p_delta_rows; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + + const int idst = row_dst*ne0 + i0; + const int ix = channel_x*s2 + row_x*s1 + i0; - const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + 1]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + 1]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + 1] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + 1] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_neox( - const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; } - const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= n_dims) { - const int i = row*ne0 + i0; + const int i = row_dst*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -94,39 +103,43 @@ static __global__ void rope_neox( return; } - const int i = row*ne0 + i0/2; - const int i2 = row/p_delta_rows; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; - const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + n_dims/2]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_multi( - const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, + const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; } - const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= n_dims) { - const int i = row*ne0 + i0; + const int i = row_dst*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -134,25 +147,28 @@ static __global__ void rope_multi( return; } - const int i = row*ne0 + i0/2; - const int i2 = row/p_delta_rows; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; - int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; - int sec_w = sections.v[1] + sections.v[0]; - int sector = (i0 / 2) % sect_dims; + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; + + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; if (sector < sections.v[0]) { - theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); } else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f); + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f); + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); } else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -160,42 +176,46 @@ static __global__ void rope_multi( float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + n_dims/2]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_vision( - const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; } - const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; - const int i = row*ne0 + i0/2; - const int i2 = row/p_delta_rows; // i2-th tokens + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; - int sect_dims = sections.v[0] + sections.v[1]; - int sec_w = sections.v[1] + sections.v[0]; - int sector = (i0 / 2) % sect_dims; + const int sect_dims = sections.v[0] + sections.v[1]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; if (sector < sections.v[0]) { const int p = sector; - theta_base = pos[i2]*powf(theta_scale, p); + theta_base = pos[channel_x]*powf(theta_scale, p); } else if (sector >= sections.v[0] && sector < sec_w) { const int p = sector - sections.v[0]; - theta_base = pos[i2 + ne2]*powf(theta_scale, p); + theta_base = pos[channel_x + ne2]*powf(theta_scale, p); } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -203,19 +223,20 @@ static __global__ void rope_vision( float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + n_dims]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims] = x0*sin_theta + x1*cos_theta; } -template +template static void rope_norm_cuda( - const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -224,22 +245,21 @@ static void rope_norm_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_norm<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_norm<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } else { - rope_norm<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_norm<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } } -template +template static void rope_neox_cuda( - const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -248,22 +268,21 @@ static void rope_neox_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_neox<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_neox<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } else { - rope_neox<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_neox<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } } -template +template static void rope_multi_cuda( - const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -272,22 +291,21 @@ static void rope_multi_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_multi<<>>( - x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, sections - ); + rope_multi<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); } else { - rope_multi<<>>( - x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, sections - ); + rope_multi<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); } } -template +template static void rope_vision_cuda( - const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -298,80 +316,18 @@ static void rope_vision_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_vision<<>>( - x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, sections - ); + rope_vision<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); } else { - rope_vision<<>>( - x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, sections - ); + rope_vision<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); } } -static void rope_norm_cuda_f16( - const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - - rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); -} - -static void rope_norm_cuda_f32( - const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - - rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); -} - -static void rope_neox_cuda_f16( - const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - - rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); -} - -static void rope_neox_cuda_f32( - const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream -) { - - rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); -} - -static void rope_multi_cuda_f16( - const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream -) { - - rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); -} - -static void rope_multi_cuda_f32( - const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream -) { - - rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); -} - -static void rope_vision_cuda_f16( - const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream -) { - - rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); -} - -static void rope_vision_cuda_f32( - const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream -) { - - rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); -} - -void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +template +void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src2 = dst->src[2]; @@ -382,7 +338,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); @@ -392,6 +347,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ne02 = src0->ne[2]; // num heads const int64_t nr = ggml_nrows(src0); + const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); + const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -440,59 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // compute if (is_neox) { if (src0->type == GGML_TYPE_F32) { - rope_neox_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_neox_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_neox_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_neox_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else { GGML_ABORT("fatal error"); } } else if (is_mrope && !is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_multi_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, sections, stream - ); + rope_multi_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_multi_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, sections, stream - ); + rope_multi_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else if (is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_vision_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, sections, stream - ); + rope_vision_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_vision_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, sections, stream - ); + rope_vision_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else { if (src0->type == GGML_TYPE_F32) { - rope_norm_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_norm_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_norm_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_norm_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else { GGML_ABORT("fatal error"); } } } + +void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_rope_impl(ctx, dst); +} + +void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_rope_impl(ctx, dst); +} diff --git a/src/ggml-cuda/rope.cuh b/src/ggml-cuda/rope.cuh index 0f787a0b2..9139f3b22 100644 --- a/src/ggml-cuda/rope.cuh +++ b/src/ggml-cuda/rope.cuh @@ -3,3 +3,5 @@ #define CUDA_ROPE_BLOCK_SIZE 256 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml-cuda/softmax.cu b/src/ggml-cuda/softmax.cu index c24abae1f..da377200e 100644 --- a/src/ggml-cuda/softmax.cu +++ b/src/ggml-cuda/softmax.cu @@ -1,5 +1,7 @@ #include "common.cuh" +#include "ggml.h" #include "softmax.cuh" +#include template static __device__ __forceinline__ float t2f32(T val) { @@ -11,14 +13,26 @@ __device__ float __forceinline__ t2f32(half val) { return __half2float(val); } -template -static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. +// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif +template +static __global__ void soft_max_f32( + const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, + const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; const int rowx = blockIdx.x; const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension + x += int64_t(rowx)*ncols; + mask += int64_t(rowy)*ncols * (mask != nullptr); + dst += int64_t(rowx)*ncols; + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; const int warp_id = threadIdx.x / WARP_SIZE; @@ -29,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication // shared memory buffer to cache values between iterations: - float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols; + float * vals = use_shared ? buf_iw + WARP_SIZE : dst; float max_val = -INFINITY; @@ -41,10 +55,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst break; } - const int64_t ix = (int64_t)rowx*ncols + col; - const int64_t iy = (int64_t)rowy*ncols + col; - - const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); + const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -110,8 +121,32 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst return; } - const int64_t idst = (int64_t)rowx*ncols + col; - dst[idst] = vals[col] * inv_sum; + dst[col] = vals[col] * inv_sum; + } +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +static __global__ void soft_max_back_f32( + const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + + grad += int64_t(rowx)*ncols; + dstf += int64_t(rowx)*ncols; + dst += int64_t(rowx)*ncols; + + float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dgf_dot += dstf[col]*grad[col]; + } + + dgf_dot = warp_reduce_sum(dgf_dot); + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; } } @@ -121,7 +156,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); + const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); const uint32_t n_head = nrows_x/nrows_y; @@ -131,50 +166,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // FIXME: this limit could be raised by ~2-4x on Ampere or newer - if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { + if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 64: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 128: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 256: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 512: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 1024: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 2048: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 4096: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; default: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; } } else { - const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); } } +static void soft_max_back_f32_cuda( + const float * grad, const float * dstf, float * dst, + const int ncols, const int nrows, const float scale, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + + soft_max_back_f32<<>>(grad, dstf, dst, ncols, scale); +} + void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const void * src1_d = src1 ? (const void *)src1->data : nullptr; + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + float * dst_d = (float *) dst->data; - float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); @@ -189,18 +242,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float scale = 1.0f; float max_bias = 0.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (use_f16) { - const half * src1_dd = (const half *)src1_d; - - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } else { - const float * src1_dd = (const float *)src1_d; - - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } } + +void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // grad + const ggml_tensor * src1 = dst->src[1]; // forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + + soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream); +} diff --git a/src/ggml-cuda/softmax.cuh b/src/ggml-cuda/softmax.cuh index 4ef4ff86c..93dfee835 100644 --- a/src/ggml-cuda/softmax.cuh +++ b/src/ggml-cuda/softmax.cuh @@ -3,3 +3,5 @@ #define CUDA_SOFT_MAX_BLOCK_SIZE 1024 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml-cuda/unary.cu b/src/ggml-cuda/unary.cu index 81fc92202..6b21f407d 100644 --- a/src/ggml-cuda/unary.cu +++ b/src/ggml-cuda/unary.cu @@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } +static __global__ void silu_back_f32( + const float * grad, const float * xf, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float xfi = xf[i]; + const float s = 1.0f / (1.0f + expf(-xfi)); + dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s)); +} + static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ silu_f32<<>>(x, dst, k); } +static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + silu_back_f32<<>>(grad, x, dst, k); +} + static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; tanh_f32<<>>(x, dst, k); @@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // input from forward pass + const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream); +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/src/ggml-cuda/unary.cuh b/src/ggml-cuda/unary.cuh index c91936728..e7f62643a 100644 --- a/src/ggml-cuda/unary.cuh +++ b/src/ggml-cuda/unary.cuh @@ -4,6 +4,7 @@ #define CUDA_STEP_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_SILU_BACK_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 #define CUDA_SIGMOID_BLOCK_SIZE 256 @@ -23,6 +24,8 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml-cuda/vendors/hip.h b/src/ggml-cuda/vendors/hip.h index c905b15d7..8594093f0 100644 --- a/src/ggml-cuda/vendors/hip.h +++ b/src/ggml-cuda/vendors/hip.h @@ -19,6 +19,12 @@ #define CUBLAS_TF32_TENSOR_OP_MATH 0 #define CUDA_R_16F HIPBLAS_R_16F #define CUDA_R_32F HIPBLAS_R_32F +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended +#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned +#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite +#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate @@ -74,6 +80,21 @@ #define cudaMemGetInfo hipMemGetInfo #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice +#define cuDeviceGet hipDeviceGet +#define CUdevice hipDevice_t +#define CUdeviceptr hipDeviceptr_t +#define cuMemUnmap hipMemUnmap +#define CUmemAccessDesc hipMemAccessDesc +#define cuMemAddressFree hipMemAddressFree +#define cuMemRelease hipMemRelease +#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t +#define cuMemCreate hipMemCreate +#define cuMemAddressReserve hipMemAddressReserve +#define cuMemMap hipMemMap +#define cuMemSetAccess hipMemSetAccess +#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity +#define CUmemAllocationProp hipMemAllocationProp +#define cuDeviceGetAttribute hipDeviceGetAttribute #define cudaStreamCreateWithFlags hipStreamCreateWithFlags #define cudaStreamDestroy hipStreamDestroy #define cudaStreamFireAndForget hipStreamFireAndForget @@ -81,6 +102,28 @@ #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaGraphExec_t hipGraphExec_t +#define cudaGraphNode_t hipGraphNode_t +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaGraphExecDestroy hipGraphExecDestroy +#define cudaGraphLaunch hipGraphLaunch +#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure +#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult +#define cudaGraphNodeType hipGraphNodeType +#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel +#define cudaGraphInstantiate hipGraphInstantiate +#define cudaStreamEndCapture hipStreamEndCapture +#define cudaGraphDestroy hipGraphDestroy +#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams +#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction +#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams +#define cudaGraphNodeGetType hipGraphNodeGetType +#define cudaGraphGetNodes hipGraphGetNodes +#define cudaGraphExecUpdate hipGraphExecUpdate +#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed +#define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaGraph_t hipGraph_t #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define __trap() do { abort(); __builtin_unreachable(); } while(0) diff --git a/src/ggml-hip/CMakeLists.txt b/src/ggml-hip/CMakeLists.txt index d090ba9bd..ecc3bc66d 100644 --- a/src/ggml-hip/CMakeLists.txt +++ b/src/ggml-hip/CMakeLists.txt @@ -92,6 +92,14 @@ if (GGML_CUDA_NO_PEER_COPY) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() +if (GGML_HIP_GRAPHS) + add_compile_definitions(GGML_HIP_GRAPHS) +endif() + +if (GGML_HIP_NO_VMM) + add_compile_definitions(GGML_HIP_NO_VMM) +endif() + if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) target_link_libraries(ggml-hip PRIVATE hip::device) diff --git a/src/ggml-metal/ggml-metal.m b/src/ggml-metal/ggml-metal.m index a85502ee0..76f8e4291 100644 --- a/src/ggml-metal/ggml-metal.m +++ b/src/ggml-metal/ggml-metal.m @@ -19,7 +19,10 @@ // max number of MTLCommandBuffer used to submit a graph for processing #define GGML_METAL_MAX_COMMAND_BUFFERS 8 -#define UNUSED(x) (void)(x) +// create residency sets only on macOS >= 15.0 +#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 +#define GGML_METAL_HAS_RESIDENCY_SETS 1 +#endif // globals @@ -39,6 +42,7 @@ bool has_simdgroup_reduction; bool has_simdgroup_mm; + bool has_residency_sets; bool has_bfloat; bool use_bfloat; @@ -48,6 +52,7 @@ /*.mtl_device_ref_count =*/ 0, /*.has_simdgroup_reduction =*/ false, /*.has_simdgroup_mm =*/ false, + /*.has_residency_sets =*/ false, /*.has_bfloat =*/ false, /*.use_bfloat =*/ false, /*.name =*/ "", @@ -59,12 +64,18 @@ if (ctx->mtl_device == nil) { ctx->mtl_device = MTLCreateSystemDefaultDevice(); + } + if (ctx->mtl_device) { ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL; +#endif + ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6]; @@ -90,8 +101,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte ctx->mtl_device_ref_count--; if (ctx->mtl_device_ref_count == 0) { - [ctx->mtl_device release]; - ctx->mtl_device = nil; + if (ctx->mtl_device) { + [ctx->mtl_device release]; + ctx->mtl_device = nil; + } } } @@ -483,6 +496,11 @@ @implementation GGMLMetalClass GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); ctx->queue = [device newCommandQueue]; + if (ctx->queue == nil) { + GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); + return NULL; + } + ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); id metal_library; @@ -649,6 +667,7 @@ @implementation GGMLMetalClass GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false"); GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false"); + GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false"); GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false"); GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); @@ -1035,8 +1054,70 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap int n_buffers; struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; + + // optional MTLResidencySet + id rset; }; +// rset init +static bool ggml_backend_metal_buffer_rset_init( + struct ggml_backend_metal_buffer_context * ctx, + struct ggml_backend_metal_device_context * ctx_dev, + id device) { + ctx->rset = nil; + + if (!ctx_dev->has_residency_sets) { + return true; + } + +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, *)) { + MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; + desc.label = @"ggml_backend_metal"; + desc.initialCapacity = ctx->n_buffers; + + NSError * error; + ctx->rset = [device newResidencySetWithDescriptor:desc error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + [desc release]; + return false; + } + + [desc release]; + + for (int i = 0; i < ctx->n_buffers; i++) { + [ctx->rset addAllocation:ctx->buffers[i].metal]; + } + + [ctx->rset commit]; + [ctx->rset requestResidency]; + + return true; + } +#else + GGML_UNUSED(ctx_dev); + GGML_UNUSED(device); +#endif + + return true; +} + +// rset free +static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) { +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, *)) { + if (ctx->rset) { + [ctx->rset endResidency]; + [ctx->rset removeAllAllocations]; + [ctx->rset release]; + } + } +#else + GGML_UNUSED(ctx); +#endif +} + // finds the Metal buffer that contains the tensor data on the GPU device // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the // Metal buffer based on the host memory pointer @@ -4176,6 +4257,8 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) for (int i = 0; i < ctx->n_buffers; i++) { [ctx->buffers[i].metal release]; } + + ggml_backend_metal_buffer_rset_free(ctx); ggml_backend_metal_device_rel(buffer->buft->device->context); if (ctx->owned) { @@ -4198,19 +4281,19 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { memset((char *)tensor->data + offset, value, size); - UNUSED(buffer); + GGML_UNUSED(buffer); } static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); - UNUSED(buffer); + GGML_UNUSED(buffer); } static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { memcpy(data, (const char *)tensor->data + offset, size); - UNUSED(buffer); + GGML_UNUSED(buffer); } static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { @@ -4220,7 +4303,7 @@ static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, c } return false; - UNUSED(buffer); + GGML_UNUSED(buffer); } static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { @@ -4246,7 +4329,7 @@ static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_ static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) { return "Metal"; - UNUSED(buft); + GGML_UNUSED(buft); } static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { @@ -4270,8 +4353,8 @@ static void ggml_backend_metal_log_allocated_size(id device, size_t s } #endif #endif - UNUSED(device); - UNUSED(size_aligned); + GGML_UNUSED(device); + GGML_UNUSED(size_aligned); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -4284,7 +4367,8 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba size_aligned += (size_page - (size_aligned % size_page)); } - id device = ggml_backend_metal_device_acq(buft->device->context); + struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context; + id device = ggml_backend_metal_device_acq(ctx_dev); ctx->all_data = ggml_metal_host_malloc(size_aligned); ctx->all_size = size_aligned; @@ -4307,7 +4391,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); free(ctx); - ggml_backend_metal_device_rel(buft->device->context); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); return NULL; } @@ -4318,7 +4409,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { return 32; - UNUSED(buft); + GGML_UNUSED(buft); } static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { @@ -4328,13 +4419,13 @@ static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_ty return max_size; - UNUSED(buft); + GGML_UNUSED(buft); } static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { return true; - UNUSED(buft); + GGML_UNUSED(buft); } ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { @@ -4357,7 +4448,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { return "Metal_Mapped"; - UNUSED(buft); + GGML_UNUSED(buft); } static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) { @@ -4400,7 +4491,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz size_aligned += (size_page - (size_aligned % size_page)); } - id device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main); + struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main; + id device = ggml_backend_metal_device_acq(ctx_dev); // the buffer fits into the max buffer size allowed by the device if (size_aligned <= device.maxBufferLength) { @@ -4453,6 +4545,13 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz } } + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); } @@ -4461,7 +4560,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz static const char * ggml_backend_metal_name(ggml_backend_t backend) { return "Metal"; - UNUSED(backend); + GGML_UNUSED(backend); } static void ggml_backend_metal_free(ggml_backend_t backend) { @@ -4766,6 +4865,13 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back } } + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); } @@ -4779,7 +4885,7 @@ static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; - UNUSED(dev); + GGML_UNUSED(dev); } static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { diff --git a/src/ggml-metal/ggml-metal.metal b/src/ggml-metal/ggml-metal.metal index 8ba43904d..44f04c909 100644 --- a/src/ggml-metal/ggml-metal.metal +++ b/src/ggml-metal/ggml-metal.metal @@ -4416,7 +4416,6 @@ void kernel_mul_mv_q2_K_f32_impl( device const half * dh = &x[ib].d; for (int row = 0; row < N_DST; row++) { - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; i += 2) { @@ -4447,7 +4446,7 @@ void kernel_mul_mv_q2_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -4613,7 +4612,7 @@ void kernel_mul_mv_q3_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { dst_f32[first_row + row] = sumf1[row]; } } @@ -4729,7 +4728,7 @@ void kernel_mul_mv_q4_K_f32_impl( device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -4861,7 +4860,7 @@ void kernel_mul_mv_q5_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = tot; @@ -4906,6 +4905,10 @@ void kernel_mul_mv_q6_K_f32_impl( const int row = 2*r0 + sgitg; + if (row >= args.ne0) { + return; + } + const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5061,7 +5064,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.25f; @@ -5179,7 +5182,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.25f; @@ -5289,7 +5292,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.5f; @@ -5401,7 +5404,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5514,7 +5517,7 @@ void kernel_mul_mv_iq2_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.25f; @@ -5614,7 +5617,7 @@ void kernel_mul_mv_iq1_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5709,7 +5712,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5799,7 +5802,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5888,7 +5891,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; diff --git a/src/ggml-rpc/ggml-rpc.cpp b/src/ggml-rpc/ggml-rpc.cpp index 63da2b86b..3d0c46578 100644 --- a/src/ggml-rpc/ggml-rpc.cpp +++ b/src/ggml-rpc/ggml-rpc.cpp @@ -181,7 +181,7 @@ struct ggml_backend_rpc_context { struct ggml_backend_rpc_buffer_context { std::shared_ptr sock; - std::unordered_map base_cache; + void * base_ptr; uint64_t remote_ptr; }; @@ -423,16 +423,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) { - return ctx->base_cache[buffer]; + if (ctx->base_ptr != nullptr) { + return ctx->base_ptr; } rpc_msg_buffer_get_base_req request = {ctx->remote_ptr}; rpc_msg_buffer_get_base_rsp response; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response)); GGML_ASSERT(status); - void * base_ptr = reinterpret_cast(response.base_ptr); - ctx->base_cache[buffer] = base_ptr; - return base_ptr; + ctx->base_ptr = reinterpret_cast(response.base_ptr); + return ctx->base_ptr; } static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { @@ -557,7 +556,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back if (response.remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr}, + new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr}, response.remote_size); return buffer; } else { diff --git a/src/ggml-sycl/backend.hpp b/src/ggml-sycl/backend.hpp index 85748a5b4..b1df4e5db 100644 --- a/src/ggml-sycl/backend.hpp +++ b/src/ggml-sycl/backend.hpp @@ -29,5 +29,6 @@ #include "wkv6.hpp" #include "outprod.hpp" #include "element_wise.hpp" +#include "gla.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/src/ggml-sycl/common.hpp b/src/ggml-sycl/common.hpp index e9500f3a1..abad847ca 100644 --- a/src/ggml-sycl/common.hpp +++ b/src/ggml-sycl/common.hpp @@ -333,8 +333,12 @@ struct ggml_backend_sycl_context { // pool std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; + std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; + static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); + static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device); + ggml_sycl_pool & pool(int device) { if (pools[device] == nullptr) { pools[device] = new_pool_for_device(stream(device,0), device); @@ -345,6 +349,15 @@ struct ggml_backend_sycl_context { ggml_sycl_pool & pool() { return pool(device); } + + ggml_sycl_pool & host_pool(int device) { + if (host_pools[device] == nullptr) { + host_pools[device] = new_pool_for_host(stream(device, 0), device); + } + return *host_pools[device]; + } + + ggml_sycl_pool & host_pool() { return host_pool(device); } }; // common device functions diff --git a/src/ggml-sycl/dpct/helper.hpp b/src/ggml-sycl/dpct/helper.hpp index e167948e7..c96395be6 100644 --- a/src/ggml-sycl/dpct/helper.hpp +++ b/src/ggml-sycl/dpct/helper.hpp @@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { return device_type.str(); } +template struct matrix_info_t { + oneapi::mkl::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; +}; + namespace dpct { typedef sycl::queue *queue_ptr; @@ -1727,26 +1735,13 @@ namespace dpct }; template - inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void **a, int lda, - const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size) - { - struct matrix_info_t - { - oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; - }; - + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, + int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, + int ldb, const void * beta, void ** c, int ldc, int batch_size, + matrix_info_t * matrix_info) { Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); matrix_info->transpose_info[0] = a_trans; matrix_info->transpose_info[1] = b_trans; matrix_info->value_info[0] = alpha_value; @@ -1763,23 +1758,18 @@ namespace dpct sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( oneapi::mkl::backend_selector{ q }, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast(a), - matrix_info->ld_info, reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), matrix_info->ld_info + 2, 1, - &(matrix_info->groupsize_info)); + matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), + reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), + reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #else sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, - matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, + matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), - matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast(c), - matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), + reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #endif - - q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); }); } template @@ -2422,25 +2412,11 @@ namespace dpct /// \param [in] ldc Leading dimension of C. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a[], - library_data_t a_type, int lda, const void *b[], - library_data_t b_type, int ldb, const void *beta, - void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type) - { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } - + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, + int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, + const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], + library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, + matrix_info_t * matrix_info) { std::uint64_t key = detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); switch (key) @@ -2449,48 +2425,24 @@ namespace dpct library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_double, library_data_t::real_double, library_data_t::real_double, library_data_t::real_double): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_half): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #ifdef __INTEL_MKL__ @@ -2498,19 +2450,16 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #endif @@ -2522,10 +2471,9 @@ namespace dpct dpct::get_value(reinterpret_cast(alpha), q); float beta_float = dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, - a, lda, b, ldb, &beta_float, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size, + matrix_info); break; } case detail::get_type_combination_id( @@ -2533,8 +2481,7 @@ namespace dpct library_data_t::real_float, library_data_t::real_float): { detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2542,8 +2489,7 @@ namespace dpct library_data_t::real_float, library_data_t::real_float): { detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2557,8 +2503,7 @@ namespace dpct sycl::half alpha_half(alpha_value); sycl::half beta_half(beta_value); detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info); break; } default: diff --git a/src/ggml-sycl/ggml-sycl.cpp b/src/ggml-sycl/ggml-sycl.cpp index 037c8093e..2984ed82e 100644 --- a/src/ggml-sycl/ggml-sycl.cpp +++ b/src/ggml-sycl/ggml-sycl.cpp @@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } }; +struct ggml_sycl_pool_host : public ggml_sycl_pool { + queue_ptr qptr; + int device; + + inline static int counter{ 0 }; + + struct ggml_sycl_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + // Set arbitrarly to 64 + static constexpr int MAX_POOL_SIZE{ 64 }; + std::vector buffer_pool = std::vector(MAX_POOL_SIZE); + size_t pool_size = 0; + + explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} + + ~ggml_sycl_pool_host() { + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + b.ptr = nullptr; + pool_size -= b.size; + b.size = 0; + } + } + counter = 0; + } + + void * alloc(size_t size, size_t * actual_size) override { + if (counter == MAX_POOL_SIZE) { + ggml_sycl_buffer b = buffer_pool[0]; + void * ptr = b.ptr; + *actual_size = b.size; + counter = 1; + return ptr; + } + ggml_sycl_buffer & b = buffer_pool[counter]; + + if (b.ptr == nullptr) { + void * ptr; + + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr))); + if (!ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); + return nullptr; + } + pool_size += size; + *actual_size = size; + counter = counter + 1; + return ptr; + } else { + ++counter; + b.size = size; + return b.ptr; + } + } + + void free(void * ptr, size_t size) override { + // if the pool is not completed add the pointer to it in place of the first nullptr found. + // Otherwise do nothing, pointers will be freed once the pool is deallocated. + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + } +}; + +std::unique_ptr ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) { + // return pool for the host to speed up memory management + return std::unique_ptr(new ggml_sycl_pool_host(qptr, device)); +} + std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { // TBD: NO VMM support // if (ggml_sycl_info().devices[device].vmm) { @@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); sycl::range<3> block_dims(1, ne12, ne13); /* @@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, }); } SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, - (const void **)(ptrs_src.get() + 0 * ne23), - dpct::library_data_t::real_half, nb01 / nb00, - (const void **)(ptrs_src.get() + 1 * ne23), - dpct::library_data_t::real_half, nb11 / nb10, beta, - (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, - cu_compute_type))); + *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, + (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, + (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, + (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); } } catch (sycl::exception const &exc) { @@ -3802,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf); } -static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max); -} - static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope); @@ -4014,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens ggml_sycl_diag_mask_inf(ctx, dst); break; case GGML_OP_SOFT_MAX: - ggml_sycl_soft_max(ctx, dst); + ggml_sycl_op_soft_max(ctx, dst); break; case GGML_OP_ROPE: ggml_sycl_rope(ctx, dst); @@ -4040,6 +4112,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens case GGML_OP_RWKV_WKV6: ggml_sycl_op_rwkv_wkv6(ctx, dst); break; + case GGML_OP_GATED_LINEAR_ATTN: + ggml_sycl_op_gated_linear_attn(ctx, dst); + break; default: return false; } @@ -4507,6 +4582,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: return true; default: return false; diff --git a/src/ggml-sycl/gla.cpp b/src/ggml-sycl/gla.cpp new file mode 100644 index 000000000..eedb47486 --- /dev/null +++ b/src/ggml-sycl/gla.cpp @@ -0,0 +1,105 @@ +#include + +#include "common.hpp" + +template +static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, u_int T, u_int C, u_int H, float scale, + const float * k, const float * v, const float * r, const float * td, + const float * s, float * dst) { + const u_int head_size = HEAD_SIZE; + const u_int state_size = C * head_size; + const u_int n_seq_tokens = T / B; + sycl::range<1> block_dims((C / H)); + sycl::range<1> grid_dims((B * H)); + stream->submit([&](sycl::handler & cgh) { + /* local memory accessors*/ + auto _k = sycl::local_accessor(sycl::range<1>(head_size), cgh); + auto _r = sycl::local_accessor(sycl::range<1>(head_size), cgh); + auto _td = sycl::local_accessor(sycl::range<1>(head_size), cgh); + + cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) { + u_int tid = item.get_local_id(0); + u_int bid = item.get_group(0); + + u_int batch_i = bid / H; + u_int head_i = bid % H; + + float state[head_size]; + +#pragma unroll + for (u_int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; + t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + + item.barrier(sycl::access::fence_space::local_space); //sync threads + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + item.barrier(sycl::access::fence_space::local_space); //sync threads + + const float _v = v[t]; + float y = 0; + + for (u_int j = 0; j < head_size; j += 4) { + const sycl::float4 & k = (sycl::float4 &) (_k[j]); + const sycl::float4 & r = (sycl::float4 &) (_r[j]); + const sycl::float4 & td = (sycl::float4 &) (_td[j]); + sycl::float4 & s = (sycl::float4 &) (state[j]); + sycl::float4 kv; + + kv.x() = k.x() * _v; + kv.y() = k.y() * _v; + kv.z() = k.z() * _v; + kv.w() = k.w() * _v; + + s.x() = s.x() * td.x() + kv.x(); + s.y() = s.y() * td.y() + kv.y(); + s.z() = s.z() * td.z() + kv.z(); + s.w() = s.w() * td.w() + kv.w(); + + y += r.x() * s.x(); + y += r.y() * s.y(); + y += r.z() * s.z(); + y += r.w() * s.w(); + } + dst[t] = y * scale; + } +#pragma unroll + for (u_int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } + }); + }); +} + +void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const float * k_d = static_cast(dst->src[0]->data); + const float * v_d = static_cast(dst->src[1]->data); + const float * r_d = static_cast(dst->src[2]->data); + const float * td_d = static_cast(dst->src[3]->data); + const float * s_d = static_cast(dst->src[4]->data); + + const int64_t B = dst->src[4]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + dpct::queue_ptr stream = ctx.stream(); + GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64 || C / H == 128); + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + float * dst_d = (float *) dst->data; + + if (C / H == 64) { + gated_linear_attn_f32_kernel<64>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } else { + gated_linear_attn_f32_kernel<128>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } +} diff --git a/src/ggml-sycl/gla.hpp b/src/ggml-sycl/gla.hpp new file mode 100644 index 000000000..607cf3a7f --- /dev/null +++ b/src/ggml-sycl/gla.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_GLA_HPP +#define GGML_SYCL_GLA_HPP + +#include "common.hpp" + +void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_GLA_HPP diff --git a/src/ggml-sycl/softmax.cpp b/src/ggml-sycl/softmax.cpp index a9b3fce0d..563e0655f 100644 --- a/src/ggml-sycl/softmax.cpp +++ b/src/ggml-sycl/softmax.cpp @@ -1,7 +1,7 @@ -#include "norm.hpp" +#include "softmax.hpp" -template -static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, +template +static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; @@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const slope = sycl::pow(base, float(exp)); } - float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols; + float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols; float max_val = -INFINITY; for (int col0 = 0; col0 < ncols; col0 += block_size) { @@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); + const float val = x[ix]*scale + (mask ? slope*static_cast(mask[iy]) : 0.0f); vals[col] = val; max_val = sycl::max(max_val, val); @@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const item_ct1.barrier(sycl::access::fence_space::local_space); max_val = buf[lane_id]; for (size_t i = 1; i < nreduce; i += 1) { - max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]); + max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]); } max_val = warp_reduce_max(max_val, item_ct1); } @@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const } } -template -static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, +template +static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const size_t n_local_scratch, queue_ptr stream) { @@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float * }); } -static void soft_max_f32_sycl(const float * x, const float * mask, +template +static void soft_max_f32_sycl(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, queue_ptr stream, int device) { @@ -223,22 +224,16 @@ static void soft_max_f32_sycl(const float * x, const float * mask, } } -void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); -#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional - const int64_t ne00 = src0->ne[0]; - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t nrows_x = ggml_nrows(dst->src[0]); + const int64_t nrows_y = dst->src[0]->ne[1]; float scale = 1.0f; float max_bias = 0.0f; @@ -246,6 +241,21 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, - nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + ggml_sycl_set_device(ctx.device); + dpct::queue_ptr main_stream = ctx.stream(); + + if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { + const sycl::half * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, + main_stream, ctx.device); + } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { + const float * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } else { + /* mask unavailable */ + soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } } diff --git a/src/ggml-sycl/softmax.hpp b/src/ggml-sycl/softmax.hpp index bdb8f712e..2cf8582ec 100644 --- a/src/ggml-sycl/softmax.hpp +++ b/src/ggml-sycl/softmax.hpp @@ -15,10 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream); +void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst); #endif // GGML_SYCL_SOFTMAX_HPP diff --git a/src/ggml-vulkan/CMakeLists.txt b/src/ggml-vulkan/CMakeLists.txt index c0ddaac82..d970f7e20 100644 --- a/src/ggml-vulkan/CMakeLists.txt +++ b/src/ggml-vulkan/CMakeLists.txt @@ -1,5 +1,20 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + find_package(Vulkan COMPONENTS glslc REQUIRED) +function(detect_host_compiler) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + else() + find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + endif() + set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE) + set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) +endfunction() + if (Vulkan_FOUND) message(STATUS "Vulkan found") @@ -73,19 +88,56 @@ if (Vulkan_FOUND) add_compile_definitions(GGML_VULKAN_RUN_TESTS) endif() - add_subdirectory(vulkan-shaders) - - set (_ggml_vk_genshaders_cmd vulkan-shaders-gen) + if (NOT CMAKE_CROSSCOMPILING) + add_subdirectory(vulkan-shaders) + if (MSVC) + foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) + string(TOUPPER ${CONFIG} CONFIG) + set_target_properties(vulkan-shaders-gen PROPERTIES + RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() + endif() + else() + if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) + set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) + else() + detect_host_compiler() + if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER) + message(FATAL_ERROR "Host compiler not found") + else() + message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}") + endif() + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) + set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) + endif() + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") + + include(ExternalProject) + # Native build through ExternalProject_Add + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} + -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + BUILD_COMMAND ${CMAKE_COMMAND} --build . + INSTALL_COMMAND ${CMAKE_COMMAND} --install . + INSTALL_DIR ${CMAKE_BINARY_DIR} + ) + ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + endif() + set (_ggml_vk_host_suffix $,.exe,>) + set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix}) set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") + set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen) - if (NOT CMAKE_CROSSCOMPILING) - set(_ggml_vk_genshaders_cmd "$/${_ggml_vk_genshaders_cmd}") - endif () + if (CMAKE_CROSSCOMPILING) + set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) + endif() add_custom_command( OUTPUT ${_ggml_vk_header} @@ -99,7 +151,7 @@ if (Vulkan_FOUND) --target-cpp ${_ggml_vk_source} --no-clean - DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd} + DEPENDS ${_ggml_vk_shader_deps} COMMENT "Generate vulkan shaders" ) diff --git a/src/ggml-vulkan/cmake/host-toolchain.cmake.in b/src/ggml-vulkan/cmake/host-toolchain.cmake.in new file mode 100644 index 000000000..b6af747a5 --- /dev/null +++ b/src/ggml-vulkan/cmake/host-toolchain.cmake.in @@ -0,0 +1,15 @@ +set(CMAKE_BUILD_TYPE Release) +set(CMAKE_C_FLAGS -O2) +set(CMAKE_CXX_FLAGS -O2) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) +set(CMAKE_C_COMPILER @HOST_C_COMPILER@) +set(CMAKE_CXX_COMPILER @HOST_CXX_COMPILER@) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@) + +if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC") + foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() +endif() diff --git a/src/ggml-vulkan/ggml-vulkan.cpp b/src/ggml-vulkan/ggml-vulkan.cpp index 649146d7b..a9d6b923c 100644 --- a/src/ggml-vulkan/ggml-vulkan.cpp +++ b/src/ggml-vulkan/ggml-vulkan.cpp @@ -29,8 +29,6 @@ #include "ggml-vulkan-shaders.hpp" -#define VK_API_VERSION VK_API_VERSION_1_2 - #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) #define VK_VENDOR_ID_AMD 0x1002 @@ -87,6 +85,10 @@ struct vk_pipeline_struct { uint32_t parameter_count; std::array wg_denoms; uint32_t align; + // set to true to request the pipeline is compiled after the dryrun + bool needed {}; + // set to true when the shader has been compiled + bool compiled {}; }; typedef std::shared_ptr vk_pipeline; @@ -188,8 +190,11 @@ struct vk_device_struct { bool mul_mat_id_m; bool mul_mat_id_s; - vk_matmul_pipeline pipeline_matmul_f32; - vk_matmul_pipeline pipeline_matmul_f32_f16; + // set to true to indicate that some shaders need to be compiled after the dryrun + bool need_compiles {}; + + vk_matmul_pipeline pipeline_matmul_f32 {}; + vk_matmul_pipeline pipeline_matmul_f32_f16 {}; vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16_f32; vk_pipeline pipeline_matmul_split_k_reduce; @@ -197,7 +202,7 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; - vk_matmul_pipeline pipeline_matmul_id_f32; + vk_matmul_pipeline pipeline_matmul_id_f32 {}; vk_matmul_pipeline2 pipeline_matmul_id_f16; vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; @@ -228,6 +233,8 @@ struct vk_device_struct { vk_pipeline pipeline_repeat_f32; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; + vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; @@ -384,10 +391,13 @@ struct vk_flash_attn_push_constants { uint32_t nev3; uint32_t nem1; + uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t nb21; uint32_t nb22; uint32_t nb23; uint32_t nb31; @@ -773,13 +783,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin GGML_ASSERT(parameter_count > 0); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT - pipeline = std::make_shared(); - pipeline->name = name; - pipeline->parameter_count = parameter_count; - pipeline->push_constant_size = push_constant_size; - pipeline->wg_denoms = wg_denoms; - pipeline->align = align; - vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); @@ -862,6 +865,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + pipeline->compiled = true; { std::lock_guard guard(device->mutex); @@ -872,12 +876,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::lock_guard guard(compile_count_mutex); assert(compile_count > 0); compile_count--; - - // "Progress bar" for shader compiles - static uint32_t total_compile_count = 0; - if ((total_compile_count++ % 10) == 0) { - std::cerr << "."; - } } compile_count_cond.notify_all(); } @@ -903,6 +901,10 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); device->pipeline_descriptor_set_requirements[pipeline->name] += n; + if (!pipeline->compiled) { + pipeline->needed = true; + device->need_compiles = true; + } } static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { @@ -1385,8 +1387,6 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); - std::cerr << "ggml_vulkan: Compiling shaders"; - // some shaders have a minimum subgroup size const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); @@ -1524,15 +1524,33 @@ static void ggml_vk_load_shaders(vk_device& device) { } } - device->pipeline_matmul_f32 = std::make_shared(); - device->pipeline_matmul_f32_f16 = std::make_shared(); - - device->pipeline_matmul_id_f32 = std::make_shared(); + if (!device->pipeline_matmul_f32) { + device->pipeline_matmul_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_f32_f16) { + device->pipeline_matmul_f32_f16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_f32) { + device->pipeline_matmul_id_f32 = std::make_shared(); + } std::vector> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + + if (!pipeline) { + pipeline = std::make_shared(); + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + } + + if (!pipeline->needed || pipeline->compiled) { + return; + } { // wait until fewer than N compiles are in progress uint32_t N = std::max(1u, std::thread::hardware_concurrency()); @@ -1609,11 +1627,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ - CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -1626,21 +1640,18 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 } else @@ -1677,31 +1688,31 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. @@ -1711,31 +1722,31 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } else { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } } #undef CREATE_MM2 @@ -1965,6 +1976,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); @@ -2002,7 +2027,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); @@ -2040,7 +2065,7 @@ static void ggml_vk_load_shaders(vk_device& device) { for (auto &c : compiles) { c.wait(); } - std::cerr << "Done!" << std::endl; + device->need_compiles = false; } static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); @@ -2268,6 +2293,14 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + VkPhysicalDeviceMaintenance4Features maint4_features {}; + maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&maint4_features; + last_struct = (VkBaseOutStructure *)&maint4_features; + device_extensions.push_back("VK_KHR_maintenance4"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; @@ -2643,7 +2676,14 @@ void ggml_vk_instance_init() { vk_instance_initialized = true; - vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; + uint32_t api_version = vk::enumerateInstanceVersion(); + + if (api_version < VK_API_VERSION_1_2) { + std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; + GGML_ABORT("fatal error"); + } + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); @@ -2953,7 +2993,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } - GGML_ASSERT(src1_type == GGML_TYPE_F32); + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { case GGML_TYPE_Q4_0: @@ -3689,6 +3729,33 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f16_f16; } } + if (src->type == GGML_TYPE_F32) { + switch (to) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_f32_quant[to]; + default: + break; + } + } + + if (to == GGML_TYPE_F32) { + switch (src->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_quant_f32[src->type]; + default: + break; + } + } std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; GGML_ABORT("fatal error"); @@ -3766,8 +3833,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub src1_uma = d_Qy != nullptr; } - const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); - // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src1); @@ -4347,8 +4415,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ids_uma = d_ids != nullptr; } - const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src1); const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; @@ -4358,7 +4429,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; if (qx_needs_dequant) { - GGML_ABORT("fatal error"); + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); } // Not implemented @@ -4766,7 +4838,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } assert(pipelines); - bool aligned = (KV % pipelines[1]->align) == 0; + const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); + const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); + const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + + bool aligned = (KV % pipelines[1]->align) == 0 && + // the "aligned" shader variant will forcibly align strides, for performance + (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); @@ -4802,15 +4881,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->device->uma) { ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); - ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset); - ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset); - ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset); + ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); Q_uma = d_Q != nullptr; K_uma = d_K != nullptr; V_uma = d_V != nullptr; D_uma = d_D != nullptr; if (mask) { - ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset); + ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); M_uma = d_M != nullptr; } } @@ -4848,7 +4927,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } - const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 }; + const vk_flash_attn_push_constants pc = { N, KV, + (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, + (uint32_t)neq2, (uint32_t)neq3, + (uint32_t)nek2, (uint32_t)nek3, + (uint32_t)nev2, (uint32_t)nev3, + nem1, + q_stride, (uint32_t)nbq2, (uint32_t)nbq3, + k_stride, (uint32_t)nbk2, (uint32_t)nbk3, + v_stride, (uint32_t)nbv2, (uint32_t)nbv3, + nbm1, + scale, max_bias, logit_softcap, + mask != nullptr, n_head_log2, m0, m1 }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, @@ -5160,7 +5250,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); - GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT GGML_ASSERT(dst->buffer != nullptr); const uint64_t ne00 = src0->ne[0]; @@ -7581,6 +7671,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg for (int i = 0; i < cgraph->n_nodes; i++) { ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } ggml_vk_preallocate_buffers(ctx); ggml_pipeline_allocate_descriptor_sets(ctx->device); @@ -7905,12 +7998,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm { ggml_type src0_type = op->src[0]->type; ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { - return true; + + if (src0_type == GGML_TYPE_F32) { + switch (src1_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } } - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { - return true; + if (src1_type == GGML_TYPE_F32) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { return true; } @@ -8601,6 +8718,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { ggml_tensor * src0 = tensor->src[0]; ggml_tensor * src1 = tensor->src[1]; ggml_tensor * src2 = tensor->src[2]; + ggml_tensor * src3 = tensor->src[3]; void * tensor_data = tensor->data; @@ -8663,6 +8781,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); @@ -8707,6 +8828,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); @@ -8729,6 +8853,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); diff --git a/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index bd0c74cb1..074031087 100644 --- a/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -1,9 +1,11 @@ find_package (Threads REQUIRED) -find_package(Vulkan COMPONENTS glslc REQUIRED) +find_program(GLSLC_EXECUTABLE glslc) +if(NOT GLSLC_EXECUTABLE) + message(FATAL_ERROR "glslc not found.") +endif() set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) target_compile_features(${TARGET} PRIVATE cxx_std_17) target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) -target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan) diff --git a/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp new file mode 100644 index 000000000..c09bf496b --- /dev/null +++ b/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" +#include "dequant_funcs.comp" + +#if defined(DATA_A_IQ4_NL) +// 16 invocations needed for init_iq4nl_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +void main() { +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = get_doffset() + dst_idx(idx); + uint src_idx = src0_idx_quant(idx, QUANT_K); + + const uint a_offset = 0; + const uint ib = src_idx; + const vec2 dm = get_dm(ib, a_offset); + + [[unroll]] for (int j = 0; j < QUANT_K; j += 4) { + vec4 v = dequantize4(ib, j / QUANT_R, a_offset); + v = v * dm.x + vec4(dm.y); + +#if QUANT_R == 2 + data_d[dst_idx + j/2 + 0] = v[0]; + data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1]; + data_d[dst_idx + j/2 + 1] = v[2]; + data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3]; +#else + data_d[dst_idx + j + 0] = v[0]; + data_d[dst_idx + j + 1] = v[1]; + data_d[dst_idx + j + 2] = v[2]; + data_d[dst_idx + j + 3] = v[3]; +#endif + } +} diff --git a/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp new file mode 100644 index 000000000..ccf5b980a --- /dev/null +++ b/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -0,0 +1,237 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +#if defined(DATA_A_IQ4_NL) +// 16 invocations needed for init_iq4nl_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +layout (binding = 0) readonly buffer S {float data_s[];}; +layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; + +#if defined(DATA_A_Q4_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id; + + const uint xi0 = min(15, int(x0 + 8.5)); + const uint xi1 = min(15, int(x1 + 8.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q4_1) +void quantize(uint dst_idx, uint src_idx) +{ + float vmin = 1.0/0.0; + float vmax = -vmin; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) { + const float v = data_s[src_idx + j]; + + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(vmin); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - vmin)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id; + + const uint xi0 = min(15, int(x0 + 0.5)); + const uint xi1 = min(15, int(x1 + 0.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q5_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id; + + const uint xi0 = min(31, int(x0 + 16.5)); + const uint xi1 = min(31, int(x1 + 16.5)); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2); + } + data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF); + data_q[dst_idx].qh[1] = uint16_t(qh >> 16); +} +#endif + +#if defined(DATA_A_Q5_1) +void quantize(uint dst_idx, uint src_idx) +{ + float min = data_s[src_idx + 0]; + float max = min; + + [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) { + const float v = data_s[src_idx + j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = (d != 0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(min); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - min)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id; + + const uint xi0 = uint(x0 + 0.5); + const uint xi1 = uint(x1 + 0.5); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2); + } + data_q[dst_idx].qh = qh; +} +#endif + +#if defined(DATA_A_Q8_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; // absolute max + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) { + const float v = data_s[src_idx + j]; + amax = max(amax, abs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) { + const float x0 = data_s[src_idx + j]*id; + + data_q[dst_idx].qs[j] = int8_t(round(x0)); + } +} +#endif + +#if defined(DATA_A_IQ4_NL) +uint best_index(float x) { + if (x <= kvalues_iq4nl[0]) return 0; + if (x >= kvalues_iq4nl[15]) return 15; + int ml = 0, mu = 15; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav; + } + return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu; +} + +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + float sumqx = 0, sumq2 = 0; + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id; + const uint xi0 = best_index(x0); + const uint xi1 = best_index(x1); + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j]; + const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d); + +} +#endif + +void main() { +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = dst_idx_quant(idx, QUANT_K); + uint src_idx = get_aoffset() + src0_idx(idx); + + quantize(dst_idx, src_idx); +} diff --git a/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 94b78598e..175e31fa7 100644 --- a/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -101,19 +101,25 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_ block_q2_K block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 { + block_q2_K_packed16 block; +}; + float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { + decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); const f16vec2 d = bl.block.d; const uint idx = coordInBlock[1]; - const uint iqs = idx; - const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31 - const uint scalesi = iqs / 16; // 0..15 - const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6 + const uint scalesi = (idx & 0xF0) >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qs = (qs >> qsshift) & 0x0303; + qs = unpack8(qs)[idx & 1]; - uint32_t qs = bl.block.qs[qsi]; const uint scales = bl.block.scales[scalesi]; - float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4); + float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); return ret; } @@ -157,39 +163,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { + block_q4_K_packed128 block; +}; + float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); const uint idx = coordInBlock[1]; const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 - const f16vec2 loadd = bl.block.d; + uvec4 v = bl128.block.q4k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); uint32_t sc; uint32_t mbyte; - uint32_t scidx0 = (is < 4) ? is : (is + 4); - uint32_t scidx1 = (is < 4) ? is : (is - 4); - uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t scidxshift1 = (is < 4) ? 0 : 2; - uint32_t mbidx0 = is + 4; - uint32_t mbidx1 = (is < 4) ? is + 4 : is; - uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; - uint32_t mbidxshift0 = (is < 4) ? 0 : 4; - uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t mbidxshift1 = (is < 4) ? 0 : 2; + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; - sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); - mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; const float16_t d = loadd.x * float16_t(sc); const float16_t m = loadd.y * float16_t(mbyte); uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); - qs = (qs >> (b * 4)) & 0x0F0F; - qs = unpack8(qs)[idx & 1]; + qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; float16_t ret = d * float16_t(qs) - m; @@ -204,47 +218,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5 block_q5_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 { + block_q5_K_packed128 block; +}; + float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); const uint idx = coordInBlock[1]; const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 - const uint32_t hm = 0x0101 << is; + uvec4 v = bl128.block.q5k[0]; - const f16vec2 loadd = bl.block.d; + const f16vec2 loadd = unpackFloat2x16(v.x); uint32_t sc; uint32_t mbyte; - uint32_t scidx0 = (is < 4) ? is : (is + 4); - uint32_t scidx1 = (is < 4) ? is : (is - 4); - uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t scidxshift1 = (is < 4) ? 0 : 2; - uint32_t mbidx0 = is + 4; - uint32_t mbidx1 = (is < 4) ? is + 4 : is; - uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; - uint32_t mbidxshift0 = (is < 4) ? 0 : 4; - uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t mbidxshift1 = (is < 4) ? 0 : 2; + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); - sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); - mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; const float16_t d = loadd.x * float16_t(sc); const float16_t m = loadd.y * float16_t(mbyte); uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); - qh = qh & hm; - qh = unpack8(qh)[idx & 1]; + qh = ((qh >> is) & 0x101) << 4; uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); qs = (qs >> (b * 4)) & 0x0F0F; - qs = unpack8(qs)[idx & 1]; + qs = unpack8(qs | qh)[idx & 1]; - float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m; + float16_t ret = d * (float16_t(qs)) - m; return ret; } diff --git a/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp index 4e68742b5..26d8bc22a 100644 --- a/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +++ b/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -12,7 +12,7 @@ layout (push_constant) uniform parameter #include "types.comp" -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; diff --git a/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index c5be8131b..3735d0dbb 100644 --- a/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -42,10 +42,13 @@ layout (push_constant) uniform parameter { uint32_t nev3; uint32_t nem1; + uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t nb21; uint32_t nb22; uint32_t nb23; uint32_t nb31; @@ -146,7 +149,24 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); - coopmat Q; + // nb?1 are already divided by the type size and are in units of elements + uint32_t q_stride = p.nb01; + uint32_t k_stride = p.nb11; + uint32_t v_stride = p.nb21; + // hint to the compiler that strides are aligned for the aligned variant of the shader + if (Clamp != gl_CooperativeMatrixClampModeConstantNV) + { + q_stride &= ~7; +#if !defined(BLOCK_SIZE) + k_stride &= ~7; + v_stride &= ~7; +#endif + } + tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); + tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); + tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); + + coopmat Q; coopmat Qf16; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; diff --git a/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp index 68d1bc9f1..8dc9d360d 100644 --- a/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +++ b/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp @@ -54,3 +54,23 @@ uint dst_idx(uint idx) { const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; } + +uint src0_idx_quant(uint idx, uint qk) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00; +} + +uint dst_idx_quant(uint idx, uint qk) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10; +} diff --git a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 6a9b9b2d1..8cdc640e8 100644 --- a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -5,6 +5,80 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + barrier(); + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + } + barrier(); + + if (i >= num_blocks_per_row) + continue; + } else { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + barrier(); + } + + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im], + fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im], + fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im], + fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im], + fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im], + fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im], + fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -14,88 +88,28 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; - - const uint step = 8; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; - const uint s_offset = 8*v_im; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - - uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; - uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; - - uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; - uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; - uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; - uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; - - uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); - uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); - uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); - uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); - - uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; - uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; - uvec2 qs0 = uvec2(unpack8(qs0_u16)); - uvec2 qs16 = uvec2(unpack8(qs16_u16)); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); - vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); - vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); - vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); - vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); - vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); - vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); - vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); - - FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); - FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); - } - temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); - } - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 96ef50fdd..3116fad16 100644 --- a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -5,6 +5,74 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + barrier(); + if (i < num_blocks_per_row) + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); + + // 0, 1, 16, 17 + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + if (all_threads) { + barrier(); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -14,76 +82,37 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; - - const uint step = 8; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + const uint itid8 = itid%8; - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_im4 = v_im*4; + const uint v_in = itid - 8*v_im; // 0...7 - const uint8_t m = uint8_t(1 << (4 * v_im)); + const uint32_t m = 0x01010101 << (4 * v_im); + uint32_t hm_m[4]; + [[unroll]] for (uint j = 0; j < 4; ++j) + hm_m[j] = m << j; const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - const uint s_shift = 4 * v_im; - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0]; - uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1]; - uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2]; - uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3]; - uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4]; - uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5]; - u8vec2 s0 = unpack8(s0_16); - u8vec2 s2 = unpack8(s2_16); - u8vec2 s4 = unpack8(s4_16); - u8vec2 s6 = unpack8(s6_16); - u8vec2 s8 = unpack8(s8_16); - u8vec2 s10 = unpack8(s10_16); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - - vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); - vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); - vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); - vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); - vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); - vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); - vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); - vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); - } - temp[j][n] = fma(d, sum, temp[j][n]); - } - } - } + const uint s_shift = v_im4 + 2*(itid8/4); + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index f97eb8744..f9cde0648 100644 --- a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -6,6 +6,86 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); + const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); + const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); + const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_lo4.x; + const FLOAT_TYPE q4_1 = qs0_lo4.y; + const FLOAT_TYPE q4_2 = qs0_lo4.z; + const FLOAT_TYPE q4_3 = qs0_lo4.w; + const FLOAT_TYPE q4_4 = qs0_hi4.x; + const FLOAT_TYPE q4_5 = qs0_hi4.y; + const FLOAT_TYPE q4_6 = qs0_hi4.z; + const FLOAT_TYPE q4_7 = qs0_hi4.w; + const FLOAT_TYPE q4_8 = qs64_lo4.x; + const FLOAT_TYPE q4_9 = qs64_lo4.y; + const FLOAT_TYPE q4_10 = qs64_lo4.z; + const FLOAT_TYPE q4_11 = qs64_lo4.w; + const FLOAT_TYPE q4_12 = qs64_hi4.x; + const FLOAT_TYPE q4_13 = qs64_hi4.y; + const FLOAT_TYPE q4_14 = qs64_hi4.z; + const FLOAT_TYPE q4_15 = qs64_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); + vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); + vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); + vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -15,13 +95,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint step = 4; - - const uint il = itid/step; // 0...3 - const uint ir = itid - step*il; // 0...7 or 0...3 + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 const uint n = 4; const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 @@ -31,89 +109,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - - uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec4 scale0 = uvec4(unpack8(scale0_u32)); - uvec4 scale4 = uvec4(unpack8(scale4_u32)); - uvec4 scale8 = uvec4(unpack8(scale8_u32)); - - const uint32_t sc0 = ( scale0.x & 0x3f); - const uint32_t sc1 = ( scale0.y & 0x3f); - const uint32_t sc2 = ( scale4.x & 0x3f); - const uint32_t sc3 = ( scale4.y & 0x3f); - const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); - - uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; - uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; - - uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; - uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; - uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; - uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; - - uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4)); - uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4)); - uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4)); - uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4)); - - const uint32_t q4_0 = qs0_lo4.x; - const uint32_t q4_1 = qs0_lo4.y; - const uint32_t q4_2 = qs0_lo4.z; - const uint32_t q4_3 = qs0_lo4.w; - const uint32_t q4_4 = qs0_hi4.x; - const uint32_t q4_5 = qs0_hi4.y; - const uint32_t q4_6 = qs0_hi4.z; - const uint32_t q4_7 = qs0_hi4.w; - const uint32_t q4_8 = qs64_lo4.x; - const uint32_t q4_9 = qs64_lo4.y; - const uint32_t q4_10 = qs64_lo4.z; - const uint32_t q4_11 = qs64_lo4.w; - const uint32_t q4_12 = qs64_hi4.x; - const uint32_t q4_13 = qs64_hi4.y; - const uint32_t q4_14 = qs64_hi4.z; - const uint32_t q4_15 = qs64_hi4.w; - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); - vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); - vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); - vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); - - const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, - fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, - fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, - fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); - } - } - } + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 79d7db0e3..6c84ef3cd 100644 --- a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -6,6 +6,118 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + + uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; + uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; + uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); + const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + + const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); + const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); + const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); + const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_16_lo4.x; + const FLOAT_TYPE q4_1 = qs0_16_lo4.y; + const FLOAT_TYPE q4_2 = qs0_16_lo4.z; + const FLOAT_TYPE q4_3 = qs0_16_lo4.w; + const FLOAT_TYPE q4_4 = qs0_16_hi4.x; + const FLOAT_TYPE q4_5 = qs0_16_hi4.y; + const FLOAT_TYPE q4_6 = qs0_16_hi4.z; + const FLOAT_TYPE q4_7 = qs0_16_hi4.w; + const FLOAT_TYPE q4_8 = qs64_80_lo4.x; + const FLOAT_TYPE q4_9 = qs64_80_lo4.y; + const FLOAT_TYPE q4_10 = qs64_80_lo4.z; + const FLOAT_TYPE q4_11 = qs64_80_lo4.w; + const FLOAT_TYPE q4_12 = qs64_80_hi4.x; + const FLOAT_TYPE q4_13 = qs64_80_hi4.y; + const FLOAT_TYPE q4_14 = qs64_80_hi4.z; + const FLOAT_TYPE q4_15 = qs64_80_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); + vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); + vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); + vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); + vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); + vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); + vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); + vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -15,11 +127,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; const uint il = itid/4; // 0...3 - const uint ir = itid - 4*il; // 0...7 or 0...3 + const uint ir = itid - 4*il; // 0...3 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; @@ -28,121 +140,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - - uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec4 scale0 = uvec4(unpack8(scale0_u32)); - uvec4 scale4 = uvec4(unpack8(scale4_u32)); - uvec4 scale8 = uvec4(unpack8(scale8_u32)); - - const uint32_t sc0 = ( scale0.x & 0x3f); - const uint32_t sc1 = ( scale0.y & 0x3f); - const uint32_t sc2 = ( scale4.x & 0x3f); - const uint32_t sc3 = ( scale4.y & 0x3f); - const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); - - uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); - uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); - - uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; - uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; - uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; - uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; - - uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); - - uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; - uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; - uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0; - uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; - - qs0_16_u32_lo4 += qs0_16_lo4_offset16; - qs0_16_u32_hi4 += qs0_16_hi4_offset16; - qs64_80_u32_lo4 += qs64_80_lo4_offset16; - qs64_80_u32_hi4 += qs64_80_hi4_offset16; - - uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); - uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); - uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); - uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4)); - - const uint32_t q4_0 = qs0_16_lo4.x; - const uint32_t q4_1 = qs0_16_lo4.y; - const uint32_t q4_2 = qs0_16_lo4.z; - const uint32_t q4_3 = qs0_16_lo4.w; - const uint32_t q4_4 = qs0_16_hi4.x; - const uint32_t q4_5 = qs0_16_hi4.y; - const uint32_t q4_6 = qs0_16_hi4.z; - const uint32_t q4_7 = qs0_16_hi4.w; - const uint32_t q4_8 = qs64_80_lo4.x; - const uint32_t q4_9 = qs64_80_lo4.y; - const uint32_t q4_10 = qs64_80_lo4.z; - const uint32_t q4_11 = qs64_80_lo4.w; - const uint32_t q4_12 = qs64_80_hi4.x; - const uint32_t q4_13 = qs64_80_hi4.y; - const uint32_t q4_14 = qs64_80_hi4.z; - const uint32_t q4_15 = qs64_80_hi4.w; - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); - vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); - vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); - vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); - vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); - vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); - vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); - vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); - - const FLOAT_TYPE sx = - fma(FLOAT_TYPE(by10.x), q4_0, - fma(FLOAT_TYPE(by10.y), q4_1, - fma(FLOAT_TYPE(by116.x), q4_2, - FLOAT_TYPE(by116.y) * q4_3))); - const FLOAT_TYPE sy = - fma(FLOAT_TYPE(by132.x), q4_4, - fma(FLOAT_TYPE(by132.y), q4_5, - fma(FLOAT_TYPE(by148.x), q4_6, - FLOAT_TYPE(by148.y) * q4_7))); - const FLOAT_TYPE sz = - fma(FLOAT_TYPE(by20.x), q4_8, - fma(FLOAT_TYPE(by20.y), q4_9, - fma(FLOAT_TYPE(by216.x), q4_10, - FLOAT_TYPE(by216.y) * q4_11))); - const FLOAT_TYPE sw = - fma(FLOAT_TYPE(by232.x), q4_12, - fma(FLOAT_TYPE(by232.y), q4_13, - fma(FLOAT_TYPE(by248.x), q4_14, - FLOAT_TYPE(by248.y) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, - fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, - fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, - (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); - } - } - } + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index 041fd27c1..f05f96b5e 100644 --- a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -6,7 +6,77 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + barrier(); + if (i < num_blocks_per_row) + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + const uint32_t qh4_u32 = (qh_u32 & 0x30303030); + const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + const vec4 q0 = vec4(unpack8(q0_u32)) - 32; + const vec4 q1 = vec4(unpack8(q1_u32)) - 32; + const vec4 q2 = vec4(unpack8(q2_u32)) - 32; + const vec4 q3 = vec4(unpack8(q3_u32)) - 32; + + if (all_threads) { + barrier(); + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]); + vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]); + vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); + vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); + + FLOAT_TYPE sum[4] = {0, 0, 0, 0}; + [[unroll]] for (uint l = 0; l < 4; ++l) { + sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]); + sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]); + sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); + sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); + } + temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); + } + } +} + +void compute_outputs(const uint first_row, const uint num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -15,13 +85,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint step = 8; - - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 const uint is = v_in / 4; @@ -31,68 +99,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint s_offset = 8*v_im + is; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - FLOAT_TYPE scales[4]; - scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); - scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); - scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); - scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); - - uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); - uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); - - uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; - uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; - uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; - uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; - - uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); - uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; - uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; - uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; - uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; - - uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; - uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; - uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; - uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; - - uvec4 q0 = uvec4(unpack8(q0_u32)); - uvec4 q1 = uvec4(unpack8(q1_u32)); - uvec4 q2 = uvec4(unpack8(q2_u32)); - uvec4 q3 = uvec4(unpack8(q3_u32)); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]); - vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]); - vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); - vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 4; ++l) { - sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), - fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), - fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), - fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); - } - temp[j][n] += sum * d; - } - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index cbfa5dce1..57f9e7245 100644 --- a/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -57,17 +57,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #if QUANT_K > 1 #define DECODEFUNCA , dequantFuncA -#define MAT_A_TYPE float16_t #include "dequant_funcs_cm2.comp" #else #define DECODEFUNCA -#define MAT_A_TYPE A_TYPE #endif -#define MAT_B_TYPE B_TYPE - #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; @@ -236,16 +232,13 @@ void main() { for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopmat mat_a_ft = coopmat(mat_a); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); - coopmat mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } } else #endif // !defined(MUL_MAT_ID) @@ -261,10 +254,8 @@ void main() { [[dont_unroll]] for (uint block_k = start_k; block_k < end_k; block_k += BK) { - coopmat mat_a; - coopmat mat_b; - coopmat mat_a_ft; - coopmat mat_b_ft; + coopmat mat_a; + coopmat mat_b; // Clamping is expensive, so detect different code paths for each combination // of A and B needing clamping. @@ -281,16 +272,12 @@ void main() { #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); #endif - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (unclampedA && !unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (!unclampedA && unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID @@ -298,16 +285,12 @@ void main() { #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); #endif - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (!unclampedA && !unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } } } diff --git a/src/ggml-vulkan/vulkan-shaders/types.comp b/src/ggml-vulkan/vulkan-shaders/types.comp index f12e61bbe..1e35b6652 100644 --- a/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/src/ggml-vulkan/vulkan-shaders/types.comp @@ -227,6 +227,11 @@ struct block_q4_K_packed32 uint32_t qs[QUANT_K_Q4_K/2/4]; }; +struct block_q4_K_packed128 +{ + uvec4 q4k[9]; +}; + #if defined(DATA_A_Q4_K) #define QUANT_K QUANT_K_Q4_K #define A_TYPE block_q4_K @@ -252,6 +257,11 @@ struct block_q5_K_packed16 uint16_t qs[QUANT_K_Q5_K/2/2]; }; +struct block_q5_K_packed128 +{ + uvec4 q5k[11]; +}; + #if defined(DATA_A_Q5_K) #define QUANT_K QUANT_K_Q5_K #define A_TYPE block_q5_K diff --git a/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 7b5044798..e9c6cb9d4 100644 --- a/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -17,21 +17,19 @@ #include #include #include +#include #include #include #ifdef _WIN32 #include #include // For _mkdir on Windows - #include // For std::replace on w64devkit #else #include #include #include #endif -#include - #define ASYNCIO_CONCURRENCY 64 std::mutex lock; @@ -318,8 +316,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool // For aligned matmul loads std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + // don't generate f32 variants for coopmat2 + if (!coopmat2) { + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } if (tname != "f16" && tname != "f32") { string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); @@ -419,6 +420,11 @@ void process_shaders() { string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + } + string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); @@ -496,6 +502,7 @@ void write_output_files() { fprintf(hdr, "#include \n\n"); fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + std::sort(shader_fnames.begin(), shader_fnames.end()); for (const auto& pair : shader_fnames) { const std::string& name = pair.first; #ifdef _WIN32 diff --git a/src/ggml.c b/src/ggml.c index ae8147dd7..3b4861542 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3454,12 +3454,14 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } -// ggml_soft_max_back +// ggml_soft_max_ext_back -static struct ggml_tensor * ggml_soft_max_back_impl( +static struct ggml_tensor * ggml_soft_max_ext_back_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + float scale, + float max_bias, bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -3467,21 +3469,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl( result->src[0] = a; result->src[1] = b; + memcpy((float *) result->op_params + 0, &scale, sizeof(float)); + memcpy((float *) result->op_params + 1, &max_bias, sizeof(float)); + return result; } -struct ggml_tensor * ggml_soft_max_back( +struct ggml_tensor * ggml_soft_max_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_soft_max_back_impl(ctx, a, b, false); + struct ggml_tensor * b, + float scale, + float max_bias) { + return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false); } -struct ggml_tensor * ggml_soft_max_back_inplace( +struct ggml_tensor * ggml_soft_max_ext_back_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_soft_max_back_impl(ctx, a, b, true); + struct ggml_tensor * b, + float scale, + float max_bias) { + return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true); } // ggml_rope @@ -3699,7 +3708,7 @@ void ggml_rope_yarn_corr_dims( // ggml_rope_back -struct ggml_tensor * ggml_rope_back( +struct ggml_tensor * ggml_rope_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -3713,29 +3722,32 @@ struct ggml_tensor * ggml_rope_back( float attn_factor, float beta_fast, float beta_slow) { - GGML_ASSERT(ggml_is_vector(b)); - GGML_ASSERT(b->type == GGML_TYPE_I32); - GGML_ASSERT(a->ne[2] == b->ne[0]); - - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - - int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; - memcpy(params + 5, &freq_base, sizeof(float)); - memcpy(params + 6, &freq_scale, sizeof(float)); - memcpy(params + 7, &ext_factor, sizeof(float)); - memcpy(params + 8, &attn_factor, sizeof(float)); - memcpy(params + 9, &beta_fast, sizeof(float)); - memcpy(params + 10, &beta_slow, sizeof(float)); - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_ROPE_BACK; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - + struct ggml_tensor * result = ggml_rope_ext( + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + result->op = GGML_OP_ROPE_BACK; return result; } +struct ggml_tensor * ggml_rope_multi_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + struct ggml_tensor * result = ggml_rope_multi( + ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + result->op = GGML_OP_ROPE_BACK; + return result; +} // ggml_clamp struct ggml_tensor * ggml_clamp( @@ -5077,10 +5089,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back( struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_tensor * c) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - GGML_ASSERT(ggml_is_scalar(c)); + GGML_ASSERT(ggml_is_scalar(a)); + GGML_ASSERT(ggml_are_same_shape(b, c)); - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_dup_tensor(ctx, b); result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; result->src[0] = a; @@ -5259,7 +5271,7 @@ static void ggml_sub_or_set( } static void ggml_compute_backward( - struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) { + struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) { struct ggml_tensor * tensor = cgraph->nodes[i]; struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor); @@ -5331,7 +5343,7 @@ static void ggml_compute_backward( } break; case GGML_OP_MUL: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1)); } if (src1_needs_grads) { struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); @@ -5403,7 +5415,7 @@ static void ggml_compute_backward( if (src0_needs_grads) { float eps; memcpy(&eps, tensor->op_params, sizeof(float)); - ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps)); } } break; case GGML_OP_MUL_MAT: { @@ -5423,21 +5435,25 @@ static void ggml_compute_backward( // src1.shape [n,p,qq,rr] if (src0_needs_grads) { - struct ggml_tensor * s1_tg = + GGML_ASSERT(grad->ne[2] == src1->ne[2]); + GGML_ASSERT(grad->ne[3] == src1->ne[3]); + struct ggml_tensor * tmp = ggml_out_prod(ctx, // [n,m,qq,rr] src1, // [n,p,qq,rr] grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = ggml_repeat_back(ctx, s1_tg, src0); + if (!ggml_are_same_shape(tmp, src0)) { + GGML_ASSERT(tmp->ne[0] == src0->ne[0]); + GGML_ASSERT(tmp->ne[1] == src0->ne[1]); + GGML_ASSERT(tmp->ne[3] == 1); + + const int64_t nr2 = tmp->ne[2] / src0->ne[2]; + const size_t nb2 = tmp->nb[2] * nr2; + const size_t nb3 = tmp->nb[2]; + + tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0); + tmp = ggml_repeat_back(ctx, tmp, src0); } - ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/); + ggml_add_or_set(ctx, cgraph, isrc0, tmp); } if (src1_needs_grads) { ggml_add_or_set(ctx, cgraph, isrc1, @@ -5506,7 +5522,9 @@ static void ggml_compute_backward( if (src0_needs_grads) { GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); GGML_ASSERT(ggml_is_contiguous(grad)); - ggml_add_or_set(ctx, cgraph, isrc0, grad); + GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0)); + ggml_add_or_set(ctx, cgraph, isrc0, + ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0)); } } break; case GGML_OP_RESHAPE: { @@ -5586,7 +5604,13 @@ static void ggml_compute_backward( } break; case GGML_OP_SOFT_MAX: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float)); + + ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias)); } GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented"); } break; @@ -5598,6 +5622,7 @@ static void ggml_compute_backward( //const int n_ctx = ((int32_t *) tensor->op_params)[3]; const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4]; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4] = {0, 0, 0, 0}; memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float)); @@ -5605,10 +5630,14 @@ static void ggml_compute_backward( memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float)); memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float)); - - ggml_add_or_set(ctx, cgraph, isrc0, - ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base, - freq_scale, ext_factor, attn_factor, beta_fast, beta_slow)); + memcpy(§ions, tensor->op_params + 11, sizeof(sections)); + + struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ? + ggml_rope_ext_back(ctx, grad, src1, src2, n_dims, + mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) : + ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections, + mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + ggml_add_or_set(ctx, cgraph, isrc0, rope_back); } GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented"); } break; @@ -5622,7 +5651,7 @@ static void ggml_compute_backward( const int32_t d1 = ggml_get_op_params_i32(tensor, 5); const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; - ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); + ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); } } break; case GGML_OP_POOL_2D: { @@ -5665,7 +5694,7 @@ static void ggml_compute_backward( } break; case GGML_UNARY_OP_SILU: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0)); } } break; case GGML_UNARY_OP_EXP: { @@ -5682,7 +5711,7 @@ static void ggml_compute_backward( } break; case GGML_OP_CROSS_ENTROPY_LOSS: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1)); } GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); } break; diff --git a/src/gguf.cpp b/src/gguf.cpp index 655ed600a..ab13669c5 100644 --- a/src/gguf.cpp +++ b/src/gguf.cpp @@ -648,6 +648,10 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par ok = ok && data != nullptr; + if (ok) { + ggml_set_name(data, "GGUF tensor data binary blob"); + } + // read the binary blob with the tensor data ok = ok && gr.read(data->data, ctx->size); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3834e0f84..4c5c4dd9c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -780,7 +780,7 @@ struct test_case { } } if (!any_params) { - printf("not supported [%s] \n", op_name); + printf("not supported [%s] \n", op_desc(out).c_str()); supported = false; } if (!supported) { @@ -1130,6 +1130,59 @@ struct test_get_rows : public test_case { } }; +// GGML_OP_GET_ROWS_BACK +struct test_get_rows_back : public test_case { + const ggml_type type; + const int n; // cols + const int m; // rows + const int r; // rows to get + const int b; // batch size + const bool v; // view (non-contiguous src1) + + std::string vars() override { + return VARS_TO_STR6(type, n, m, r, b, v); + } + + test_get_rows_back(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false) + : type(type), n(n), m(m), r(r), b(b), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * in_forward = ggml_new_tensor_3d(ctx, type, n, m, b); + ggml_set_name(in_forward, "in_forward"); + + ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b); + ggml_set_name(rows, "rows"); + if (v) { + rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0); + ggml_set_name(rows, "view_of_rows"); + } + + ggml_tensor * grad = ggml_new_tensor_3d(ctx, type, n, r, b); + ggml_set_name(grad, "grad"); + + ggml_tensor * out = ggml_get_rows_back(ctx, grad, rows, in_forward); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // rows + std::vector data(r*b); + for (int i = 0; i < r*b; i++) { + data[i] = rand() % m; + } + ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int)); + } else { + init_tensor_uniform(t); + } + } + } +}; + // GGML_OP_ARGMAX struct test_argmax : public test_case { const ggml_type type; @@ -1249,6 +1302,59 @@ struct test_repeat : public test_case { } }; +// GGML_OP_REPEAT_BACK +struct test_repeat_back : public test_case { + const ggml_type type; + const std::array ne; + const std::array nr; + const bool v; // whether src is a noncontiguous view + + std::string vars() override { + return VARS_TO_STR4(type, ne, nr, v); + } + + size_t op_size(ggml_tensor * t) override { + return ggml_nbytes(t) * 2; + } + + test_repeat_back(ggml_type type = GGML_TYPE_F32, + std::array ne = {8, 6, 4, 2}, + std::array nr = {2, 2, 2, 2}, + bool v = false) + : type(type), ne(ne), nr(nr), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]); + ggml_set_name(src, "src"); + + if (v) { + GGML_ASSERT(ne[0] % 2 == 0); + GGML_ASSERT(ne[1] % 2 == 0); + GGML_ASSERT(ne[2] % 2 == 0); + GGML_ASSERT(ne[3] % 2 == 0); + GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1); + GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1); + GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1); + GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1); + + const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2; + const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2; + const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2; + const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2; + + src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0); + } + + ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(target, "target"); + + ggml_tensor * out = ggml_repeat_back(ctx, src, target); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_DUP struct test_dup : public test_case { const ggml_type type; @@ -1531,6 +1637,39 @@ struct test_scale : public test_case { } }; +// GGML_OP_SILU_BACK +struct test_silu_back : public test_case { + const ggml_type type; + const std::array ne; + float eps; + + std::string vars() override { + return VARS_TO_STR3(type, ne, eps); + } + + test_silu_back(ggml_type type = GGML_TYPE_F32, + std::array ne = {64, 5, 4, 3}, + float eps = 1e-6f) + : type(type), ne(ne), eps(eps) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * grad = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(grad, "grad"); + + ggml_tensor * out = ggml_silu_back(ctx, a, grad); + ggml_set_name(out, "out"); + + return out; + } + + bool grad_precise() override { + return true; + } +}; + // GGML_OP_NORM struct test_norm : public test_case { const ggml_type type; @@ -1583,11 +1722,56 @@ struct test_rms_norm : public test_case { return out; } + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -10.f, 10.f); + } + } + + float grad_eps() override { + return 1.0f; + } + bool grad_precise() override { return true; } }; +// GGML_OP_RMS_NORM_BACK +struct test_rms_norm_back : public test_case { + const ggml_type type; + const std::array ne; + float eps; + + std::string vars() override { + return VARS_TO_STR3(type, ne, eps); + } + + test_rms_norm_back(ggml_type type = GGML_TYPE_F32, + std::array ne = {64, 5, 4, 3}, + float eps = 1e-6f) + : type(type), ne(ne), eps(eps) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(b, "b"); + + ggml_tensor * out = ggml_rms_norm_back(ctx, a, b, eps); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -10.f, 10.f); + } + } +}; + // GGML_OP_SSM_CONV struct test_ssm_conv : public test_case { const ggml_type type; @@ -1718,6 +1902,10 @@ struct test_mul_mat : public test_case { return 5e-4; } + int64_t grad_nmax() override { + return 20000; + } + uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1]; @@ -1747,8 +1935,12 @@ struct test_mul_mat : public test_case { a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]); b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + if (!ggml_is_quantized(type_a)) { + if (bs[1] == 1 && nr[1] == 1) { + ggml_set_param(ctx, a); + } + ggml_set_param(ctx, b); + } ggml_set_name(a, "a"); ggml_set_name(b, "b"); @@ -1759,8 +1951,12 @@ struct test_mul_mat : public test_case { } else { a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + if (!ggml_is_quantized(type_a)) { + if (bs[1] == 1 && nr[1] == 1) { + ggml_set_param(ctx, a); + } + ggml_set_param(ctx, b); + } ggml_set_name(a, "a"); ggml_set_name(b, "b"); } @@ -1855,10 +2051,11 @@ struct test_out_prod : public test_case { const int64_t n; const int64_t k; const std::array bs; // dims 3 and 4 + const std::array nr; // repeat in dims 3 and 4 const bool trans_b; std::string vars() override { - return VARS_TO_STR7(type_a, type_b, m, n, k, bs, trans_b); + return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, trans_b); } double max_nmse_err() override { @@ -1868,8 +2065,9 @@ struct test_out_prod : public test_case { test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, int64_t m = 32, int64_t n = 32, int64_t k = 32, std::array bs = {10, 10}, + std::array nr = {2, 2}, bool trans_b = false) - : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), trans_b(trans_b) {} + : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), trans_b(trans_b) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]); @@ -1877,10 +2075,10 @@ struct test_out_prod : public test_case { ggml_tensor * b; if (trans_b) { - b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0], bs[1]); + b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); b = ggml_transpose(ctx, b); } else { - b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0], bs[1]); + b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0]*nr[0], bs[1]*nr[1]); } ggml_set_name(b, "b"); @@ -2149,11 +2347,12 @@ struct test_soft_max : public test_case { const ggml_type type; const std::array ne; const bool mask; + const ggml_type m_prec; const float scale; const float max_bias; std::string vars() override { - return VARS_TO_STR5(type, ne, mask, scale, max_bias); + return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias); } // the 1024 test with bias occasionally fails: @@ -2165,9 +2364,10 @@ struct test_soft_max : public test_case { test_soft_max(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 5, 4, 3}, bool mask = false, + ggml_type m_prec = GGML_TYPE_F32, float scale = 1.0f, float max_bias = 0.0f) - : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {} + : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); @@ -2176,7 +2376,7 @@ struct test_soft_max : public test_case { ggml_tensor * mask = nullptr; if (this->mask) { - mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]); + mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]); ggml_set_name(mask, "mask"); } @@ -2191,8 +2391,38 @@ struct test_soft_max : public test_case { } }; +// GGML_OP_SOFT_MAX_BACK +struct test_soft_max_back : public test_case { + const ggml_type type; + const std::array ne; + const float scale; + const float max_bias; + + std::string vars() override { + return VARS_TO_STR4(type, ne, scale, max_bias); + } + + test_soft_max_back(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}, + float scale = 1.0f, + float max_bias = 0.0f) + : type(type), ne(ne), scale(scale), max_bias(max_bias) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_soft_max_ext_back(ctx, a, b, scale, max_bias); + ggml_set_name(out, "out"); -// GGML_OP_ROPE + return out; + } +}; + +// GGML_OP_ROPE + GGML_OP_ROPE_BACK struct test_rope : public test_case { const ggml_type type; const std::array ne_a; @@ -2204,29 +2434,36 @@ struct test_rope : public test_case { float af; // attn_factor bool ff; int v; // view (1 : non-contiguous a) + bool forward; std::string vars() override { + // forward can be inferred from the op, does not need to be printed return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v); } test_rope(ggml_type type = GGML_TYPE_F32, std::array ne_a = {10, 5, 3, 1}, - int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0) - : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v) {} + int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, + float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true) + : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a; if (v & 1) { auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3; a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + if (forward) { + ggml_set_param(ctx, a); + } ggml_set_name(a, "a"); a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); ggml_set_name(a, "view_of_a"); } else { a = ggml_new_tensor(ctx, type, 4, ne_a.data()); - ggml_set_param(ctx, a); + if (forward) { + ggml_set_param(ctx, a); + } ggml_set_name(a, "a"); } @@ -2252,14 +2489,26 @@ struct test_rope : public test_case { if (is_vision) { GGML_ASSERT(n_dims/4 > 0); int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate - out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (forward) { + out = ggml_rope_multi (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } else { GGML_ASSERT(n_dims/3 > 0); int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0}; - out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (forward) { + out = ggml_rope_multi (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } } else { - out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (forward) { + out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } ggml_set_name(out, "out"); @@ -2864,9 +3113,10 @@ struct test_flash_attn_ext : public test_case { const float logit_softcap; // Gemma 2 const ggml_type type_KV; + std::array permute; std::string vars() override { - return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV); + return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute); } double max_nmse_err() override { @@ -2881,19 +3131,33 @@ struct test_flash_attn_ext : public test_case { } test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, - bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16) - : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {} + bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16, + std::array permute = {0, 1, 2, 3}) + : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {} ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV)); - ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1); + auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * { + int64_t ne[4] = {ne0, ne1, ne2, ne3}; + int64_t ne_perm[4]; + for (int i = 0; i < 4; ++i) { + ne_perm[permute[i]] = ne[i]; + } + ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]); + if (permute != std::array{0, 1, 2, 3}) { + t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]); + } + return t; + }; + + ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1); ggml_set_name(q, "q"); - ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); + ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1); ggml_set_name(k, "k"); - ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); + ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1); ggml_set_name(v, "v"); ggml_tensor * m = nullptr; @@ -2961,6 +3225,40 @@ struct test_cross_entropy_loss : public test_case { } }; +// GGML_OP_CROSS_ENTROPY_LOSS_BACK +struct test_cross_entropy_loss_back : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_cross_entropy_loss_back(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * grad = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_name(grad, "grad"); + + ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(logits, "logits"); + + ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(labels, "labels"); + + // Ensure labels add up to 1: + labels = ggml_soft_max(ctx, labels); + ggml_set_name(labels, "labels_normalized"); + + ggml_tensor * out = ggml_cross_entropy_loss_back(ctx, grad, logits, labels); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_OPT_STEP_ADAMW struct test_opt_step_adamw : public test_case { const ggml_type type; @@ -3460,6 +3758,16 @@ static std::vector> make_test_cases_eval() { } } + test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false)); + for (ggml_type type : all_types) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v)); + } + } + for (bool v : {false, true}) { + test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v)); + } + for (ggml_type type_input : {GGML_TYPE_F32}) { for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { for (int k0 : {1, 3}) { @@ -3557,6 +3865,16 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2})); } + for (bool view : {false, true}) { + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); + } + test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); test_cases.emplace_back(new test_dup(GGML_TYPE_I32)); @@ -3582,6 +3900,12 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows } } + for (ggml_type type_dst : {GGML_TYPE_F32}) { + for (ggml_type type_src : all_types) { + test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4})); + test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows + } + } for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) { test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous @@ -3638,10 +3962,12 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_scale()); + test_cases.emplace_back(new test_silu_back()); - for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) { - test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); - test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); + for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) { + test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); + test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); + test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1})); @@ -3660,38 +3986,35 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4)); - for (int i = 1; i < 9; ++i) { - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + for (ggml_type type_a : all_types) { + for (int i = 1; i < 10; ++i) { + test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); + } } #if 1 for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { // test cases without permutation - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2})); - - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2})); + + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2})); // test cases with permutation test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); @@ -3781,22 +4104,19 @@ static std::vector> make_test_cases_eval() { for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, { 1, 1})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); - - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1}, true)); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10})); - test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10})); + for (int n : {1, 16}) { + for (int k : {1, 16}) { + for (int bs2 : {1, 3}) { + for (int bs3 : {1, 3}) { + for (int nr2 : {1, 2}) { + for (int nr3 : {1, 2}) { + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, n, k, {bs2, bs3}, {nr2, nr3})); + } + } + } + } + } + } } } @@ -3832,19 +4152,41 @@ static std::vector> make_test_cases_eval() { for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { for (int64_t ne1 : {16, 1024}) { - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias)); + if (mask) { + for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) { + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias)); + } + } else { + /* The precision of mask here doesn't matter as boolean mask is false */ + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias)); + } } } } } } - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 8.0f)); + + for (float max_bias : {0.0f, 8.0f}) { + for (float scale : {1.0f, 0.1f}) { + for (int64_t ne0 : {16, 1024}) { + for (int64_t ne1 : {16, 1024}) { + test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0, ne1, 1, 1}, scale, max_bias)); + test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias)); + } + } + } + } - { + for (bool fw : {true, false}) { // fw == forward bool all = true; for (float v : { 0, 1 }) { @@ -3853,29 +4195,29 @@ static std::vector> make_test_cases_eval() { for (float af : { 1.0f, 1.4245f }) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (bool ff : {false, true}) { // freq_factors - test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B + test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B if (all) { - test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B + test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B } if (all) { - test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) } if (all) { - test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 2B) - test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 7B) - test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl ViT) + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B) + test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) } - test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) } } @@ -3925,6 +4267,10 @@ static std::vector> make_test_cases_eval() { for (int nb : { 1, 3, 32, 35, }) { for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV)); + // run fewer test cases permuted + if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3})); + } } } } @@ -3934,7 +4280,11 @@ static std::vector> make_test_cases_eval() { } } - test_cases.emplace_back(new test_cross_entropy_loss()); + test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3})); + test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1})); + test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, { 10, 5, 4, 3})); + test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1})); + test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); // these tests are disabled to save execution time, but they can be handy for debugging @@ -3959,13 +4309,13 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));