diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c14a12f54a8d6..c93261295efe1 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -252,6 +252,10 @@ static bool fp16_mma_hardware_available(const int cc) { GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); } +static bool bf16_mma_hardware_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE; +} + // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. static bool new_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 898b24341471d..12a236c5c37ba 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1916,16 +1916,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - bool any_gpus_with_slow_fp16 = false; - bool any_gpus_without_fp16_mma = false; + bool any_gpus_with_slow_fp16 = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; @@ -1936,16 +1934,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - const int cc = ggml_cuda_info().devices[id].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); - any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc); + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } } else { - const int cc = ggml_cuda_info().devices[ctx.device].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); - any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc); + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } // debug helpers @@ -1956,7 +1954,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + if (!split && use_mul_mat_vec) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst); diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index d8c385e2399ae..1502e9d942fbc 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -2,25 +2,26 @@ #include "common.cuh" #include "mmv.cuh" -template +template static __global__ void mul_mat_vec( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row, - const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, - const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { - const int64_t row = blockIdx.x; - const int64_t channel_dst = blockIdx.y; - const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; - const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst; - const int64_t sample_dst = blockIdx.z; - const int64_t sample_x = sample_dst / sample_ratio; - const int64_t sample_y = sample_dst; - const int tid = threadIdx.x; + const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const int row = blockIdx.x; + const int channel_dst = blockIdx.y; + const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + const int tid = threadIdx.x; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row; - y += sample_y *stride_sample_y + channel_y *stride_channel_y; - dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst; + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; + dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; const float2 * y2 = (const float2 *) y; @@ -34,81 +35,108 @@ static __global__ void mul_mat_vec( __syncthreads(); } - float sumf = 0.0f; + float sumf[ncols_dst] = {0.0f}; if constexpr (std::is_same::value) { const float2 * x2 = (const float2 *) x; - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = x2[col2]; - const float2 tmpy = y2[col2]; - sumf += tmpx.x*tmpy.x; - sumf += tmpx.y*tmpy.y; + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumf[j] += tmpx.x*tmpy.x; + sumf[j] += tmpx.y*tmpy.y; + } } } else if constexpr (std::is_same::value) { const half2 * x2 = (const half2 *) x; if (std::is_same::value) { - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = __half22float2(x2[col2]); - const float2 tmpy = y2[col2]; - sumf += tmpx.x * tmpy.x; - sumf += tmpx.y * tmpy.y; + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumf[j] += tmpx.x * tmpy.x; + sumf[j] += tmpx.y * tmpy.y; + } } } else { #ifdef FP16_AVAILABLE - half2 sumh2 = make_half2(0.0f, 0.0f); + half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; + + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const half2 tmpx = x2[col2]; - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmp = y2[col2]; - sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); + } } - sumf = __low2float(sumh2) + __high2float(sumh2); +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); + } #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE } } else if constexpr (std::is_same::value) { const int * x2 = (const int *) x; - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { - const int tmpx = x2[col2]; - const float2 tmpy = y2[col2]; - sumf += float(reinterpret_cast(&tmpx)[0]) * tmpy.x; - sumf += float(reinterpret_cast(&tmpx)[1]) * tmpy.y; + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const int tmpx = x2[col2]; +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumf[j] += float(reinterpret_cast(&tmpx)[0]) * tmpy.x; + sumf[j] += float(reinterpret_cast(&tmpx)[1]) * tmpy.y; + } } } else { static_assert(std::is_same::value, "unsupported type"); } - sumf = warp_reduce_sum(sumf); +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf[j] = warp_reduce_sum(sumf[j]); - if (block_size > warp_size) { - buf_iw[tid/warp_size] = sumf; - __syncthreads(); - if (tid >= warp_size) { - return; + if (block_size > warp_size) { + buf_iw[tid/warp_size] = sumf[j]; + __syncthreads(); + if (tid < warp_size) { + sumf[j] = buf_iw[tid]; + sumf[j] = warp_reduce_sum(sumf[j]); + } + if (j < ncols_dst) { + __syncthreads(); + } } - sumf = buf_iw[tid]; - sumf = warp_reduce_sum(sumf); } - if (tid != 0) { + if (tid >= ncols_dst) { return; } - dst[row] = sumf; + dst[tid*stride_col_dst + row] = sumf[tid]; } -template +template static void launch_mul_mat_vec_cuda( const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t ncols, const int64_t nrows, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, cudaStream_t stream) { - GGML_ASSERT(ncols % 2 == 0); - GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); const int64_t channel_ratio = nchannels_dst / nchannels_x; @@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda( const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; default: { GGML_ABORT("fatal error"); @@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda( } } +template +static void mul_mat_vec_cuda_switch_ncols_dst( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + switch (ncols_dst) { + case 1: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 2: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 3: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 4: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 5: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 6: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 7: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 8: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + template static void mul_mat_vec_cuda( const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, enum ggml_prec prec, cudaStream_t stream) { if constexpr(std::is_same::value) { if (prec == GGML_PREC_DEFAULT) { - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + mul_mat_vec_cuda_switch_ncols_dst + (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); return; } } - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + mul_mat_vec_cuda_switch_ncols_dst + (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } @@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; - GGML_ASSERT(ncols_dst == 1); + GGML_ASSERT(!ids || ncols_dst == 1); switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; @@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec( GGML_ASSERT(dst->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; - GGML_ASSERT(src1_ncols == 1); - - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; // ggml_cuda_op provides single, contiguous matrices const int64_t stride_row = ne00; + const int64_t stride_col_y = ne10; + const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer const int64_t nchannels_x = 1; const int64_t nchannels_y = 1; const int64_t nchannels_dst = 1; @@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec( switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; @@ -334,3 +441,48 @@ void ggml_cuda_op_mul_mat_vec( GGML_UNUSED(src1_ncols); GGML_UNUSED(src1_padded_row_size); } + +bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { + if (src0_ne[0] % 2 != 0) { + return false; + } + switch (type) { + case GGML_TYPE_F32: + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return ne11 <= 8; + } + if (cc >= GGML_CUDA_CC_TURING) { + return ne11 <= 4; + } + return ne11 <= 3; + } + return ne11 <= 8; + case GGML_TYPE_F16: + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return src0_small && ne11 <= 4; + } + if (fp16_mma_hardware_available(cc)) { + return src0_small && ne11 <= 3; + } + return ne11 <= 8; + } + return ne11 <= 8; + case GGML_TYPE_BF16: + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return src0_small && ne11 <= 4; + } + if (bf16_mma_hardware_available(cc)) { + return src0_small && ne11 <= 3; + } + return ne11 <= 8; + } + return ne11 <= 8; + default: + return false; + } +} diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmv.cuh index 756e7e1cc7fc3..1330bcb6a8860 100644 --- a/ggml/src/ggml-cuda/mmv.cuh +++ b/ggml/src/ggml-cuda/mmv.cuh @@ -1,8 +1,5 @@ #include "common.cuh" -// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available -#define MMV_MAX_ROWS 512 - void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); void ggml_cuda_op_mul_mat_vec( @@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); + +bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);