Skip to content

Commit

Permalink
cond synth: add features module
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 10, 2024
1 parent 6f5e9b5 commit 535eeef
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 171 deletions.
174 changes: 174 additions & 0 deletions tools/egraphs-cond-synth/src/features.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Copyright 2024 Cornell University
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>

use crate::samples::{RuleInfo, Samples};
use bitvec::prelude as bv;
use egg::Var;
use patronus::expr::WidthInt;
use rustc_hash::FxHashMap;

/// Applies the features from the ROVER paper to all assignments and returns the result.
pub fn apply_features(rule: &RuleInfo, samples: &Samples) -> FeatureResult {
let labels = get_labels(rule);
let mut results = bv::BitVec::new();

for (assignment, is_equal) in samples.iter() {
let v = FxHashMap::from_iter(assignment);
results.push(is_equal);
for feature in FEATURES.iter() {
(feature.eval)(rule, &v, &mut results);
}
}

FeatureResult { labels, results }
}

pub struct FeatureResult {
labels: Vec<String>,
results: bv::BitVec,
}

impl FeatureResult {
pub fn num_features(&self) -> usize {
self.labels.len()
}
pub fn labels(&self) -> &[String] {
&self.labels
}
pub fn iter(&self) -> impl Iterator<Item = (&bv::BitSlice, bool)> + '_ {
let cs = self.num_features() + 1;
self.results.chunks(cs).map(|c| {
let is_equivalent = c[0];
let features = &c[1..];
(features, is_equivalent)
})
}
}

fn get_labels(rule: &RuleInfo) -> Vec<String> {
FEATURES
.iter()
.map(|f| (f.labels)(rule))
.reduce(|mut a, mut b| {
a.append(&mut b);
a
})
.unwrap_or_default()
}

const FEATURES: &[Feature] = &[
Feature {
name: "is_unsigned", // (13)
labels: |r| {
let mut o = vec![];
for sign in r.signs() {
o.push(format!("!{sign}"));
}
o
},
eval: |r, v, o| {
for sign in r.signs() {
// s_i == unsign
o.push(v[&sign] == 0);
}
},
},
Feature {
name: "is_width_equal", // (14)
labels: |r| {
let mut o = vec![];
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
o.push(format!("{w_i} == {w_j}"));
}
}
}
o
},
eval: |r, v, o| {
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
// w_i == w_j
o.push(v[&w_i] == v[&w_j]);
}
}
}
},
},
Feature {
name: "is_width_smaller", // (15) + (16)
labels: |r| {
let mut o = vec![];
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
o.push(format!("{w_i} < {w_j}"));
o.push(format!("{w_i} + 1 < {w_j}"));
o.push(format!("{w_i} - 1 < {w_j}"));
}
}
}
o
},
eval: |r, v, o| {
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
let (w_i, w_j) = (v[&w_i], v[&w_j]);
// w_i < w_j
o.push(w_i < w_j);
// w_i + 1 < w_j
o.push(w_i + 1 < w_j);
// w_i - 1 < w_j
o.push(w_i - 1 < w_j);
}
}
}
},
},
Feature {
name: "is_width_sum_smaller", // (17) + (18)
labels: |r| {
let mut o = vec![];
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
for w_k in r.widths() {
if w_k != w_i && w_k != w_j {
o.push(format!("{w_i} + {w_j} < {w_k}"));
o.push(format!("{w_i} as u64 + 2u64.pow({w_j}) < {w_k} as u64"));
}
}
}
}
}
o
},
eval: |r, v, o| {
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
for w_k in r.widths() {
if w_k != w_i && w_k != w_j {
let (w_i, w_j, w_k) = (v[&w_i], v[&w_j], v[&w_k]);
// w_i + w_j < w_k
o.push(w_i + w_j < w_k);
// w_i + 2**w_j < w_k
o.push(w_i as u64 + 2u64.pow(w_j) < w_k as u64);
}
}
}
}
}
},
},
];

struct Feature {
name: &'static str,
labels: fn(rule: &RuleInfo) -> Vec<String>,
eval: fn(rule: &RuleInfo, v: &FxHashMap<Var, WidthInt>, out: &mut bv::BitVec),
}
1 change: 1 addition & 0 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// author: Amelia Dobis <[email protected]>
// author: Mohanna Shahrad <[email protected]>

mod features;
mod samples;
mod summarize;

Expand Down
174 changes: 3 additions & 171 deletions tools/egraphs-cond-synth/src/summarize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>

use crate::features::apply_features;
use crate::samples::{RuleInfo, Samples};
use bitvec::macros::internal::funty::Fundamental;
use bitvec::prelude as bv;
use egg::Var;
use patronus::expr::WidthInt;
use rustc_hash::FxHashMap;

/// generate a simplified re-write condition from samples, using BDDs
pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String {
let results = check_features(rule, samples);
let results = apply_features(rule, samples);

// generate BDD terminals
let mut bdd = boolean_expression::BDD::<String>::new();
Expand All @@ -30,7 +26,7 @@ pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String {
.into_iter()
.enumerate()
.map(|(terminal, is_true)| {
if is_true.as_bool() {
if *is_true {
vars[terminal]
} else {
bdd.not(vars[terminal])
Expand All @@ -46,167 +42,3 @@ pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String {
// extract simplified expression
format!("{:?}", bdd.to_expr(cond))
}

pub fn check_features(rule: &RuleInfo, samples: &Samples) -> FeatureResult {
let labels = get_labels(rule);
let mut results = bv::BitVec::new();

for (assignment, is_equal) in samples.iter() {
let v = FxHashMap::from_iter(assignment);
results.push(is_equal);
for feature in FEATURES.iter() {
(feature.eval)(rule, &v, &mut results);
}
}

FeatureResult { labels, results }
}

pub struct FeatureResult {
labels: Vec<String>,
results: bv::BitVec,
}

impl FeatureResult {
pub fn num_features(&self) -> usize {
self.labels.len()
}
pub fn labels(&self) -> &[String] {
&self.labels
}
pub fn iter(&self) -> impl Iterator<Item = (&bv::BitSlice, bool)> + '_ {
let cs = self.num_features() + 1;
self.results.chunks(cs).map(|c| {
let is_equivalent = c[0];
let features = &c[1..];
(features, is_equivalent)
})
}
}

fn get_labels(rule: &RuleInfo) -> Vec<String> {
FEATURES
.iter()
.map(|f| (f.labels)(rule))
.reduce(|mut a, mut b| {
a.append(&mut b);
a
})
.unwrap_or_default()
}

const FEATURES: &[Feature] = &[
Feature {
name: "is_unsigned", // (13)
labels: |r| {
let mut o = vec![];
for sign in r.signs() {
o.push(format!("!{sign}"));
}
o
},
eval: |r, v, o| {
for sign in r.signs() {
// s_i == unsign
o.push(v[&sign] == 0);
}
},
},
Feature {
name: "is_width_equal", // (14)
labels: |r| {
let mut o = vec![];
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
o.push(format!("{w_i} == {w_j}"));
}
}
}
o
},
eval: |r, v, o| {
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
// w_i == w_j
o.push(v[&w_i] == v[&w_j]);
}
}
}
},
},
Feature {
name: "is_width_smaller", // (15) + (16)
labels: |r| {
let mut o = vec![];
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
o.push(format!("{w_i} < {w_j}"));
o.push(format!("{w_i} + 1 < {w_j}"));
o.push(format!("{w_i} - 1 < {w_j}"));
}
}
}
o
},
eval: |r, v, o| {
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
let (w_i, w_j) = (v[&w_i], v[&w_j]);
// w_i < w_j
o.push(w_i < w_j);
// w_i + 1 < w_j
o.push(w_i + 1 < w_j);
// w_i - 1 < w_j
o.push(w_i - 1 < w_j);
}
}
}
},
},
Feature {
name: "is_width_sum_smaller", // (17) + (18)
labels: |r| {
let mut o = vec![];
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
for w_k in r.widths() {
if w_k != w_i && w_k != w_j {
o.push(format!("{w_i} + {w_j} < {w_k}"));
o.push(format!("{w_i} as u64 + 2u64.pow({w_j}) < {w_k} as u64"));
}
}
}
}
}
o
},
eval: |r, v, o| {
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
for w_k in r.widths() {
if w_k != w_i && w_k != w_j {
let (w_i, w_j, w_k) = (v[&w_i], v[&w_j], v[&w_k]);
// w_i + w_j < w_k
o.push(w_i + w_j < w_k);
// w_i + 2**w_j < w_k
o.push(w_i as u64 + 2u64.pow(w_j) < w_k as u64);
}
}
}
}
}
},
},
];

struct Feature {
name: &'static str,
labels: fn(rule: &RuleInfo) -> Vec<String>,
eval: fn(rule: &RuleInfo, v: &FxHashMap<Var, WidthInt>, out: &mut bv::BitVec),
}

0 comments on commit 535eeef

Please sign in to comment.