diff --git a/Project.toml b/Project.toml index 846a45a..1f44fa5 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ ValSplit = "0625e100-946b-11ec-09cd-6328dd093154" [compat] AutoHashEquals = "0.2, 1.0" IntervalArithmetic = "0.17, 0.18, 0.19, 0.20" -Julog = "0.1.14" +Julog = "0.1.15" ParserCombinator = "2.0.0" ValSplit = "0.1" julia = "1.3" diff --git a/src/grounding/action.jl b/src/grounding/action.jl index 0fb37e7..9eb7c19 100644 --- a/src/grounding/action.jl +++ b/src/grounding/action.jl @@ -32,8 +32,11 @@ end get_name(action::GroundActionGroup) = action.name +"Maximum limit for grounding by enumerating over typed objects." const MAX_GROUND_BY_TYPE_LIMIT = 250 -const MIN_GROUND_BY_PRECOND_LIMIT = 1 + +"Minimum number of static conditions for grounding by static satisfaction." +const MIN_GROUND_BY_STATIC_LIMIT = 1 "Returns an iterator over all ground arguments of an `action`." function groundargs(domain::Domain, state::State, action::Action; @@ -46,10 +49,10 @@ function groundargs(domain::Domain, state::State, action::Action; preconds = flatten_conjs(get_precond(action)) filter!(p -> p.name in statics, preconds) # Decide whether to generate by satisfying static preconditions - n_groundings = prod(length(get_objects(domain, state, ty)) + n_groundings = prod(get_object_count(domain, state, ty) for ty in get_argtypes(action)) use_preconds = n_groundings > MAX_GROUND_BY_TYPE_LIMIT && - length(preconds) >= MIN_GROUND_BY_PRECOND_LIMIT + length(preconds) >= MIN_GROUND_BY_STATIC_LIMIT if use_preconds # Filter using preconditions # Add type conditions for correctness act_vars, act_types = get_argvars(action), get_argtypes(action) @@ -75,8 +78,10 @@ function groundactions(domain::Domain, state::State, action::Action; statics=infer_static_fluents(domain)) ground_acts = GroundAction[] # Dequantify and flatten preconditions and effects - precond = to_nnf(dequantify(get_precond(action), domain, state)) - effects = flatten_conditions(dequantify(get_effect(action), domain, state)) + precond = to_nnf(dequantify(get_precond(action), + domain, state, statics)) + effects = flatten_conditions(dequantify(get_effect(action), + domain, state, statics)) # Iterate over possible groundings for args in groundargs(domain, state, action; statics=statics) # Construct ground action for each set of arguments @@ -99,8 +104,8 @@ end Returns all ground actions for a `domain` and initial `state`. """ -function groundactions(domain::Domain, state::State) - statics = infer_static_fluents(domain) +function groundactions(domain::Domain, state::State; + statics=infer_static_fluents(domain)) iters = (groundactions(domain, state, act; statics=statics) for act in values(get_actions(domain))) return collect(Iterators.flatten(iters)) @@ -114,8 +119,10 @@ is never satisfiable given the `domain` and `state`, return `nothing`. """ function ground(domain::Domain, state::State, action::Action, args; statics=infer_static_fluents(domain), - precond=to_nnf(dequantify(get_precond(action), domain, state)), - effects=flatten_conditions(dequantify(get_effect(action), domain, state)) + precond=to_nnf(dequantify(get_precond(action), + domain, state, statics)), + effects=flatten_conditions(dequantify(get_effect(action), + domain, state, statics)) ) act_name = get_name(action) act_vars = get_argvars(action) diff --git a/src/grounding/utils.jl b/src/grounding/utils.jl index dcd2d27..f4dcc08 100644 --- a/src/grounding/utils.jl +++ b/src/grounding/utils.jl @@ -18,29 +18,98 @@ end Replaces universally or existentially quantified expressions with their corresponding conjuctions or disjunctions over the object types they quantify. """ -function dequantify(term::Term, domain::Domain, state::State) - if term.name in (:forall, :exists) - typeconds, query = flatten_conjs(term.args[1]), term.args[2] - query = dequantify(query, domain, state) - ground_terms = Term[query] - for cond in typeconds - type, var = cond.name, cond.args[1] - ground_terms = map(ground_terms) do gt - return Term[substitute(gt, Subst(var => obj)) - for obj in get_objects(domain, state, type)] - end - ground_terms = reduce(vcat, ground_terms; init=Term[]) +function dequantify(term::Term, domain::Domain, state::State, + statics=infer_static_fluents(domain)) + if is_quantifier(term) + conds, query = flatten_conjs(term.args[1]), term.args[2] + query = dequantify(query, domain, state, statics) + # Dequantify by type if no static fluents + if statics === nothing || isempty(statics) + return dequantify_by_type(term.name, conds, query, domain, state) + else # Dequantify by static conditions otherwise + return dequantify_by_stat_conds(term.name, conds, query, + domain, state, statics) end - op = term.name == :forall ? :and : :or - return Compound(op, ground_terms) elseif term.name in (:and, :or, :imply, :not, :when) - args = Term[dequantify(arg, domain, state) for arg in term.args] + args = Term[dequantify(arg, domain, state, statics) + for arg in term.args] return Compound(term.name, args) else return term end end +"Dequantifies a quantified expression by the types it is quantified over." +function dequantify_by_type( + name::Symbol, typeconds::Vector{Term}, query::Term, + domain::Domain, state::State +) + # Accumulate list of ground terms + stack = Term[] + subterms = Term[query] + for cond in typeconds + # Swap array references + stack, subterms = subterms, stack + # Substitute all objects of each type + type, var = cond.name, cond.args[1] + objects = get_objects(domain, state, type) + while !isempty(stack) + term = pop!(stack) + for obj in objects + push!(subterms, substitute(term, Subst(var => obj))) + end + end + end + # Return conjunction / disjunction of ground terms + if name == :forall + return isempty(subterms) ? Const(true) : Compound(:and, subterms) + else # name == :exists + return isempty(subterms) ? Const(false) : Compound(:or, subterms) + end +end + +"Dequantifies a quantified expression via static satisfaction (where useful)." +function dequantify_by_stat_conds( + name::Symbol, conds::Vector{Term}, query::Term, + domain::Domain, state::State, statics::Vector{Symbol} +) + vars = Var[c.args[1] for c in conds] + types = Symbol[c.name for c in conds] + # Determine conditions that potentially restrict dequantification + if name == :forall + extra_conds = query.name in (:when, :imply) ? query.args[1] : Term[] + else # name == :exists + extra_conds = query + end + # Add static conditions + for c in flatten_conjs(extra_conds) + c.name in statics || continue + push!(conds, c) + end + # Default to dequantifying by types if no static conditions were added + if length(conds) == length(vars) + return dequantify_by_type(name, conds, query, domain, state) + end + # Check if static conditions actually restrict the number of groundings + substs = satisfiers(domain, state, conds) + if prod(get_object_count(domain, state, ty) for ty in types) < length(substs) + conds = resize!(conds, length(vars)) + return dequantify_by_type(name, conds, query, domain, state) + end + # Accumulate list of ground terms + subterms = Term[] + for s in substs + length(s) > length(vars) && filter!(p -> first(p) in vars, s) + push!(subterms, substitute(query, s)) + end + # Return conjunction / disjunction of ground terms + if name == :forall + return isempty(subterms) ? Const(true) : Compound(:and, subterms) + else # name == :exists + return isempty(subterms) ? Const(false) : Compound(:or, subterms) + end +end + """ conds, effects = flatten_conditions(term::Term) diff --git a/src/interface/utils.jl b/src/interface/utils.jl index eb93361..2748b52 100644 --- a/src/interface/utils.jl +++ b/src/interface/utils.jl @@ -20,3 +20,11 @@ function get_all_subtypes(domain::Domain, type::Symbol) return reduce(vcat, [get_all_subtypes(domain, ty) for ty in subtypes], init=subtypes) end + +"Returns number of objects of particular type." +get_object_count(domain::Domain, state::State, type::Symbol) = + length(get_objects(domain, state, type)) + +"Returns number of objects of each type as a dictionary." +get_object_counts(domain::Domain, state::State) = + Dict(ty => get_object_count(domain, state, ty) for ty in get_types(domain))