Skip to content

Commit

Permalink
store State on the heap
Browse files Browse the repository at this point in the history
  • Loading branch information
wartman4404 committed Mar 8, 2016
1 parent b023a88 commit ae0616d
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 52 deletions.
2 changes: 1 addition & 1 deletion examples/quine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ fn repl(mut state: State) {
fresh!(state, result);
//let path = state.make_var_of(Nil);
//let newpath = state.make_var();
let rcstate = Rc::new(state);
let rcstate = Rc::new(state.unwrap());
let tmpstate = State::with_parent(rcstate.clone());
let mut iter = eval(tmpstate, x, env, result);
if let Some(new_state) = iter.into_iter().next() {
Expand Down
3 changes: 2 additions & 1 deletion src/constraints.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::ops::{Add, Sub};
use core::{ToVar, ToConstraint, Constraint, Var, StateProxy, State, ConstraintResult, VarStore, Unifier, VarRetrieve, VarMap, UntypedVar, VarWrapper};
use core::{ToVar, ToConstraint, Constraint, Var, StateProxy, ConstraintResult, VarStore, Unifier, VarRetrieve, VarMap, UntypedVar, VarWrapper};
use core::StateInner as State;
use core::ConstraintResult::*;
use finitedomain::Fd;
use finitedomain::Fd::*;
Expand Down
3 changes: 2 additions & 1 deletion src/core/disequal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::any::TypeId;
use std::fmt::{self, Debug, Formatter};
use core::{ConstraintResult, StateProxy, UntypedVar, Constraint, State, Var, ToVar, FollowRef, Unifier, VarMap, VarWrapper};
use core::{ConstraintResult, StateProxy, UntypedVar, Constraint, Var, ToVar, FollowRef, Unifier, VarMap, VarWrapper};
use core::StateInner as State;
use core::ConstraintResult::*;
use core::ExactVarRef::*;

Expand Down
14 changes: 7 additions & 7 deletions src/core/get_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::cmp::Ordering::*;
use std::collections::{BTreeSet, HashMap};
use std::collections::hash_map::Entry::*;
use std::rc::Rc;
use core::{UntypedVar, State, FollowRef, VarWrapper, Unifier};
use core::{UntypedVar, State, FollowRef, VarWrapper, Unifier, StateInner};
use core::ExactVarRef::*;
use iter::{StateIter, single, TailIter};

Expand All @@ -25,7 +25,7 @@ impl PartialOrd for CountedVar {

type VarWrapperIter = Box<Iterator<Item=Box<VarWrapper>>>;

fn value_iter(state: Rc<State>, var: UntypedVar, mut iter: VarWrapperIter) -> TailIter {
fn value_iter(state: Rc<StateInner>, var: UntypedVar, mut iter: VarWrapperIter) -> TailIter {
use iter::{TailIterResult, wrap_fn};
wrap_fn(move || {
while let Some(x) = iter.next() {
Expand All @@ -40,18 +40,18 @@ fn value_iter(state: Rc<State>, var: UntypedVar, mut iter: VarWrapperIter) -> Ta
}

struct ParentStateIter<'a> {
state: Option<&'a State>,
state: Option<&'a StateInner>,
}

impl<'a> ParentStateIter<'a> {
fn new(state: &'a State) -> ParentStateIter {
ParentStateIter { state: Some(state) }
ParentStateIter { state: Some(&**state) }
}
}

impl<'a> Iterator for ParentStateIter<'a> {
type Item = &'a State;
fn next(&mut self) -> Option<&'a State> {
type Item = &'a StateInner;
fn next(&mut self) -> Option<&'a StateInner> {
let result = self.state;
self.state = self.state.and_then(|s| {
s.parent.as_ref().map(|x| &**x)
Expand Down Expand Up @@ -137,7 +137,7 @@ fn assign_values_inner(state: State, mut counted: BTreeSet<CountedVar>, mut vars
Some(x) => x,
None => { return assign_values_inner(state, counted, vars); },
};
let iter = TailIterResult(None, Some(value_iter(Rc::new(state), var.0, val)));
let iter = TailIterResult(None, Some(value_iter(Rc::new(state.unwrap()), var.0, val)));

iter
.and(move |state| {
Expand Down
96 changes: 68 additions & 28 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,51 @@ use std::any::*;
use std::raw::TraitObject;
use std::collections::HashSet;
use std::mem;
use std::ops::{Deref, DerefMut};

//use core::ExactVal::*;
//use core::VarRef::*;
use core::ExactVarRef::*;

#[derive(Debug)]
pub struct State(Box<StateInner>);

impl Deref for State {
type Target = StateInner;
fn deref(&self) -> &StateInner { &self.0 }
}
impl DerefMut for State {
fn deref_mut(&mut self) -> &mut StateInner { &mut self.0 }
}
impl State {
pub fn new() -> State {
State(Box::new(StateInner::new()))
}
pub fn with_parent(parent: Rc<StateInner>) -> State {
State(Box::new(StateInner::with_parent(parent)))
}
pub fn from_inner(state: StateInner) -> State {
State(Box::new(state))
}
pub fn unwrap(self) -> StateInner {
let state: StateInner = *(self.0);
state
}
}
impl VarStore for State {
fn store_value<A>(&mut self, value: A) -> Var<A>
where A : VarWrapper + 'static { self.0.store_value(value) }
fn make_var<A>(&mut self) -> Var<A> where A : VarWrapper { self.0.make_var() }
}
impl VarRetrieve for State {
fn get_value<A>(&self, a: Var<A>) -> Option<&A> where A : VarWrapper { self.0.get_value(a) }
fn get_untyped(&self, var: UntypedVar) -> Option<&VarWrapper> { self.0.get_untyped(var) }
}
impl Unifier for State {
fn unify_vars<A>(&mut self, a: Var<A>, b: Var<A>) -> &mut Self where A : VarWrapper { self.0.unify_vars(a, b); self }
fn fail(&mut self) -> &mut Self { self.0.fail(); self }
fn ok(&self) -> bool { self.0.ok() }
}

///! Wrapper for a usize, used as a unique variable identifier.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
Expand All @@ -42,13 +82,13 @@ pub struct Var<A: VarWrapper> { var: UntypedVar, var_type: PhantomData<A> }
impl<A> Clone for Var<A> where A: VarWrapper { fn clone(&self) -> Var<A> { *self } }
impl<A> Copy for Var<A> where A: VarWrapper { }

///! State is the heart of rust_kanren. It tracks all variable substitutions added by calling
///! StateInner is the heart of rust_kanren. It tracks all variable substitutions added by calling
///! `unify()`, hands out Vars through `make_var()` and `make_var_of()`, and tracks whether any
///! unifications have failed.
pub struct State {
pub struct StateInner {
eqs: VarMap,
// TODO: use a reference instead
parent: Option<Rc<State>>,
parent: Option<Rc<StateInner>>,
constraints: ConstraintStore,
proxy_eqs: VarMap,
}
Expand All @@ -57,7 +97,7 @@ pub struct State {
///! unification, which is necessary for constraints.
#[derive(Debug)]
pub struct StateProxy<'a> {
parent: &'a mut State,
parent: &'a mut StateInner,
}

///! Enum representing the possible outcomes of `Constraint::update()`.
Expand Down Expand Up @@ -85,22 +125,22 @@ pub trait Constraint: Debug + Sized {
fn relevant(&self, _: &VarMap) -> bool;
///! Called to update a constraint's variables when unification has assigned them to be equal
///! to other variables. Should call `state.update_var()` for each variable in the constraint.
fn update_vars(&mut self, _: &State);
fn update_vars(&mut self, _: &StateInner);
///! (Optional) Called to determine whether `update_vars()` needs to be called. Should call
///! `varmap.need_update()` for each variable in the constraint.
fn need_update(&self, vars: &VarMap) -> bool { self.relevant(vars) }
}

///! Trait for creating a `Constraint`, given a `State`.
///! Trait for creating a `Constraint`, given a `StateInner`.
pub trait ToConstraint {
type ConstraintType: Constraint + 'static + Clone;
fn into_constraint(self, state: &mut State) -> Self::ConstraintType;
fn into_constraint(self, state: &mut StateInner) -> Self::ConstraintType;
}

trait BoxedConstraint: Debug {
fn update(&self, _: &mut StateProxy) -> ConstraintResult<Box<BoxedConstraint>>;
fn relevant(&self, _: &VarMap) -> bool;
fn update_vars(&mut self, _: &State);
fn update_vars(&mut self, _: &StateInner);
fn need_update(&self, vars: &VarMap) -> bool;
fn clone_boxed(&self) -> Box<BoxedConstraint>;
}
Expand All @@ -120,7 +160,7 @@ impl<A> BoxedConstraint for ConstraintWrapper<A> where A : Constraint + Clone +
}
}
fn relevant(&self, vars: &VarMap) -> bool { self.0.relevant(vars) }
fn update_vars(&mut self, vars: &State) { self.0.update_vars(vars) }
fn update_vars(&mut self, vars: &StateInner) { self.0.update_vars(vars) }
fn need_update(&self, vars: &VarMap) -> bool { self.0.need_update(vars) }
fn clone_boxed(&self) -> Box<BoxedConstraint> {
Box::new(ConstraintWrapper(self.0.clone()))
Expand Down Expand Up @@ -421,7 +461,7 @@ impl ExactVal {
//}
}

impl VarRetrieve for State {
impl VarRetrieve for StateInner {
fn get_value<A: VarWrapper>(&self, var: Var<A>) -> Option<&A> {
self.get_exact_val_opt(var.var)
}
Expand All @@ -430,13 +470,13 @@ impl VarRetrieve for State {
}
}

impl VarStore for State {
impl VarStore for StateInner {
fn store_value<A>(&mut self, value: A) -> Var<A>
where A : VarWrapper + 'static { self.eqs.store_value(value) }
fn make_var<A>(&mut self) -> Var<A> where A : VarWrapper { self.eqs.make_var() }
}

impl FollowRef for State {
impl FollowRef for StateInner {
fn get_ref(&self, id: UntypedVar) -> &VarRef {
let mut state = self;
loop {
Expand All @@ -453,14 +493,14 @@ impl FollowRef for State {
}
}

impl Unifier for State {
fn unify_vars<A>(&mut self, a: Var<A>, b: Var<A>) -> &mut State
impl Unifier for StateInner {
fn unify_vars<A>(&mut self, a: Var<A>, b: Var<A>) -> &mut StateInner
where A : VarWrapper {
self.untyped_unify(a.var, b.var, TypeId::of::<A>(), needs_occurs_check::<A>());
self
}

fn fail(&mut self) -> &mut State {
fn fail(&mut self) -> &mut StateInner {
self.eqs.ok = false;
self
}
Expand All @@ -470,23 +510,23 @@ impl Unifier for State {
}
}

impl State {
impl StateInner {

///! Create a State with no substitutions and no parent.
pub fn new() -> State {
State {
///! Create a StateInner with no substitutions and no parent.
pub fn new() -> StateInner {
StateInner {
eqs: VarMap::new(),
parent: None,
constraints: ConstraintStore::new(),
proxy_eqs: VarMap::new(),
}
}

///! Create a State which builds on a parent State. This is essential for backtracking: no
///! steps are needed to return to an earlier point beyond dropping the child State.
pub fn with_parent(parent: Rc<State>) -> State {
///! Create a StateInner which builds on a parent StateInner. This is essential for backtracking: no
///! steps are needed to return to an earlier point beyond dropping the child StateInner.
pub fn with_parent(parent: Rc<StateInner>) -> StateInner {
let constraints = parent.constraints.clone();
State {
StateInner {
eqs: VarMap::with_parent(&parent.eqs),
parent: Some(parent.clone()),
constraints: constraints,
Expand Down Expand Up @@ -637,7 +677,7 @@ impl State {
}
}

///! Return type for `State::are_vars_unified`.
///! Return type for `StateInner::are_vars_unified`.
#[derive(Eq, PartialEq, Copy, Clone)]
pub enum Unifiability {
Impossible,
Expand Down Expand Up @@ -687,7 +727,7 @@ impl<'a> Unifier for StateProxy<'a> {
}

impl<'a> StateProxy<'a> {
fn new(parent: &'a mut State) -> StateProxy<'a> {
fn new(parent: &'a mut StateInner) -> StateProxy<'a> {
parent.proxy_eqs.id = parent.eqs.id;
parent.proxy_eqs.ok = parent.eqs.ok;
StateProxy { parent: parent }
Expand Down Expand Up @@ -987,10 +1027,10 @@ where A : VarWrapper {
write!(fmt, "Var({})", self.var.0)
}
}
impl Debug for State {
impl Debug for StateInner {
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {

fn debug_var_ref(me: &State, val: &VarRef, fmt: &mut Formatter) -> fmt::Result {
fn debug_var_ref(me: &StateInner, val: &VarRef, fmt: &mut Formatter) -> fmt::Result {
match val.split() {
Err(x) => write!(fmt, "EqualTo({:?})", x),
Ok(x) => {
Expand All @@ -1006,7 +1046,7 @@ impl Debug for State {
}
}

try!(writeln!(fmt, "State {{"));
try!(writeln!(fmt, "StateInner {{"));
try!(writeln!(fmt, "\tid: {:?}", self.eqs.id));
try!(writeln!(fmt, "\tok: {:?}", self.eqs.ok));
try!(writeln!(fmt, "\tproxy.id: {:?}", self.proxy_eqs.id));
Expand Down
6 changes: 3 additions & 3 deletions src/finitedomain.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use finitedomain::Fd::*;
use std::collections::HashSet;
use core::{VarWrapper, StateProxy, Var, ToVar, VarStore, VarRetrieve, State, Unifier, UnifyResult};
use core::{VarWrapper, StateProxy, Var, ToVar, VarStore, VarRetrieve, State, Unifier, UnifyResult, StateInner};
use iter::{StateIter, single};
use std::rc::Rc;
use iter::{TailIter, TailIterResult};
Expand Down Expand Up @@ -195,12 +195,12 @@ where A: ToVar<VarType=Fd>, B: ToVar<VarType=usize> {
None => { single(state) },
Some(Values(values)) => {
let valiter = values.into_iter();
TailIterResult(None, Some(fd_value_iter(Rc::new(state), fd, valiter, u)))
TailIterResult(None, Some(fd_value_iter(Rc::new(state.unwrap()), fd, valiter, u)))
}
}
}

fn fd_value_iter(state: Rc<State>, fd: Var<Fd>, mut vals: ::std::vec::IntoIter<usize>, u: Var<usize>) -> TailIter {
fn fd_value_iter(state: Rc<StateInner>, fd: Var<Fd>, mut vals: ::std::vec::IntoIter<usize>, u: Var<usize>) -> TailIter {
use iter::wrap_fn;
wrap_fn(move || {
while let Some(x) = vals.next() {
Expand Down
18 changes: 10 additions & 8 deletions src/iter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use core::{ToVar, State, Var, Unifier, VarRetrieve, VarWrapper};
use core::{ToVar, State, Var, Unifier, VarRetrieve, VarWrapper, StateInner};
use std::rc::Rc;
use std::marker::PhantomData;
use std::any::*;
Expand Down Expand Up @@ -202,7 +202,7 @@ pub type StateIter = TailIterResult;
struct StateFnIter<F>
where F: Fn(usize, State) -> StateIter + 'static {
f: F,
state: Rc<State>,
state: Rc<StateInner>,
len: usize,
pos: usize,
}
Expand Down Expand Up @@ -230,7 +230,7 @@ where F: Fn(usize, State) -> StateIter + 'static {
pub fn conde(self, state: State) -> StateIter {
if !state.ok() { return TailIterResult(None, None); }
let chain = VecDeque::with_capacity(self.len);
let iter = StateFnIter { f: self.f, state: Rc::new(state), len: self.len, pos: 0 };
let iter = StateFnIter { f: self.f, state: Rc::new(state.unwrap()), len: self.len, pos: 0 };
TailIterResult(None, Some(Box::new(ChainManyIter { iter: Some(Box::new(iter)), chain: chain })))
}

Expand All @@ -244,7 +244,7 @@ where F: Fn(usize, State) -> StateIter + 'static {

fn condau(self, state: State, return_more: bool) -> StateIter {
if !state.ok() { return TailIterResult(None, None); }
let iter = StateFnIter { f: self.f, state: Rc::new(state), len: self.len, pos: 0 };
let iter = StateFnIter { f: self.f, state: Rc::new(state.unwrap()), len: self.len, pos: 0 };
TailIterResult(None, Some(Box::new(CondaIter { iter: Box::new(iter), return_more: return_more })))
}
}
Expand Down Expand Up @@ -303,14 +303,14 @@ impl StateIterExt for StateIter {
///! Helper to find all results for a given state and iterator.
pub struct FindAll<F>
where F: Fn(State) -> StateIter + 'static {
state: Rc<State>,
state: Rc<StateInner>,
f: F,
}

impl<F> FindAll<F>
where F: Fn(State) -> StateIter + 'static {
pub fn new(state: State, f: F) -> FindAll<F> {
let state = Rc::new(state);
let state = Rc::new(state.unwrap());
FindAll { state: state, f: f }
}

Expand All @@ -324,7 +324,9 @@ where F: Fn(State) -> StateIter + 'static {

///! Retrieve the wrapped state, destroying the FindAll.
pub fn state(self) -> State {
Rc::try_unwrap(self.state).unwrap_or_else(|state| State::with_parent(state.clone()))
Rc::try_unwrap(self.state)
.map(State::from_inner)
.unwrap_or_else(|state| State::with_parent(state.clone()))
}
}

Expand Down Expand Up @@ -361,7 +363,7 @@ where F: Fn(State) -> StateIter + 'static,

let mut list = state.make_var_of(list);
let var = state.make_var_of(var);
let state = Rc::new(state);
let state = Rc::new(state.unwrap());
let mut return_state = State::with_parent(state.clone());
let findall_state = State::with_parent(state);
for state in FindAll::new(findall_state, state_fn).iter() {
Expand Down
Loading

0 comments on commit ae0616d

Please sign in to comment.