diff --git a/src/Automa.jl b/src/Automa.jl index ef4a9461..bf484896 100644 --- a/src/Automa.jl +++ b/src/Automa.jl @@ -3,13 +3,6 @@ module Automa using Printf: @sprintf import ScanByte: ScanByte, ByteSet -include("sdict.jl") -include("sset.jl") - -# TODO: use StableDict and StableSet only where they are required -const Dict = StableDict -const Set = StableSet - # Encode a byte set into a sequence of non-empty ranges. function range_encode(set::ScanByte.ByteSet) result = UnitRange{UInt8}[] diff --git a/src/edge.jl b/src/edge.jl index 0d2cb00c..095594ee 100644 --- a/src/edge.jl +++ b/src/edge.jl @@ -19,6 +19,37 @@ function Edge(labels::ByteSet, actions::ActionList) return Edge(labels, Precondition(), actions) end +# Don't override isless, because I don't want to figure out how +# to hash correctly. It's fine, we only use this for sorting in order_machine +function in_sort_order(e1::Edge, e2::Edge) + # First check edges + # This could presumably be done much faster + for (i,j) in zip(e1.labels, e2.labels) + i < j && return true + j < i && return false + end + l1, l2 = length(e1.labels), length(e2.labels) + l1 < l2 && return true + l2 < l1 && return false + + # Then check preconditions + p1, p2 = e1.precond, e2.precond + lp1, lp2 = length(p1.names), length(p2.names) + for i in 1:min(lp1, lp2) + isless(p1.names[i], p2.names[i]) && return true + isless(p2.names[i], p1.names[i]) && return false + u1, u2 = convert(UInt8, p1.values[i]), convert(UInt8, p2.values[i]) + u1 < u2 && return true + u2 < u1 && return false + end + lp1 < lp2 && return true + lp2 < lp1 && return false + + # A machine should never have two indistinguishable edges + # so if we reach here, something went wrong + error() +end + """Check if two edges have preconditions that could be disambiguating. I.e. can an FSM distinguish the edges based on their conditions? """ diff --git a/src/machine.jl b/src/machine.jl index 0ed8d1e4..2ceff8ef 100644 --- a/src/machine.jl +++ b/src/machine.jl @@ -71,6 +71,49 @@ function Base.show(io::IO, machine::Machine) ) end +# Reorder machine states so the states are in a completely deterministic manner. +# solves #19, see issue #106. +function reorder_machine(machine::Machine) + # old state index => new state index, in a deterministic manner + old2new = Dict(machine.start.state => 1) + remaining = [machine.start] + while !isempty(remaining) + node = pop!(remaining) + for (_, target) in sort(node.edges; lt=in_sort_order, by=first) + if !haskey(old2new, target.state) + old2new[target.state] = length(old2new) + 1 + push!(remaining, target) + end + end + end + + # Make new nodes complete with edges + new_nodes = Dict(i => Node(i) for i in 1:length(old2new)) + oldnodes = collect(traverse(machine.start)) + @assert length(oldnodes) == length(machine.states) + for old_node in traverse(machine.start) + for (e, t) in old_node.edges + push!( + new_nodes[old2new[old_node.state]].edges, + (e, new_nodes[old2new[t.state]]) + ) + + end + end + for node in values(new_nodes) + sort!(node.edges; by=first, lt=in_sort_order) + end + + # Rebuild machine and return it + Machine( + new_nodes[1], + machine.states, + 1, + Set([old2new[i] for i in machine.final_states]), + Dict{Int, ActionList}(old2new[i] => act for (i, act) in machine.eof_actions) + ) +end + """ compile(re::RegExp; optimize::Bool=true, unambiguous::Bool=true) -> Machine @@ -94,7 +137,8 @@ function compile(re::RegExp.RE; optimize::Bool=true, unambiguous::Bool=true) dfa = remove_dead_nodes(reduce_nodes(dfa)) end validate(dfa) - return dfa2machine(dfa) + machine = dfa2machine(dfa) + return reorder_machine(machine) end function dfa2machine(dfa::DFA) diff --git a/src/sdict.jl b/src/sdict.jl deleted file mode 100644 index 3bb5db1c..00000000 --- a/src/sdict.jl +++ /dev/null @@ -1,198 +0,0 @@ -# Stable Dictionary -# ================= - -mutable struct StableDict{K, V} <: AbstractDict{K, V} - slots::Vector{Int} - keys::Vector{K} - vals::Vector{V} - used::Int - nextidx::Int - - function StableDict{K, V}() where {K, V} - size = 16 - slots = zeros(Int, size) - keys = Vector{K}(undef, size) - vals = Vector{V}(undef, size) - return new{K,V}(slots, keys, vals, 0, 1) - end - - function StableDict(dict::StableDict{K, V}) where {K, V} - copy = StableDict{K, V}() - for (k, v) in dict - copy[k] = v - end - return copy - end -end - -function StableDict(kvs::Pair{K, V}...) where {K, V} - dict = StableDict{K, V}() - for (k, v) in kvs - dict[k] = v - end - return dict -end - -function StableDict{K, V}(kvs) where {K, V} - dict = StableDict{K, V}() - for (k, v) in kvs - dict[k] = v - end - return dict -end - -function StableDict(kvs) - return StableDict([Pair(k, v) for (k, v) in kvs]...) -end - -function StableDict() - return StableDict{Any, Any}() -end - -function Base.copy(dict::StableDict) - return StableDict(dict) -end - -function Base.length(dict::StableDict) - return dict.used -end - -function Base.haskey(dict::StableDict, key) - _, j = indexes(dict, convert(keytype(dict), key)) - return j > 0 -end - -function Base.getindex(dict::StableDict, key) - _, j = indexes(dict, convert(keytype(dict), key)) - if j == 0 - throw(KeyError(key)) - end - return dict.vals[j] -end - -function Base.get!(dict::StableDict, key, default) - if haskey(dict, key) - return dict[key] - end - val = convert(valtype(dict), default) - dict[key] = val - return val -end - -function Base.get!(f::Function, dict::StableDict, key) - if haskey(dict, key) - return dict[key] - end - val = convert(valtype(dict), f()) - dict[key] = val - return val -end - -function Base.setindex!(dict::StableDict, val, key) - k = convert(keytype(dict), key) - v = convert(valtype(dict), val) - @label index - i, j = indexes(dict, k) - if j == 0 - if dict.nextidx > lastindex(dict.keys) - expand!(dict) - @goto index - end - dict.keys[dict.nextidx] = k - dict.vals[dict.nextidx] = v - dict.slots[i] = dict.nextidx - dict.used += 1 - dict.nextidx += 1 - else - dict.slots[i] = j - dict.keys[j] = k - dict.vals[j] = v - end - return dict -end - -function Base.delete!(dict::StableDict, key) - k = convert(keytype(dict), key) - i, j = indexes(dict, k) - if j > 0 - dict.slots[i] = -j - dict.used -= 1 - end - return dict -end - -function Base.pop!(dict::StableDict) - if isempty(dict) - throw(ArgumentError("empty")) - end - i = dict.slots[argmax(dict.slots)] - key = dict.keys[i] - val = dict.vals[i] - delete!(dict, key) - return key => val -end - -function Base.iterate(dict::StableDict) - if length(dict) == 0 - return nothing - end - if dict.used == dict.nextidx - 1 - keys = dict.keys[1:dict.used] - vals = dict.vals[1:dict.used] - else - idx = sort!(dict.slots[dict.slots .> 0]) - @assert length(idx) == length(dict) - keys = dict.keys[idx] - vals = dict.vals[idx] - end - return (keys[1] => vals[1]), (2, keys, vals) -end - -function Base.iterate(dict::StableDict, st) - i = st[1] - if i > length(st[2]) - return nothing - end - return (st[2][i] => st[3][i]), (i + 1, st[2], st[3]) -end - -function hashindex(key, sz) - return (reinterpret(Int, hash(key)) & (sz-1)) + 1 -end - -function indexes(dict, key) - sz = length(dict.slots) - h = hashindex(key, sz) - i = 0 - while i < sz - j = mod1(h + i, sz) - k = dict.slots[j] - if k == 0 - return j, k - elseif k > 0 && isequal(dict.keys[k], key) - return j, k - end - i += 1 - end - return 0, 0 -end - -function expand!(dict) - sz = length(dict.slots) - newsz = sz * 2 - newslots = zeros(Int, newsz) - resize!(dict.keys, newsz) - resize!(dict.vals, newsz) - for i in 1:sz - j = dict.slots[i] - if j > 0 - k = hashindex(dict.keys[j], newsz) - while newslots[mod1(k, newsz)] != 0 - k += 1 - end - newslots[mod1(k, newsz)] = j - end - end - dict.slots = newslots - return dict -end diff --git a/src/sset.jl b/src/sset.jl deleted file mode 100644 index 32964b87..00000000 --- a/src/sset.jl +++ /dev/null @@ -1,91 +0,0 @@ -# Stable Set -# ========== - -mutable struct StableSet{T} <: Base.AbstractSet{T} - dict::StableDict{T, Nothing} - - function StableSet{T}() where T - return new{T}(StableDict{T, Nothing}()) - end -end - -function StableSet(vals) - set = StableSet{eltype(vals)}() - for v in vals - push!(set, v) - end - return set -end - -function Base.copy(set::StableSet) - newset = StableSet{eltype(set)}() - newset.dict = copy(set.dict) - return newset -end - -function Base.length(set::StableSet) - return length(set.dict) -end - -function Base.eltype(::Type{StableSet{T}}) where T - return T -end - -function Base.:(==)(set1::StableSet, set2::StableSet) - if length(set1) == length(set2) - for x in set1 - if x ∉ set2 - return false - end - end - return true - end - return false -end - -function Base.hash(set::StableSet, h::UInt) - h = hash(Base.hashs_seed, h) - for x in set - h = xor(h, hash(x)) - end - return h -end - -function Base.in(val, set::StableSet) - return haskey(set.dict, val) -end - -function Base.push!(set::StableSet, val) - v = convert(eltype(set), val) - if v ∉ set - set.dict[v] = nothing - end - return set -end - -function Base.pop!(set::StableSet) - return pop!(set.dict)[1] -end - -function Base.delete!(set::StableSet, val) - delete!(set.dict, val) - return set -end - -function Base.union!(set::StableSet, xs) - for x in xs - push!(set, x) - end - return set -end - -function Base.union(set::StableSet, xs) - return union!(copy(set), xs) -end - -function Base.iterate(set::StableSet, s=iterate(set.dict)) - if s === nothing - return nothing - end - return s[1][1], iterate(set.dict, s[2]) -end diff --git a/src/traverser.jl b/src/traverser.jl index 423920e1..c3fc5452 100644 --- a/src/traverser.jl +++ b/src/traverser.jl @@ -19,7 +19,7 @@ end function Base.iterate(t::Traverser{T}, state=nothing) where T if state === nothing - state = (visited = Set{T}(), unvisited = Set([t.start])) + state = (visited = Set{T}(), unvisited = [t.start]) end if isempty(state.unvisited) return nothing @@ -27,7 +27,7 @@ function Base.iterate(t::Traverser{T}, state=nothing) where T s = pop!(state.unvisited) push!(state.visited, s) for (_, t) in s.edges - if t ∉ state.visited + if t ∉ state.visited && t ∉ state.unvisited push!(state.unvisited, t) end end