diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs index 3cb0ff9..7459d6c 100644 --- a/benches/benchmarks.rs +++ b/benches/benchmarks.rs @@ -1,6 +1,7 @@ extern crate matrixmultiply; pub use matrixmultiply::sgemm; pub use matrixmultiply::dgemm; +pub use matrixmultiply::igemm; #[macro_use] extern crate bencher; @@ -10,7 +11,14 @@ extern crate bencher; // by flop / s = 2 M N K / time -benchmark_main!(mat_mul_f32, mat_mul_f64, layout_f32_032, layout_f64_032); +benchmark_main!( + mat_mul_f32, + mat_mul_f64, + mat_mul_i32, + layout_f32_032, + layout_f64_032, + layout_i32_032 +); macro_rules! mat_mul { ($modname:ident, $gemm:ident, $(($name:ident, $m:expr, $n:expr, $k:expr))+) => { @@ -20,17 +28,17 @@ macro_rules! mat_mul { $( pub fn $name(bench: &mut Bencher) { - let a = vec![0.; $m * $n]; - let b = vec![0.; $n * $k]; - let mut c = vec![0.; $m * $k]; + let a = vec![0 as _; $m * $n]; + let b = vec![0 as _; $n * $k]; + let mut c = vec![0 as _; $m * $k]; bench.iter(|| { unsafe { $gemm( $m, $n, $k, - 1., + 1 as _, a.as_ptr(), $n, 1, b.as_ptr(), $k, 1, - 0., + 0 as _, c.as_mut_ptr(), $k, 1, ) } @@ -106,9 +114,9 @@ macro_rules! gemm_layout { fn base(bench: &mut Bencher, al: Layout, bl: Layout, cl: Layout) { - let a = vec![0.; $m * $m]; - let b = vec![0.; $m * $m]; - let mut c = vec![0.; $m * $m]; + let a = vec![0 as _; $m * $m]; + let b = vec![0 as _; $m * $m]; + let mut c = vec![0 as _; $m * $m]; let (rsa, csa) = al.strides($m, 1); let (rsb, csb) = bl.strides($m, 1); let (rsc, csc) = cl.strides($m, 1); @@ -116,10 +124,10 @@ macro_rules! gemm_layout { unsafe { $gemm( $m, $m, $m, - 1., + 1 as _, a.as_ptr(), rsa, csa, b.as_ptr(), rsb, csb, - 0., + 0 as _, c.as_mut_ptr(), rsc, csc, ) } @@ -157,6 +165,10 @@ gemm_layout!{layout_f64_032, dgemm, (m032, 32) } +gemm_layout!{layout_i32_032, igemm, + (m032, 32) +} + use std::ops::{Add, Mul}; @@ -219,3 +231,22 @@ ref_mat_mul!{ref_mat_mul_f32, f32, (m032, 32, 32, 32) (m064, 64, 64, 64) } + +mat_mul!{mat_mul_i32, igemm, + (m004, 4, 4, 4) + (m006, 6, 6, 6) + (m008, 8, 8, 8) + (m012, 12, 12, 12) + (m016, 16, 16, 16) + (m032, 32, 32, 32) + (m064, 64, 64, 64) + (m127, 127, 127, 127) + /* + (m256, 256, 256, 256) + (m512, 512, 512, 512) + (mix16x4, 32, 4, 32) + (mix32x2, 32, 2, 32) + (mix97, 97, 97, 125) + (mix128x10000x128, 128, 10000, 128) + */ +} diff --git a/src/gemm.rs b/src/gemm.rs index 01ec2d7..4dc11bd 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -19,6 +19,7 @@ use kernel::GemmKernel; use kernel::Element; use sgemm_kernel; use dgemm_kernel; +use igemm_kernel; use rawpointer::PointerExt; /// General matrix multiplication (f32) @@ -87,6 +88,23 @@ pub unsafe fn dgemm( c, rsc, csc) } +pub unsafe fn igemm( + m: usize, k: usize, n: usize, + alpha: i32, + a: *const i32, rsa: isize, csa: isize, + b: *const i32, rsb: isize, csb: isize, + beta: i32, + c: *mut i32, rsc: isize, csc: isize) +{ + gemm_loop::( + m, k, n, + alpha, + a, rsa, csa, + b, rsb, csb, + beta, + c, rsc, csc) +} + /// Ensure that GemmKernel parameters are supported /// (alignment, microkernel size). /// diff --git a/src/igemm_kernel.rs b/src/igemm_kernel.rs new file mode 100644 index 0000000..c91bfec --- /dev/null +++ b/src/igemm_kernel.rs @@ -0,0 +1,219 @@ +// Copyright 2016 - 2018 Ulrik Sverdrup "bluss" +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use kernel::GemmKernel; +use kernel::Element; +use archparam; + + +#[cfg(target_arch="x86")] +use std::arch::x86::*; +#[cfg(target_arch="x86_64")] +use std::arch::x86_64::*; + +pub enum Gemm { } + +pub type T = i32; + +const MR: usize = 8; +const NR: usize = 4; + +macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; } +macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; } + +impl GemmKernel for Gemm { + type Elem = T; + + #[inline(always)] + fn align_to() -> usize { 16 } + + #[inline(always)] + fn mr() -> usize { MR } + #[inline(always)] + fn nr() -> usize { NR } + + #[inline(always)] + fn always_masked() -> bool { true } + + #[inline(always)] + fn nc() -> usize { archparam::S_NC } + #[inline(always)] + fn kc() -> usize { archparam::S_KC } + #[inline(always)] + fn mc() -> usize { archparam::S_MC } + + #[inline(always)] + unsafe fn kernel( + k: usize, + alpha: T, + a: *const T, + b: *const T, + beta: T, + c: *mut T, rsc: isize, csc: isize) { + kernel(k, alpha, a, b, beta, c, rsc, csc) + } +} + +/// matrix multiplication kernel +/// +/// This does the matrix multiplication: +/// +/// C ← α A B + β C +/// +/// + k: length of data in a, b +/// + a, b are packed +/// + c has general strides +/// + rsc: row stride of c +/// + csc: col stride of c +/// + if beta is 0, then c does not need to be initialized +#[inline(never)] +pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T, + beta: T, c: *mut T, rsc: isize, csc: isize) +{ + // dispatch to specific compiled versions + #[cfg(any(target_arch="x86", target_arch="x86_64"))] + { + if is_x86_feature_detected_!("avx") { + return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc); + } else if is_x86_feature_detected_!("sse2") { + return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc); + } + } + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc); +} + +#[inline] +#[target_feature(enable="avx")] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T, + beta: T, c: *mut T, rsc: isize, csc: isize) +{ + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) +} + +#[inline] +#[target_feature(enable="sse2")] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T, + beta: T, c: *mut T, rsc: isize, csc: isize) +{ + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) +} + + +#[inline(always)] +unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T, + beta: T, c: *mut T, rsc: isize, csc: isize) +{ + let mut ab: [[T; NR]; MR] = [[0; NR]; MR]; + let mut a = a; + let mut b = b; + debug_assert_eq!(beta, 0); + + // Compute A B into ab[i][j] + unroll_by!(4 => k, { + loop_m!(i, loop_n!(j, { + ab[i][j] = ab[i][j].wrapping_add(at(a, i).wrapping_mul(at(b, j))); + })); + + a = a.offset(MR as isize); + b = b.offset(NR as isize); + }); + + macro_rules! c { + ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize)); + } + + // set C = α A B + β C + loop_n!(j, loop_m!(i, *c![i, j] = alpha.wrapping_mul(ab[i][j]))); +} + +#[inline(always)] +unsafe fn at(ptr: *const T, i: usize) -> T { + *ptr.offset(i as isize) +} + +#[cfg(test)] +mod tests { + use super::*; + use aligned_alloc::Alloc; + + fn aligned_alloc(elt: T, n: usize) -> Alloc where T: Copy + { + unsafe { + Alloc::new(n, Gemm::align_to()).init_with(elt) + } + } + + use super::T; + type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize); + + fn test_a_kernel(_name: &str, kernel_fn: KernelFn) { + const K: usize = 4; + let mut a = aligned_alloc(1, MR * K); + let mut b = aligned_alloc(0, NR * K); + for (i, x) in a.iter_mut().enumerate() { + *x = i as _; + } + + for i in 0..K { + b[i + i * NR] = 1; + } + let mut c = [0; MR * NR]; + unsafe { + kernel_fn(K, 1, &a[0], &b[0], 0, &mut c[0], 1, MR as isize); + // col major C + } + assert_eq!(&a[..], &c[..a.len()]); + } + + #[test] + fn test_native_kernel() { + test_a_kernel("kernel", kernel); + } + + #[test] + fn test_kernel_fallback_impl() { + test_a_kernel("kernel", kernel_fallback_impl); + } + + #[test] + fn test_loop_m_n() { + let mut m = [[0; NR]; MR]; + loop_m!(i, loop_n!(j, m[i][j] += 1)); + for arr in &m[..] { + for elt in &arr[..] { + assert_eq!(*elt, 1); + } + } + } + + mod test_arch_kernels { + use super::test_a_kernel; + macro_rules! test_arch_kernels_x86 { + ($($feature_name:tt, $function_name:ident),*) => { + $( + #[test] + fn $function_name() { + if is_x86_feature_detected_!($feature_name) { + test_a_kernel(stringify!($function_name), super::super::$function_name); + } else { + println!("Skipping, host does not have feature: {:?}", $feature_name); + } + } + )* + } + } + + #[cfg(any(target_arch="x86", target_arch="x86_64"))] + test_arch_kernels_x86! { + "avx", kernel_target_avx, + "sse2", kernel_target_sse2 + } + } +} diff --git a/src/kernel.rs b/src/kernel.rs index 1d59dcd..801b3d9 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -79,3 +79,15 @@ impl Element for f64 { *self += alpha * a; } } + +impl Element for i32 { + fn zero() -> Self { 0 } + fn one() -> Self { 1 } + fn is_zero(&self) -> bool { *self == 0 } + fn scale_by(&mut self, x: Self) { + *self = self.wrapping_mul(x); + } + fn scaled_add(&mut self, alpha: Self, a: Self) { + *self = self.wrapping_add(alpha.wrapping_mul(a)); + } +} diff --git a/src/lib.rs b/src/lib.rs index a7e9bd5..f263bd8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,8 +62,10 @@ mod kernel; mod gemm; mod sgemm_kernel; mod dgemm_kernel; +mod igemm_kernel; mod util; mod aligned_alloc; pub use gemm::sgemm; pub use gemm::dgemm; +pub use gemm::igemm; diff --git a/tests/sgemm.rs b/tests/sgemm.rs index 8b4e1ae..b4314a0 100644 --- a/tests/sgemm.rs +++ b/tests/sgemm.rs @@ -1,7 +1,7 @@ extern crate itertools; extern crate matrixmultiply; -use matrixmultiply::{sgemm, dgemm}; +use matrixmultiply::{sgemm, dgemm, igemm}; use itertools::Itertools; use itertools::{ @@ -35,6 +35,13 @@ impl Float for f64 { fn is_nan(self) -> bool { self.is_nan() } } +impl Float for i32 { + fn zero() -> Self { 0 } + fn one() -> Self { 1 } + fn from(x: i64) -> Self { x as Self } + fn nan() -> Self { i32::min_value() } // hack + fn is_nan(self) -> bool { self == i32::min_value() } +} trait Gemm : Sized { unsafe fn gemm( @@ -64,6 +71,24 @@ impl Gemm for f32 { } } +impl Gemm for i32 { + unsafe fn gemm( + m: usize, k: usize, n: usize, + alpha: Self, + a: *const Self, rsa: isize, csa: isize, + b: *const Self, rsb: isize, csb: isize, + beta: Self, + c: *mut Self, rsc: isize, csc: isize) { + igemm( + m, k, n, + alpha, + a, rsa, csa, + b, rsb, csb, + beta, + c, rsc, csc) + } +} + impl Gemm for f64 { unsafe fn gemm( m: usize, k: usize, n: usize, @@ -99,6 +124,11 @@ fn test_dgemm_strides() { test_gemm_strides::(); } +#[test] +fn test_i32gemm_strides() { + test_gemm_strides::(); +} + fn test_gemm_strides() where F: Gemm + Float { for n in 0..10 { test_strides::(n, n, n);