Skip to content

Commit

Permalink
Merge pull request #22 from boathit/master
Browse files Browse the repository at this point in the history
Adding logsoftmax
  • Loading branch information
MikeInnes authored Jan 23, 2018
2 parents c3dfc0d + c0cc507 commit d15b558
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ module NNlib
using Requires

export σ, sigmoid, relu, leakyrelu, elu, swish, selu, softplus, softsign,
softmax, conv2d, maxpool2d, avgpool2d
softmax, logsoftmax, conv2d, maxpool2d, avgpool2d

const libnnlib = Libdl.find_library("nnlib.$(Libdl.dlext)", [joinpath(@__DIR__, "..", "deps")])

include("numeric.jl")
include("activation.jl")
include("softmax.jl")
include("logsoftmax.jl")
include("linalg.jl")
include("conv.jl")

Expand Down
34 changes: 34 additions & 0 deletions src/logsoftmax.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using Base.Threads

function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
@threads for j = 1:size(xs, 2)
@inbounds begin
xi_max = xs[1, j]
for i = 1:size(out, 1)
xi_max = max(xi_max, xs[i, j])
end
s = zero(eltype(out))
for i = 1:size(out, 1)
s += exp(xs[i, j] - xi_max)
end
for i = 1:size(out, 1)
out[i, j] = xs[i, j] - log(s) - xi_max
end
end
end
out
end


logsoftmax!(xs) = logsoftmax!(xs, xs)
logsoftmax(xs) = logsoftmax!(similar(xs), xs)

∇logsoftmax(Δ, xs) = ∇softmax./ softmax(xs), xs)

"""
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))
logsoftmax computes the log of softmax(xs) and it is more numerically stable
than softmax function in computing the cross entropy loss.
"""
logsoftmax
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@ include("conv.jl")

xs = rand(5)
@test softmax(xs) exp.(xs) ./ sum(exp.(xs))
@test logsoftmax(xs) log.(softmax(xs))

xs = rand(5,10)
@test softmax(xs) exp.(xs) ./ sum(exp.(xs),1)
@test logsoftmax(xs) log.(softmax(xs))

## compare the outputs with the PyTorch nn.LogSoftmax returns
xs = Float32[1, 2, 3000.]
@test logsoftmax(xs) [-2999, -2998, 0]

xs = Float32[1 2 3; 1000 2000 3000]
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.]
end

0 comments on commit d15b558

Please sign in to comment.