Skip to content

Commit

Permalink
add _rng_compat_array (#458)
Browse files Browse the repository at this point in the history
* add _rng_compat_array

* Update src/dropout.jl

Co-authored-by: Brian Chen <[email protected]>

Co-authored-by: Brian Chen <[email protected]>
  • Loading branch information
mcabbott and ToucheSir authored Jan 7, 2023
1 parent 5f63dbf commit 4832bb8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.8.14"
version = "0.8.15"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
7 changes: 6 additions & 1 deletion src/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ julia> mean(dropout(ones(10^4, 5), 0.3, dims=1), dims=1)
dropout(A::AbstractArray, p::Real; dims = :) = dropout(_rng_from_array(A), A, p; dims)

function dropout(rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
_rng_compat_array(rng, A)
T = float(eltype(A))
0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1"))
if p > 0
Expand All @@ -52,7 +53,7 @@ function dropout(rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
end

"""
dropout!(B, A, p; dims=:)
dropout!(B, A, p; [dims])
This does exactly `B .= dropout(A, p; dims)`,
or rather, it's the implementation of out-of-place [`dropout`](@ref).
Expand All @@ -62,6 +63,7 @@ dropout!(B::AbstractArray, A::AbstractArray, p::Real; dims = :) = dropout!(_rng_
function dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real; dims=:)
size(dst) == size(src) || throw(DimensionMismatch("dropout! expects output array the same size as input"))
0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1"))
_rng_compat_array(rng, A)
if p > 0
pT = convert(real(eltype(dst)), p)
_dropout!(rng, dst, src, pT, dims)
Expand Down Expand Up @@ -155,3 +157,6 @@ _rng_from_array(::AbstractArray) = Random.default_rng()

@non_differentiable _rng_from_array(::Any)

# This exists because `rand!(default_rng(), CUDA.rand(3))` ignores the RNG,
# and Flux would prefer an error. NNlibCUDA will overload it to produce that.
_rng_compat_array(::AbstractRNG, ::AbstractArray) = nothing

2 comments on commit 4832bb8

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75300

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.15 -m "<description of version>" 4832bb8513c9a576956dc06fd3bcf0bc4bdda334
git push origin v0.8.15

Please sign in to comment.