Skip to content

Commit

Permalink
add logsigmoid
Browse files Browse the repository at this point in the history
  • Loading branch information
boathit committed Jan 30, 2018
1 parent d15b558 commit 92f1d42
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ module NNlib

using Requires

export σ, sigmoid, relu, leakyrelu, elu, swish, selu, softplus, softsign,
export σ, sigmoid, relu, leakyrelu, elu, swish, selu, softplus, softsign, logσ, logsigmoid,
softmax, logsoftmax, conv2d, maxpool2d, avgpool2d

const libnnlib = Libdl.find_library("nnlib.$(Libdl.dlext)", [joinpath(@__DIR__, "..", "deps")])

include("numeric.jl")
include("activation.jl")
include("logsigmoid.jl")
include("softmax.jl")
include("logsoftmax.jl")
include("linalg.jl")
Expand Down
24 changes: 24 additions & 0 deletions src/logsigmoid.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

"""
logσ(x)
Return `log(σ(x))` which is computed in a numerically stable way.
julia> logσ(0.)
-0.6931471805599453
julia> logσ.([-100, -10, 100.])
3-element Array{Float64,1}:
-100.0
-10.0
-0.0
"""
function logσ(x)
max_v = max(zero(x), -x)
z = exp(-max_v) + exp(-x-max_v)
-(max_v + log(z))
end

∇logσ(Δ, x) = Δ * (1 - σ(x))

const logsigmoid = logσ
const ∇logsigmoid = ∇logσ
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ include("conv.jl")
xs = rand(5)
@test softmax(xs) exp.(xs) ./ sum(exp.(xs))
@test logsoftmax(xs) log.(softmax(xs))
@test logsigmoid.(xs) log.(sigmoid.(xs))

xs = rand(5,10)
@test softmax(xs) exp.(xs) ./ sum(exp.(xs),1)
@test logsoftmax(xs) log.(softmax(xs))
@test logsigmoid.(xs) log.(sigmoid.(xs))

for T in [:Float32, :Float64]
@eval @test logsigmoid.($T[-100_000, 100_000.]) $T[-100_000, 0.]
end

## compare the outputs with the PyTorch nn.LogSoftmax returns
xs = Float32[1, 2, 3000.]
Expand Down

0 comments on commit 92f1d42

Please sign in to comment.