diff --git a/src/SugarBLAS.jl b/src/SugarBLAS.jl index 93c03e2..e70fe1e 100644 --- a/src/SugarBLAS.jl +++ b/src/SugarBLAS.jl @@ -10,14 +10,11 @@ export @blas! include("Match/Match.jl") using .Match -import Base: copy, - - -copy(s::Symbol) = s #Generic fallback deprecated in 0.5 - """ -Negate a symbol or expression +Negate a number, symbol or expression """ -function -(ast) +neg(ast::Number) = -ast +function neg(ast::Union{Symbol, Expr}) if @match(ast, -ast) | (ast == 0) ast else @@ -127,7 +124,7 @@ macro blas!(expr::Expr) @case begin @match(expr, X *= a) => @call scale!(a,X) @match(expr, X = a*X) => @call scale!(a,X) - @match(expr, Y = Y - a*X) => @call Base.LinAlg.axpy!(-a,X,Y) + @match(expr, Y = Y - a*X) => @call Base.LinAlg.axpy!(neg(a),X,Y) @match(expr, Y = Y - X) => @call Base.LinAlg.axpy!(-1.0,X,Y) @match(expr, Y = a*X + Y) => @call Base.LinAlg.axpy!(a,X,Y) @match(expr, Y = X + Y) => @call Base.LinAlg.axpy!(1.0,X,Y) @@ -186,7 +183,7 @@ macro axpy!(expr::Expr) unkeyword!(expr) expr = expand(expr) @case begin - @match(expr, Y = Y - a*X) => @call(Base.LinAlg.axpy!(-a,X,Y)) + @match(expr, Y = Y - a*X) => @call(Base.LinAlg.axpy!(neg(a),X,Y)) @match(expr, Y = Y - X) => @call Base.LinAlg.axpy!(-1.0,X,Y) @match(expr, Y = a*X + Y) => @call Base.LinAlg.axpy!(a,X,Y) @match(expr, Y = X + Y) => @call Base.LinAlg.axpy!(1.0,X,Y) @@ -208,7 +205,7 @@ macro ger!(expr::Expr) expr = expand(expr) f = @case begin @match(expr, A = alpha*x*y' + A) => identity - @match(expr, A = A - alpha*x*y') => (-) + @match(expr, A = A - alpha*x*y') => neg otherwise => error("No match found") end @call Base.LinAlg.BLAS.ger!(f(alpha),x,y,A) @@ -231,7 +228,7 @@ macro syr!(expr::Expr) @match(expr, A[uplo] = right) || error("No match found") f = @case begin @match(right, alpha*x*x.' + Y) => identity - @match(right, Y - alpha*x*x.') => (-) + @match(right, Y - alpha*x*x.') => neg otherwise => error("No match found") end (@match(Y, Y[uplo]) && (Y == A)) || (Y == A) || error("No match found") @@ -282,7 +279,7 @@ macro syrk!(expr::Expr) @match(expr, C[uplo] = right) || error("No match found") f = @case begin @match(right, alpha*X*Y + D) => identity - @match(right, D - alpha*X*Y) => (-) + @match(right, D - alpha*X*Y) => neg otherwise => error("No match found") end trans = @case begin @@ -312,7 +309,7 @@ macro her!(expr::Expr) @match(expr, A[uplo] = right) || error("No match found") f = @case begin @match(right, alpha*x*x' + Y) => identity - @match(right, Y - alpha*x*x') => (-) + @match(right, Y - alpha*x*x') => neg otherwise => error("No match found") end (@match(Y, Y[uplo]) && (Y == A)) || (Y == A) || error("No match found") @@ -361,7 +358,7 @@ macro herk!(expr::Expr) @match(expr, C[uplo] = right) || error("No match found") f = @case begin @match(right, alpha*X*Y + D) => identity - @match(right, D - alpha*X*Y) => (-) + @match(right, D - alpha*X*Y) => neg otherwise => error("No match found") end trans = @case begin @@ -389,7 +386,7 @@ macro gbmv(expr::Expr) @match(expr, alpha*Y*x) || error("No match found") trans = @match(Y, Y') ? 'T' : 'N' @match(Y, A[kl:ku,h=m]) - @call Base.LinAlg.BLAS.gbmv(trans,m,-kl,ku,alpha,A,x) + @call Base.LinAlg.BLAS.gbmv(trans,m,neg(kl),ku,alpha,A,x) end """ @@ -410,14 +407,14 @@ macro gbmv!(expr::Expr) @match(expr, y = right) || error("No match found") f = @case begin @match(right, alpha*Y*x + w) => identity - @match(right, w - alpha*Y*x) => (-) + @match(right, w - alpha*Y*x) => neg otherwise => error("No match found") end trans = @match(Y, Y') ? 'T' : 'N' @match(Y, A[kl:ku,h=m]) @match(w, beta*w) || (beta = 1.0) (y == w) || error("No match found") - @call Base.LinAlg.BLAS.gbmv!(trans,m,-kl,ku,f(alpha),A,x,beta,y) + @call Base.LinAlg.BLAS.gbmv!(trans,m,neg(kl),ku,f(alpha),A,x,beta,y) end """ @@ -458,7 +455,7 @@ macro sbmv!(expr::Expr) @match(expr, y = right) || error("No match found") f = @case begin @match(right, alpha*A[0:k,uplo]*x + w) => identity - @match(right, w - alpha*A[0:k,uplo]*x) => (-) + @match(right, w - alpha*A[0:k,uplo]*x) => neg otherwise => error("No match found") end @match(w, beta*w) || (beta = 1.0) @@ -519,7 +516,7 @@ macro gemm!(expr::Expr) @match(expr, C = right) || error("No match found") f = @case begin @match(right, alpha*A*B + D) => identity - @match(right, D - alpha*A*B) => (-) + @match(right, D - alpha*A*B) => neg otherwise => error("No match found") end tA = @match(A, A') ? 'T' : 'N' @@ -571,9 +568,9 @@ macro gemv!(expr::Expr) expr = expand(expr) @match(expr, y = right) || error("No match found") f = @case begin - @match(right, alpha*A*x + w) => identity - @match(right, w - alpha*A*x) => (-) - otherwise => error("No match found") + @match(right, alpha*A*x + w) => identity + @match(right, w - alpha*A*x) => neg + otherwise => error("No match found") end tA = @match(A, A') ? 'T' : 'N' @match(w, beta*w) || (beta = 1.0) @@ -634,7 +631,7 @@ macro symm!(expr::Expr) @match(expr, C = right) || error("No match found") f = @case begin @match(right, alpha*A*B + D) => identity - @match(right, D - alpha*A*B) => (-) + @match(right, D - alpha*A*B) => neg otherwise => error("No match found") end side = @case begin @@ -684,7 +681,7 @@ macro symv!(expr::Expr) @match(expr, y = right) || error("No match found") f = @case begin @match(right, alpha*A[uplo]*x + w) => identity - @match(right, w - alpha*A[uplo]*x) => (-) + @match(right, w - alpha*A[uplo]*x) => neg otherwise => error("No match found") end @match(w, beta*w) || (beta = 1.0)