Skip to content

Commit

Permalink
Mitigate combinatorial explosion in dequantify with static conditions.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Aug 1, 2022
1 parent 13e76e3 commit d268a96
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ValSplit = "0625e100-946b-11ec-09cd-6328dd093154"
[compat]
AutoHashEquals = "0.2, 1.0"
IntervalArithmetic = "0.17, 0.18, 0.19, 0.20"
Julog = "0.1.14"
Julog = "0.1.15"
ParserCombinator = "2.0.0"
ValSplit = "0.1"
julia = "1.3"
Expand Down
25 changes: 16 additions & 9 deletions src/grounding/action.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ end

get_name(action::GroundActionGroup) = action.name

"Maximum limit for grounding by enumerating over typed objects."
const MAX_GROUND_BY_TYPE_LIMIT = 250
const MIN_GROUND_BY_PRECOND_LIMIT = 1

"Minimum number of static conditions for grounding by static satisfaction."
const MIN_GROUND_BY_STATIC_LIMIT = 1

"Returns an iterator over all ground arguments of an `action`."
function groundargs(domain::Domain, state::State, action::Action;
Expand All @@ -46,10 +49,10 @@ function groundargs(domain::Domain, state::State, action::Action;
preconds = flatten_conjs(get_precond(action))
filter!(p -> p.name in statics, preconds)
# Decide whether to generate by satisfying static preconditions
n_groundings = prod(length(get_objects(domain, state, ty))
n_groundings = prod(get_object_count(domain, state, ty)
for ty in get_argtypes(action))
use_preconds = n_groundings > MAX_GROUND_BY_TYPE_LIMIT &&
length(preconds) >= MIN_GROUND_BY_PRECOND_LIMIT
length(preconds) >= MIN_GROUND_BY_STATIC_LIMIT
if use_preconds # Filter using preconditions
# Add type conditions for correctness
act_vars, act_types = get_argvars(action), get_argtypes(action)
Expand All @@ -75,8 +78,10 @@ function groundactions(domain::Domain, state::State, action::Action;
statics=infer_static_fluents(domain))
ground_acts = GroundAction[]
# Dequantify and flatten preconditions and effects
precond = to_nnf(dequantify(get_precond(action), domain, state))
effects = flatten_conditions(dequantify(get_effect(action), domain, state))
precond = to_nnf(dequantify(get_precond(action),
domain, state, statics))
effects = flatten_conditions(dequantify(get_effect(action),
domain, state, statics))
# Iterate over possible groundings
for args in groundargs(domain, state, action; statics=statics)
# Construct ground action for each set of arguments
Expand All @@ -99,8 +104,8 @@ end
Returns all ground actions for a `domain` and initial `state`.
"""
function groundactions(domain::Domain, state::State)
statics = infer_static_fluents(domain)
function groundactions(domain::Domain, state::State;
statics=infer_static_fluents(domain))
iters = (groundactions(domain, state, act; statics=statics)
for act in values(get_actions(domain)))
return collect(Iterators.flatten(iters))
Expand All @@ -114,8 +119,10 @@ is never satisfiable given the `domain` and `state`, return `nothing`.
"""
function ground(domain::Domain, state::State, action::Action, args;
statics=infer_static_fluents(domain),
precond=to_nnf(dequantify(get_precond(action), domain, state)),
effects=flatten_conditions(dequantify(get_effect(action), domain, state))
precond=to_nnf(dequantify(get_precond(action),
domain, state, statics)),
effects=flatten_conditions(dequantify(get_effect(action),
domain, state, statics))
)
act_name = get_name(action)
act_vars = get_argvars(action)
Expand Down
99 changes: 84 additions & 15 deletions src/grounding/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,98 @@ end
Replaces universally or existentially quantified expressions with their
corresponding conjuctions or disjunctions over the object types they quantify.
"""
function dequantify(term::Term, domain::Domain, state::State)
if term.name in (:forall, :exists)
typeconds, query = flatten_conjs(term.args[1]), term.args[2]
query = dequantify(query, domain, state)
ground_terms = Term[query]
for cond in typeconds
type, var = cond.name, cond.args[1]
ground_terms = map(ground_terms) do gt
return Term[substitute(gt, Subst(var => obj))
for obj in get_objects(domain, state, type)]
end
ground_terms = reduce(vcat, ground_terms; init=Term[])
function dequantify(term::Term, domain::Domain, state::State,
statics=infer_static_fluents(domain))
if is_quantifier(term)
conds, query = flatten_conjs(term.args[1]), term.args[2]
query = dequantify(query, domain, state, statics)
# Dequantify by type if no static fluents
if statics === nothing || isempty(statics)
return dequantify_by_type(term.name, conds, query, domain, state)
else # Dequantify by static conditions otherwise
return dequantify_by_stat_conds(term.name, conds, query,
domain, state, statics)
end
op = term.name == :forall ? :and : :or
return Compound(op, ground_terms)
elseif term.name in (:and, :or, :imply, :not, :when)
args = Term[dequantify(arg, domain, state) for arg in term.args]
args = Term[dequantify(arg, domain, state, statics)
for arg in term.args]
return Compound(term.name, args)
else
return term
end
end

"Dequantifies a quantified expression by the types it is quantified over."
function dequantify_by_type(
name::Symbol, typeconds::Vector{Term}, query::Term,
domain::Domain, state::State
)
# Accumulate list of ground terms
stack = Term[]
subterms = Term[query]
for cond in typeconds
# Swap array references
stack, subterms = subterms, stack
# Substitute all objects of each type
type, var = cond.name, cond.args[1]
objects = get_objects(domain, state, type)
while !isempty(stack)
term = pop!(stack)
for obj in objects
push!(subterms, substitute(term, Subst(var => obj)))
end
end
end
# Return conjunction / disjunction of ground terms
if name == :forall
return isempty(subterms) ? Const(true) : Compound(:and, subterms)
else # name == :exists
return isempty(subterms) ? Const(false) : Compound(:or, subterms)
end
end

"Dequantifies a quantified expression via static satisfaction (where useful)."
function dequantify_by_stat_conds(
name::Symbol, conds::Vector{Term}, query::Term,
domain::Domain, state::State, statics::Vector{Symbol}
)
vars = Var[c.args[1] for c in conds]
types = Symbol[c.name for c in conds]
# Determine conditions that potentially restrict dequantification
if name == :forall
extra_conds = query.name in (:when, :imply) ? query.args[1] : Term[]
else # name == :exists
extra_conds = query
end
# Add static conditions
for c in flatten_conjs(extra_conds)
c.name in statics || continue
push!(conds, c)
end
# Default to dequantifying by types if no static conditions were added
if length(conds) == length(vars)
return dequantify_by_type(name, conds, query, domain, state)
end
# Check if static conditions actually restrict the number of groundings
substs = satisfiers(domain, state, conds)
if prod(get_object_count(domain, state, ty) for ty in types) < length(substs)
conds = resize!(conds, length(vars))
return dequantify_by_type(name, conds, query, domain, state)
end
# Accumulate list of ground terms
subterms = Term[]
for s in substs
length(s) > length(vars) && filter!(p -> first(p) in vars, s)
push!(subterms, substitute(query, s))
end
# Return conjunction / disjunction of ground terms
if name == :forall
return isempty(subterms) ? Const(true) : Compound(:and, subterms)
else # name == :exists
return isempty(subterms) ? Const(false) : Compound(:or, subterms)
end
end

"""
conds, effects = flatten_conditions(term::Term)
Expand Down
8 changes: 8 additions & 0 deletions src/interface/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ function get_all_subtypes(domain::Domain, type::Symbol)
return reduce(vcat, [get_all_subtypes(domain, ty) for ty in subtypes],
init=subtypes)
end

"Returns number of objects of particular type."
get_object_count(domain::Domain, state::State, type::Symbol) =
length(get_objects(domain, state, type))

"Returns number of objects of each type as a dictionary."
get_object_counts(domain::Domain, state::State) =
Dict(ty => get_object_count(domain, state, ty) for ty in get_types(domain))

0 comments on commit d268a96

Please sign in to comment.