diff --git a/faces.lua b/faces.lua index 11c7a95..f321c1a 100644 --- a/faces.lua +++ b/faces.lua @@ -31,12 +31,48 @@ local faces = {} -- STATIC FUNCTIONS -- ---------------------- +-- inspired by Penlight string_lambda: +-- http://stevedonovan.github.io/Penlight/api/libraries/pl.utils.html#string_lambda +local lambda +do + local memory = {} + lambda = setmetatable({ clear = function() memory = {} end, }, + { + __call = function(self,fstr) + local fun = memory[fstr] + if not fun then + local args,code = fstr:match("^%s*|(.+)|(.+)$") + assert(args and code, "Needs args and code sections, e.g., |args|code") + local fun_src = ("return function(%s) return %s end"):format(args,code) + fun = assert(loadstring(fun_src))() + memory[fstr] = fun + end + return fun + end, + }) +end + +-- transform a user function string into a lambda function +local function lambda_transform(func_str) + return function(vars) + local values,keys = {},{} + for k,v in pairs(vars) do + values[#values+1] = v + keys[#keys+1] = k + end + local code = ('|%s|%s'):format(table.concat(keys,","),func_str) + local f = lambda(code) + return f(table.unpack(values)) + end +end + -- converts to in-mutable the given table argument local function inmutable(tbl) return setmetatable({}, { __index = function(_,k) return tbl[k] end, __newindex = function() error("Unable to modify an in-mutable table") end, __len = function() return #tbl end, + __pairs = function() return pairs(tbl) end, }) end @@ -131,15 +167,16 @@ do end assign_variables = function(self, vars, patterns, sequence, var_matches, - user_clauses) + user_clauses, fact_vars) for i,pat in ipairs(patterns) do local fid = sequence[i] local fact = self.fact_list[fid] if not assign_fact_vars(vars, pat, fact, var_matches) then return false end end + for vname,i in pairs(fact_vars) do vars[vname] = sequence[i] end local inmutable_vars = inmutable(vars) for _,func in ipairs(user_clauses) do - if not func(sequence, inmutable_vars) then return false end + if not func(inmutable_vars) then return false end end return true end @@ -186,7 +223,8 @@ local function regenerate_agenda(self) if not rule_entailements[sequence] then local seq_vars = {} if assign_variables(self, seq_vars, rule.patterns, sequence, - rule.var_matches, rule.user_clauses) then + rule.var_matches, rule.user_clauses, + rule.fact_vars) then table.insert(combinations, sequence) table.insert(variables, seq_vars) end @@ -231,7 +269,7 @@ local function fire_rule(self, rule_name, args, vars) for i,v in ipairs(args) do self.fact_entailment[v] = args end -- execute rule actions for _,action in ipairs(rule.actions) do - action(args, inmutable(vars)) + action(inmutable(vars)) end end @@ -464,14 +502,27 @@ end -- declares a new rule in the knowledge base function faces_methods:defrule(rule_name) local rule = { patterns={}, user_clauses = {}, - actions={}, salience=0, var_matches = {} } + actions={}, salience=0, var_matches = {}, fact_vars = {} } self.kb_table[rule_name] = rule - local rule_builder = { + local rule_builder + rule_builder = { pattern = function(rule_builder, pattern) table.insert(rule.patterns, tuple(pattern)) return rule_builder end, + var = function(rule_builder, varname) + varname = assert(varname:match("%?([^%s]+)"), + string.format("Incorrect variable name: %s", varname)) + rule.fact_vars[varname] = #rule.patterns + 1 + return { + pattern = function(_,...) + return rule_builder.pattern(rule_builder,...) + end + } + end, u = function(rule_builder, func) + -- user_func receives one argument: vars + if type(func) == "string" then func = lambda_transform(func) end table.insert(rule.user_clauses, func) return rule_builder end, @@ -502,7 +553,8 @@ function faces_methods:defrule(rule_name) __index = function(rule_builder, key) if key == "u" then return function(rule_builder, user_func) - -- user_func receives two arguments (fact_ids, vars) + -- user_func receives one argument: vars + if type(func) == "string" then func = lambda_transform(func) end table.insert(rule.actions, user_func) return rule_builder end @@ -514,7 +566,7 @@ function faces_methods:defrule(rule_name) local args = table.pack(...) for i=1,args.n do args[i] = tuple(args[i]) end table.insert(rule.actions, - function(fact_ids, vars) + function(vars) local new_args = replace_variables(args, vars) return self[key](self, table.unpack(new_args)) end) diff --git a/tests/animals.lua b/tests/animals.lua index 1c446ef..d824af2 100644 --- a/tests/animals.lua +++ b/tests/animals.lua @@ -78,11 +78,11 @@ kb:defrule("Error"): -- Retracts error rule in case any classification is asserted kb:defrule("RetractError"): salience(100): - pattern{ "ERROR" }: + var("?f1"):pattern{ "ERROR" }: pattern{ ANIMAL_IS, ".*" }: ENTAILS("=>"): - u(function(fact_ids, vars) - kb:retract(fact_ids[1]) + u(function(vars) + kb:retract(vars.f1) end) -- Initial rule, asks the first question @@ -149,7 +149,7 @@ kb:defrule("Show"): salience(-10): pattern{ ANIMAL_IS, "?x" }: ENTAILS("=>"): - u(function(fact_ids,vars) + u(function(vars) print("The animal is a: " .. vars.x) end) diff --git a/tests/factorial.lua b/tests/factorial.lua index 8e20c6e..5cb85ad 100644 --- a/tests/factorial.lua +++ b/tests/factorial.lua @@ -4,11 +4,11 @@ local kb = faces() kb:defrule("expand"): salience(100): pattern{ "Factorial", "?x" }: - u(function(fact_ids, vars) + u(function(vars) return vars.x > 1 end): ENTAILS("=>"): - u(function(fact_ids, vars) + u(function(vars) kb:fassert{ "Factorial", vars.x-1 } end) @@ -16,11 +16,11 @@ kb:defrule("compute"): salience(0): pattern{ "Factorial", "?x" }: pattern{ "Result", "?y", "?z" }: - u(function(fact_ids, vars) + u(function(vars) return vars.x == vars.y+1 end): ENTAILS("=>"): - u(function(fact_ids, vars) + u(function(vars) kb:fassert{ "Result", vars.x, vars.x * vars.z } end) diff --git a/tests/test.lua b/tests/test.lua index 45f234d..9a42e3f 100644 --- a/tests/test.lua +++ b/tests/test.lua @@ -58,12 +58,12 @@ kb:fassert{ "duck sound", "poww" } kb:defrule("duck2"): salience(100): - pattern{ "duck sound", { "?name1", "?name2" } }: + var("?f1"):pattern{ "duck sound", { "?name1", "?name2" } }: match("?name1", "q.*"): ENTAILS("=>"): fassert{ "sound is", { "?name1", "?name2" } }: - u(function(fact_ids, vars) - print("DEBUG", kb:consult(fact_ids[1]), vars.name1, vars.name2) + u(function(vars) + print("DEBUG", kb:consult(vars.f1), vars.name1, vars.name2) end) kb:defrule("init"): @@ -99,7 +99,7 @@ kb:facts() kb:defrule("MultiValuated"): pattern{ "duck sound", "$?p" }: ENTAILS("=>"): - u(function(fact_ids, vars) + u(function(vars) print(vars.p) end) kb:agenda() diff --git a/tests/test2.lua b/tests/test2.lua index 7f9f66f..6db23f1 100644 --- a/tests/test2.lua +++ b/tests/test2.lua @@ -12,9 +12,7 @@ kb:defrule("user"): pattern{ "AnimalIs", "?x" }: numeric("?y"): numeric("?z"): - u(function(fact_ids, vars) - return (vars.z == vars.y*4) and (vars.y % 2)==0 - end): + u("(z == y*4) and (y%2)==0"): ENTAILS("=>"): fassert{ "EvenAnimal", "?x", "?y", "?z" }