Skip to content

Commit

Permalink
TEST: Add i32 gemm tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed Dec 1, 2018
1 parent 2ea01de commit 2946204
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion tests/sgemm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
extern crate itertools;
extern crate matrixmultiply;

use matrixmultiply::{sgemm, dgemm};
use matrixmultiply::{sgemm, dgemm, igemm};

use itertools::Itertools;
use itertools::{
Expand Down Expand Up @@ -35,6 +35,13 @@ impl Float for f64 {
fn is_nan(self) -> bool { self.is_nan() }
}

impl Float for i32 {
fn zero() -> Self { 0 }
fn one() -> Self { 1 }
fn from(x: i64) -> Self { x as Self }
fn nan() -> Self { i32::min_value() } // hack
fn is_nan(self) -> bool { self == i32::min_value() }
}

trait Gemm : Sized {
unsafe fn gemm(
Expand Down Expand Up @@ -64,6 +71,24 @@ impl Gemm for f32 {
}
}

impl Gemm for i32 {
unsafe fn gemm(
m: usize, k: usize, n: usize,
alpha: Self,
a: *const Self, rsa: isize, csa: isize,
b: *const Self, rsb: isize, csb: isize,
beta: Self,
c: *mut Self, rsc: isize, csc: isize) {
igemm(
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}
}

impl Gemm for f64 {
unsafe fn gemm(
m: usize, k: usize, n: usize,
Expand Down Expand Up @@ -99,6 +124,11 @@ fn test_dgemm_strides() {
test_gemm_strides::<f64>();
}

#[test]
fn test_i32gemm_strides() {
test_gemm_strides::<i32>();
}

fn test_gemm_strides<F>() where F: Gemm + Float {
for n in 0..10 {
test_strides::<F>(n, n, n);
Expand Down

0 comments on commit 2946204

Please sign in to comment.