Skip to content

Commit

Permalink
add bdd based formula generator
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 9, 2024
1 parent 99a8793 commit 271fb04
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ jobs:
run: cargo build --verbose --release -p patronus-egraphs-cond-synth
- name: synthesize commute-add condition
run: |
cargo run --release -p patronus-egraphs-cond-synth -- commute-add
cargo run --release -p patronus-egraphs-cond-synth -- --bdd-formula commute-add
semver:
name: Check Semantic Versioning of Patronus
Expand Down
8 changes: 8 additions & 0 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
mod samples;
mod summarize;

use crate::summarize::bdd_summarize;
use clap::Parser;
use egg::*;
use patronus::expr::*;
Expand Down Expand Up @@ -80,4 +81,11 @@ fn main() {
println!("{:?}", sample);
}
}

if args.bdd_formula {
let summarize_start = std::time::Instant::now();
let formula = bdd_summarize(&rule_info, &samples);
let summarize_delta_t = std::time::Instant::now() - summarize_start;
println!("Generated formula in {summarize_delta_t:?}:\n{}", formula);
}
}
103 changes: 84 additions & 19 deletions tools/egraphs-cond-synth/src/summarize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,63 @@ use egg::Var;
use patronus::expr::WidthInt;
use rustc_hash::FxHashMap;

/// generate a simplified re-writ condition from samples, using BDDs
pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) {
/// generate a simplified re-write condition from samples, using BDDs
pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String {
// generate all labels and the corresponding BDD terminals
let labels = get_labels(rule);
let mut bdd = boolean_expression::BDD::<usize>::new();
let vars: Vec<_> = (0..labels.len()).map(|ii| bdd.terminal(ii)).collect();

// start condition as trivially `true`
let mut cond = boolean_expression::BDD_ONE;
for (assignment, is_equal) in samples.iter() {
let v = FxHashMap::from_iter(assignment);
let mut outputs = vec![];
for feature in FEATURES.iter() {
(feature.eval)(rule, &v, &mut outputs);
}
let lits = outputs
.into_iter()
.enumerate()
.map(|(terminal, is_true)| {
if is_true {
vars[terminal]
} else {
bdd.not(vars[terminal])
}
})
.collect::<Vec<_>>();
let term = lits.into_iter().reduce(|a, b| bdd.and(a, b)).unwrap();
let term = if is_equal { term } else { bdd.not(term) };

cond = bdd.and(cond, term);
}

// extract simplified expression
format!("{:?}", bdd.to_expr(cond))
}

fn get_labels(rule: &RuleInfo) -> Vec<String> {
FEATURES
.iter()
.map(|f| (f.labels)(rule))
.reduce(|mut a, mut b| {
a.append(&mut b);
a
})
.unwrap_or_default()
}

const FEATURES: &[Feature] = &[
Feature {
name: "is_unsigned", // (13)
len: |r| Some(r.signs().count()),
labels: |r| {
let mut o = vec![];
for sign in r.signs() {
o.push(format!("!{sign}"));
}
o
},
eval: |r, v, o| {
for sign in r.signs() {
// s_i == unsign
Expand All @@ -27,12 +73,16 @@ const FEATURES: &[Feature] = &[
},
Feature {
name: "is_width_equal", // (14)
len: |r| {
if r.widths().count() <= 0 {
None
} else {
Some(r.widths().count() * (r.widths().count() - 1))
labels: |r| {
let mut o = vec![];
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
o.push(format!("{w_i} == {w_j}"));
}
}
}
o
},
eval: |r, v, o| {
for w_i in r.widths() {
Expand All @@ -47,12 +97,18 @@ const FEATURES: &[Feature] = &[
},
Feature {
name: "is_width_smaller", // (15) + (16)
len: |r| {
if r.widths().count() <= 0 {
None
} else {
Some(r.widths().count() * (r.widths().count() - 1) * 3)
labels: |r| {
let mut o = vec![];
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
o.push(format!("{w_i} < {w_j}"));
o.push(format!("{w_i} + 1 < {w_j}"));
o.push(format!("{w_i} - 1 < {w_j}"));
}
}
}
o
},
eval: |r, v, o| {
for w_i in r.widths() {
Expand All @@ -72,12 +128,21 @@ const FEATURES: &[Feature] = &[
},
Feature {
name: "is_width_sum_smaller", // (17) + (18)
len: |r| {
if r.widths().count() <= 1 {
None
} else {
Some(r.widths().count() * (r.widths().count() - 1) * (r.widths().count() - 2) * 2)
labels: |r| {
let mut o = vec![];
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
for w_k in r.widths() {
if w_k != w_i && w_k != w_j {
o.push(format!("{w_i} + {w_j} < {w_k}"));
o.push(format!("{w_i} as u64 + 2u64.pow({w_j}) < {w_k} as u64"));
}
}
}
}
}
o
},
eval: |r, v, o| {
for w_i in r.widths() {
Expand All @@ -101,6 +166,6 @@ const FEATURES: &[Feature] = &[

struct Feature {
name: &'static str,
len: fn(rule: &RuleInfo) -> Option<usize>,
labels: fn(rule: &RuleInfo) -> Vec<String>,
eval: fn(rule: &RuleInfo, v: &FxHashMap<Var, WidthInt>, out: &mut Vec<bool>),
}

0 comments on commit 271fb04

Please sign in to comment.