Skip to content

Commit

Permalink
Merge pull request #26 from boathit/master
Browse files Browse the repository at this point in the history
add logsigmoid
  • Loading branch information
MikeInnes authored Feb 5, 2018
2 parents d15b558 + 92f1d42 commit 2a20d64
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ module NNlib

using Requires

export σ, sigmoid, relu, leakyrelu, elu, swish, selu, softplus, softsign,
export σ, sigmoid, relu, leakyrelu, elu, swish, selu, softplus, softsign, logσ, logsigmoid,
softmax, logsoftmax, conv2d, maxpool2d, avgpool2d

const libnnlib = Libdl.find_library("nnlib.$(Libdl.dlext)", [joinpath(@__DIR__, "..", "deps")])

include("numeric.jl")
include("activation.jl")
include("logsigmoid.jl")
include("softmax.jl")
include("logsoftmax.jl")
include("linalg.jl")
Expand Down
24 changes: 24 additions & 0 deletions src/logsigmoid.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

"""
logσ(x)
Return `log(σ(x))` which is computed in a numerically stable way.
julia> logσ(0.)
-0.6931471805599453
julia> logσ.([-100, -10, 100.])
3-element Array{Float64,1}:
-100.0
-10.0
-0.0
"""
function logσ(x)
max_v = max(zero(x), -x)
z = exp(-max_v) + exp(-x-max_v)
-(max_v + log(z))
end

∇logσ(Δ, x) = Δ * (1 - σ(x))

const logsigmoid = logσ
const ∇logsigmoid = ∇logσ
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ include("conv.jl")
xs = rand(5)
@test softmax(xs) exp.(xs) ./ sum(exp.(xs))
@test logsoftmax(xs) log.(softmax(xs))
@test logsigmoid.(xs) log.(sigmoid.(xs))

xs = rand(5,10)
@test softmax(xs) exp.(xs) ./ sum(exp.(xs),1)
@test logsoftmax(xs) log.(softmax(xs))
@test logsigmoid.(xs) log.(sigmoid.(xs))

for T in [:Float32, :Float64]
@eval @test logsigmoid.($T[-100_000, 100_000.]) $T[-100_000, 0.]
end

## compare the outputs with the PyTorch nn.LogSoftmax returns
xs = Float32[1, 2, 3000.]
Expand Down

0 comments on commit 2a20d64

Please sign in to comment.