Skip to content

Commit

Permalink
cond synth: add option to check rewrite condition
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 11, 2024
1 parent e312022 commit 0622fb8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
25 changes: 23 additions & 2 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ fn main() {
} else {
// remember start time
let start = std::time::Instant::now();
let samples =
samples::generate_samples(rule, args.max_width, true, args.dump_smt, args.check_cond);
let samples = samples::generate_samples(rule, args.max_width, true, args.dump_smt);
let delta_t = std::time::Instant::now() - start;
println!(
"Took {delta_t:?} on {} threads.",
Expand All @@ -99,6 +98,28 @@ fn main() {
samples.num_unequivalent()
);

if args.check_cond {
// false positive => our current condition says it is equivalent, while it actually is not
let mut false_positive = 0u64;
// false negative => our current condition says the rule does not apply, while it actually could
let mut false_negative = 0u64;
for (a, is_eq) in samples.iter() {
let condition_res = rule.eval_condition(&a);
match (condition_res, is_eq) {
(true, false) => {
false_positive += 1;
}
(false, true) => {
false_negative += 1;
}
_ => {} // ignore
}
}
println!("The current implementation of the condition has:");
println!("False positives (BAD): {false_positive: >10}");
println!("False negatives (OK): {false_negative: >10}");
}

if let Some(out_filename) = args.write_assignments {
let mut file = std::fs::File::create(&out_filename).expect("failed to open output JSON");
samples
Expand Down
21 changes: 18 additions & 3 deletions tools/egraphs-cond-synth/src/rewrites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>

use crate::samples::Assignment;
use egg::*;
use patronus::expr::*;
use patronus_egraphs::*;
Expand Down Expand Up @@ -30,7 +31,7 @@ pub struct ArithRewrite {
/// rhs pattern with all widths derived from the lhs, maybe be the same as rhs
rhs_derived: Pattern<Arith>,
/// variables use by the condition
cond_vars: Vec<String>,
cond_vars: Vec<Var>,
/// condition of the re_write
cond: Option<fn(&[WidthInt]) -> bool>,
}
Expand All @@ -45,7 +46,7 @@ impl ArithRewrite {
) -> Self {
let cond_vars = cond_vars
.into_iter()
.map(|n| n.as_ref().to_string())
.map(|n| n.as_ref().parse().unwrap())
.collect();
Self {
name: name.to_string(),
Expand All @@ -67,7 +68,7 @@ impl ArithRewrite {
pub fn to_egg(&self) -> Vec<Rewrite<Arith, ()>> {
// TODO: support bi-directional rules
if let Some(cond) = self.cond {
let vars: Vec<Var> = self.cond_vars.iter().map(|n| n.parse().unwrap()).collect();
let vars: Vec<Var> = self.cond_vars.clone();
let condition = move |egraph: &mut EGraph, _, subst: &Subst| {
let values: Vec<WidthInt> = vars
.iter()
Expand All @@ -89,6 +90,20 @@ impl ArithRewrite {
.unwrap()]
}
}

pub fn eval_condition(&self, a: &Assignment) -> bool {
if let Some(cond) = self.cond {
let values: Vec<WidthInt> = self
.cond_vars
.iter()
.map(|v| a.iter().find(|(k, _)| k == v).unwrap().1)
.collect();
cond(values.as_slice())
} else {
// unconditional rewrite
true
}
}
}

pub fn create_rewrites() -> Vec<ArithRewrite> {
Expand Down
1 change: 0 additions & 1 deletion tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ pub fn generate_samples(
max_width: WidthInt,
show_progress: bool,
dump_smt: bool,
check_cond: bool,
) -> Samples {
let (lhs, rhs) = rule.patterns();
let lhs_info = analyze_pattern(lhs);
Expand Down

0 comments on commit 0622fb8

Please sign in to comment.