diff --git a/src/dropout.jl b/src/dropout.jl index a04222672..d021ee54d 100644 --- a/src/dropout.jl +++ b/src/dropout.jl @@ -63,7 +63,7 @@ dropout!(B::AbstractArray, A::AbstractArray, p::Real; dims = :) = dropout!(_rng_ function dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real; dims=:) size(dst) == size(src) || throw(DimensionMismatch("dropout! expects output array the same size as input")) 0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1")) - _rng_compat_array(rng, A) + _rng_compat_array(rng, src) if p > 0 pT = convert(real(eltype(dst)), p) _dropout!(rng, dst, src, pT, dims)