From 3e56f3d6d12686771e6072ea548da812a1d3c85d Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Fri, 8 Feb 2019 19:45:18 +0900 Subject: [PATCH] Rewrite `im2col()` for greater performance The performance gain of this PR comes mainly from identifying that the conditionals within the `im2col()` inner loops are really only needed towards the edges of the image being `im2col()`'ed. By separating the loops into two separate segments, we are able to achieve very high memory bandwidth utilizaiton for the majority of the wallclock time of the function. Eventually we will want to parallelize this loop via threading, but until that is possible, this should be a nice improvement. In my testing, this speeds up `im2col()` by a factor of 2-3x. This change also begins the work of pushing some convolutional parameters into the type domain. This is especially helpful for `im2col()`, as it makes it possible that the compiler could completely elide the second set of for loops in the case of no padding (this is not tested; I don't know if the compiler actually is able to elide the loops) but pushing these parameters into the type domain gives a substantial speed boost. --- src/impl/conv.jl | 190 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 142 insertions(+), 48 deletions(-) diff --git a/src/impl/conv.jl b/src/impl/conv.jl index 5feb58e39..cf0639b0c 100644 --- a/src/impl/conv.jl +++ b/src/impl/conv.jl @@ -10,32 +10,112 @@ function psize(p, x) end end -function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, width::Int, height::Int, channels::Int, - kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int, stride_w::Int, stride_h::Int, - dil_w::Int, dil_h::Int, mode::Int) where T - - height_col = div(height + 2pad_h - (kernel_h - 1) * dil_h - 1, stride_h) + 1 - width_col = div(width + 2pad_w - (kernel_w - 1) * dil_w - 1, stride_w) + 1 - channels_col = channels * kernel_h * kernel_w +# Type system-level information about convolution dimensions. Critical for things like +# im2col_2d!() to generate efficient code. +struct ConvDims{img, kernel, channels, stride, padding, dilation, flipkernel} end +img_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = I + +# Calculate the output dimensions of this convolution +function output_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} + O_w = div(I[1] + P[1] + P[2] - (K[1] - 1) * D[1] - 1, S[1]) + 1 + O_h = div(I[2] + P[3] + P[4] - (K[1] - 1) * D[1] - 1, S[1]) + 1 + return (O_w, O_h) +end +kernel_size(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = K +img_channels(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = C +stride(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = S +padding(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = P +dilation(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = D +flipkernel(c::ConvDims{I,K,C,S,P,D,F}) where {I, K, C, S, P, D, F} = F + +function im2col_2d!(img::AbstractArray{T,3}, col::AbstractArray{T,2}, cdims::ConvDims) where T + width, height = img_size(cdims) + kernel_w, kernel_h = kernel_size(cdims) + channels = img_channels(cdims) + pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi = padding(cdims) + dil_w, dil_h = dilation(cdims) + stride_w, stride_h = stride(cdims) + width_col, height_col = output_size(cdims) + + if flipkernel(cdims) + flipk = (w, h) -> (kernel_w - w + 1, kernel_h - h + 1) + else + flipk = (w, h) -> (w, h) + end - #pragma omp parallel for - for c = 1:channels_col - w_offset = (c - 1) % kernel_w - h_offset = div(c - 1, kernel_w) % kernel_h - c_im = div(c - 1, kernel_h * kernel_w) - if mode == 0 - w_offset = kernel_w - 1 - w_offset - h_offset = kernel_h - 1 - h_offset + # Reshape col for easy access. + col_reshaped = reshape(col, (width_col, height_col, kernel_w, kernel_h, channels)) + + # Let us first calculate the number of rows/columns within which we must zero out some + # portion of the image patches we're copying over. Note the subtractions on the `_hi` + # variants are due to us needing to account for padding that is completely ignored due + # to stride/dilation/kernel size combinations. + spill_w_lo = ceil(Int, pad_w_lo/stride_w) + spill_w_hi = width_col - div(width + pad_w_lo - (kernel_w - 1)*dil_w, stride_w) + spill_h_lo = ceil(Int, pad_h_lo/stride_h) + spill_h_hi = height_col - div(height + pad_h_lo - (kernel_h - 1)*dil_h, stride_h) + spill_w_hi_abs = width_col - spill_w_hi + 1 + spill_h_hi_abs = height_col - spill_h_hi + 1 + + # First, a helper function to project from output (w, h) to input (input_w, input_h) + project(idx, stride, pad) = (idx - 1)*stride - pad + 1 + + # These are the regions we're going to have to run with cognizance of padding + padded_regions = ( + (1:width_col, 1:spill_h_lo), + (1:spill_w_lo, (spill_h_lo+1):(spill_h_hi_abs-1)), + (spill_w_hi_abs:width_col, (spill_h_lo+1):(spill_h_hi_abs-1)), + (1:width_col, spill_h_hi_abs:height_col), + ) + + # We begin by copying the central region of the image which requires no padding at all. + # Eliminating the branches of the fully generalized version below gives us a nice + # speedup on the majority of the data. + for c in 1:channels + for kh in 1:kernel_h + for kw in 1:kernel_w + for h in (spill_h_lo+1):(height_col - spill_h_hi) + input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h + + @inbounds for w in (spill_w_lo+1):(width_col - spill_w_hi) + input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w + col_reshaped[w, h, flipk(kw, kh)..., c] = img[input_kw, input_kh, c] + end + end + end end - for h = 1:height_col - for w = 1:width_col - h_pad = (h - 1) * stride_h - pad_h + h_offset * dil_h - w_pad = (w - 1) * stride_w - pad_w + w_offset * dil_w - if h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width - col[((c - 1)*height_col+h-1) * width_col + w] = - img[(c_im * height + h_pad) * width + w_pad + 1] - else - col[((c - 1)*height_col+h - 1) * width_col + w] = 0 + end + + # For each "padded region", we run the fully general version + for (w_region, h_region) in padded_regions + for c in 1:channels + for kh in 1:kernel_h + for kw in 1:kernel_w + @inbounds for h in h_region + input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h + + # If this column is off the edge, then deal with the entire thing + # in one fell swoop, like a ravenous flock of crows. CAW CAW. + if input_kh <= 0 || input_kh > height + for w in w_region + col_reshaped[w, h, flipk(kw, kh)..., c] = zero(eltype(col_reshaped)) + end + continue + end + + @inbounds for w in w_region + input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w + + # If this pixel is off the edge of the map, clear it out. + if input_kw <= 0 || input_kw > width + col_reshaped[w, h, flipk(kw, kh)..., c] = zero(eltype(col_reshaped)) + continue + end + + # Copy the data over + col_reshaped[w, h, flipk(kw, kh)..., c] = img[input_kw, input_kh, c] + end + end end end end @@ -256,26 +336,41 @@ function depthwiseconv2d_grad_x!(dx::AbstractArray{T,4}, x::AbstractArray{T,4}, return dx end +function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}, + cdims::ConvDims; alpha=T(1)) where T + Wx, Hx = img_size(cdims) + Ww, Hw = kernel_size(cdims) + Wy, Hy = output_size(cdims) + Cx = img_channels(cdims) + M, N, K, Y = Wy*Hy, size(y,4), Ww*Hw*Cx, Wy*Hy*size(y, 4) + + x2 = similar(x, im2col_dims(w, y)) + @inbounds for n in 1:size(x,4) + im2col_2d!(view(x, :, :, :, n), x2, cdims) + gemm!('N','N',M,N,K,alpha,pointer(x2),pointer(w),T(0),pointer(y,(n - 1)*Y + 1)) + end + return y +end + function conv2d!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4}; padding=0, stride=1, dilation=1, mode=0, alpha=T(1)) where T - if mode != 0 && mode != 1; throw(ArgumentError("conv2d only supports mode=0 or 1.")); end + if mode != 0 && mode != 1 + throw(ArgumentError("conv2d only supports mode=0 or 1.")) + end Wx,Hx,Cx,Nx = size(x) Ww,Hw,C1,C2 = size(w) - if Cx!=C1; throw(DimensionMismatch()); end - Wy,Hy,Cy,Ny = size(y) - x2dims = im2col_dims(w,y) - x2 = similar(x, x2dims) + + # Check that the number of channels in `x` matches the number of channels in each + # kernel of `w`. IF it doesn't, throw a DimensionMismatch() + if Cx != C1 + throw(DimensionMismatch()) + end (p1,p2) = psize(padding,x) (s1,s2) = psize(stride,x) (d1,d2) = psize(dilation, x) - M,N,K,Y = Wy*Hy,Cy,Ww*Hw*Cx,Wy*Hy*Cy - yidx = 1 - @inbounds for n in 1:Nx - im2col2d!(w, x, x2, n, p1, p2, s1, s2, d1, d2, mode) - gemm!('N','N',M,N,K,alpha,pointer(x2),pointer(w),T(0),pointer(y,yidx)) - yidx += Y - end - return y + + cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(d1,d2), mode == 0}() + return conv2d!(y, x, w, cdims; alpha=alpha) end function conv2d_grad_w!(dw::AbstractArray{T,4}, x::AbstractArray{T,4}, dy::AbstractArray{T,4}; @@ -332,8 +427,9 @@ function im2col2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T, n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T Wx,Hx,Cx,Nx = size(x) Ww,Hw,C1,C2 = w - xn = x[:, :, :, n] - im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode) + xn = view(x, :, :, :, n) + cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(1,1), mode == 0}() + im2col_2d!(xn,x2,cdims) return x2 end @@ -341,8 +437,9 @@ function im2col2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArr n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T Wx,Hx,Cx,Nx = size(x) Ww,Hw,C1,C2 = size(w) - xn = x[:, :, :, n] - im2col_2d!(xn,x2,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode) + xn = view(x, :, :, :, n) + cdims = ConvDims{(Wx,Hx),(Ww,Hw),Cx,(s1,s2),(p1,p1,p2,p2),(d1,d2), mode == 0}() + im2col_2d!(xn,x2,cdims) return x2 end @@ -350,9 +447,8 @@ function col2im2d!(w::NTuple{4,Int}, x::AbstractArray{T,4}, x2::AbstractArray{T, n::Int, p1::Int, p2::Int, s1::Int, s2::Int, mode::Int) where T Wx,Hx,Cx,Nx = size(x) Ww,Hw,C1,C2 = w - xn = x[:, :, :, n] + xn = view(x, :, :, :, n) col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,1,1,mode) - x[:, :, :, n] .= xn return x end @@ -360,9 +456,8 @@ function col2im2d!(w::AbstractArray{T,4}, x::AbstractArray{T,4}, x2::AbstractArr n::Int, p1::Int, p2::Int, s1::Int, s2::Int, d1::Int, d2::Int, mode::Int) where T Wx,Hx,Cx,Nx = size(x) Ww,Hw,C1,C2 = size(w) - xn = x[:, :, :, n] + xn = view(x, :, :, :, n) col2im_2d!(x2,xn,Wx,Hx,Cx,Ww,Hw,p1,p2,s1,s2,d1,d2,mode) - x[:, :, :, n] .= xn return x end @@ -445,7 +540,7 @@ function im2col3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArr s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T Wx,Hx,Dx,Cx,Nx = size(x) Ww,Hw,Dw,C1,C2 = size(w) - xn = x[:, :, :, :, n] + xn = view(x, :, :, :, :, n) im2col_3d!(xn,x2,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode) return x2 end @@ -455,8 +550,7 @@ function col2im3d!(w::AbstractArray{T,5}, x::AbstractArray{T,5}, x2::AbstractArr s3::Int, d1::Int, d2::Int, d3::Int, mode::Int) where T Wx,Hx,Dx,Cx,Nx = size(x) Ww,Hw,Dw,C1,C2 = size(w) - xn = x[:, :, :, :, n] + xn = view(x, :, :, :, :, n) col2im_3d!(x2,xn,Wx,Hx,Dx,Cx,Ww,Hw,Dw,p1,p2,p3,s1,s2,s3,d1,d2,d3,mode) - x[:, :, :, :, n] = xn return x end