Skip to content

Commit

Permalink
egraph: debugging rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 20, 2024
1 parent 3fb8ffd commit ed6c034
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 19 deletions.
2 changes: 2 additions & 0 deletions patronus-egraphs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,14 @@ pub(crate) fn verification_fig_1(ctx: &mut Context) -> (ExprRef, ExprRef) {
let b = ctx.bv_symbol("B", 16);
let m = ctx.bv_symbol("M", 4);
let n = ctx.bv_symbol("N", 4);
// (A << M) * (B << N)
let spec = ctx.build(|c| {
c.mul(
c.zero_extend(c.shift_left(c.zero_extend(a, 15), c.zero_extend(m, 27)), 32),
c.zero_extend(c.shift_left(c.zero_extend(b, 15), c.zero_extend(n, 27)), 32),
)
});
// (A * B) << (M + N)
let implementation = ctx.build(|c| {
c.shift_left(
c.zero_extend(c.mul(c.zero_extend(a, 16), c.zero_extend(b, 16)), 31),
Expand Down
19 changes: 16 additions & 3 deletions patronus-egraphs/src/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use crate::{get_const_width_or_sign, is_bin_op, EGraph};
use egg::Language;
use rustc_hash::FxHashMap;
use std::io::Write;
use std::io::{BufWriter, Write};

pub fn to_pdf(filename: &str, egraph: &EGraph) -> std::io::Result<()> {
use std::process::{Command, Stdio};
Expand All @@ -24,10 +24,16 @@ pub fn to_pdf(filename: &str, egraph: &EGraph) -> std::io::Result<()> {
}
}

pub fn to_dot(filename: &str, egraph: &EGraph) -> std::io::Result<()> {
let mut out = BufWriter::new(std::fs::File::create(filename)?);
write_to_dot(&mut out, egraph)?;
Ok(())
}

/// Reimplements egg's `to_dot` functionality.
/// This is necessary because we do not want to show the Width nodes in the graph, because
/// otherwise it becomes very confusing.
pub fn write_to_dot(out: &mut impl Write, egraph: &EGraph) -> std::io::Result<()> {
fn write_to_dot(out: &mut impl Write, egraph: &EGraph) -> std::io::Result<()> {
writeln!(out, "digraph egraph {{")?;

// set compound=true to enable edges to clusters
Expand All @@ -46,8 +52,15 @@ pub fn write_to_dot(out: &mut impl Write, egraph: &EGraph) -> std::io::Result<()
if !widths.contains_key(&class.id) {
writeln!(out, " subgraph cluster_{} {{", class.id)?;
writeln!(out, " style=dotted")?;
writeln!(out, " label=\"{}\"", class.id)?;
for (i, node) in class.iter().enumerate() {
writeln!(out, " {}.{}[label = \"{}\"]", class.id, i, node)?;
let label = if is_bin_op(node) {
let width = widths[&node.children()[0]];
format!("{node} ({width})")
} else {
format!("{node}")
};
writeln!(out, " {}.{}[label = \"{}\"]", class.id, i, label)?;
}
writeln!(out, " }}")?;
}
Expand Down
70 changes: 57 additions & 13 deletions patronus-egraphs/src/rewrites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ pub fn create_rewrites() -> Vec<ArithRewrite> {
vec![
// a + b => b + a
arith_rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"),
// a * b => b * a
arith_rewrite!("commute-mul"; "(* ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(* ?wo ?wb ?sb ?b ?wa ?sa ?a)"),
// (a << b) << x => a << (b + c)
arith_rewrite!("merge-left-shift";
// we require that b, c and (b + c) are all unsigned
Expand Down Expand Up @@ -70,11 +72,34 @@ pub fn create_rewrites() -> Vec<ArithRewrite> {
"(<< ?wo ?wab unsign (* ?wab ?wa unsign ?a ?wb unsign ?b) ?wc unsign ?c)" =>
// we set the width of (a << c) to the result width to satisfy wac >= wo
"(* ?wo ?wo unsign (<< ?wo ?wa unsign ?a ?wc unsign ?c) ?wb unsign ?b)";
// wab >= wo && all_signs_the_same
if["?wab", "?wo"], |w| w[0] >= w[1]),
// we want to determine that there is no overflow
// lhs: wab >= wa + wb && wo >= wab + max_shift(wc)
// rhs: wac >= wa + max_shift(c) && wo >= wac + wb
if["?wab", "?wa", "?wb", "?wo", "?wc"], |w| mul_no_ov(w[0], w[1], w[2]) && lsh_no_ov(w[3], w[0], w[4])),
]
}

/// Determines if there is no overflow possible for this addition.
fn add_no_ov(wo: WidthInt, wa: WidthInt, wb: WidthInt) -> bool {
wo >= max(wa, wb) + 1
}

/// Determines if there is no overflow possible for this multiplication.
fn mul_no_ov(wo: WidthInt, wa: WidthInt, wb: WidthInt) -> bool {
wo >= wa + wb
}

/// Determines if there is no overflow possible for this left shift.
fn lsh_no_ov(wo: WidthInt, wa: WidthInt, wb: WidthInt) -> bool {
if wb >= WidthInt::BITS {
// avoid overflow
false
} else {
let max_shift: WidthInt = (1 << wb) - 1;
wo >= wa + max_shift
}
}

pub struct ArithRewrite {
name: String,
/// most general lhs pattern
Expand Down Expand Up @@ -272,7 +297,7 @@ pub fn create_egg_rewrites() -> Vec<Rewrite<Arith, ()>> {
mod tests {
use super::*;
use crate::arithmetic::verification_fig_1;
use crate::to_arith;
use crate::{to_arith, to_dot, to_pdf};
use patronus::expr::{Context, SerializableIrNode};
#[test]
fn test_data_path_verification_fig_1_rewrites() {
Expand All @@ -289,12 +314,13 @@ mod tests {
let runner = egg::Runner::default()
.with_expr(&spec_e)
.with_expr(&impl_e)
.with_iter_limit(10)
.run(&egg_rewrites);

runner.print_report();

let spec_class = runner.roots[0];
let impl_class = runner.roots[1];
let spec_class = runner.egraph.find(runner.roots[0]);
let impl_class = runner.egraph.find(runner.roots[1]);
println!("{spec_class} {impl_class}");

let left_shift_mult = create_rewrites()
Expand All @@ -307,8 +333,25 @@ mod tests {
println!("{m:?}");
}

// to_pdf("graph.pdf", &runner.egraph).unwrap();
// runner.egraph.dot().to_pdf("full_graph.pdf").unwrap();
to_pdf("graph.pdf", &runner.egraph).unwrap();
to_dot("graph.dot", &runner.egraph).unwrap();
runner.egraph.dot().to_pdf("full_graph.pdf").unwrap();
runner.egraph.dot().to_dot("full_graph.dot").unwrap();

// investigating eclass 26 and 13 which should ideally be the same
println!("{}", inspect_e_class(&runner.egraph, 26));
println!("{}", inspect_e_class(&runner.egraph, 13));
println!("{}", inspect_e_class(&runner.egraph, 25));
println!("{}", inspect_e_class(&runner.egraph, 12));
}

fn inspect_e_class(egraph: &EGraph, id: usize) -> String {
let nodes = egraph[id.into()]
.nodes
.iter()
.map(|n| format!("{n} {:?}", n.children()))
.collect::<Vec<_>>();
format!("Class {id}: {}", nodes.join(", "))
}

#[test]
Expand All @@ -317,22 +360,23 @@ mod tests {
let a = ctx.bv_symbol("A", 16);
let b = ctx.bv_symbol("B", 16);
let in_smt_expr = ctx.add(a, b);
let in_smt_expr_2 = ctx.add(b, a);
assert_eq!(in_smt_expr.serialize_to_str(&ctx), "add(A, B)");

// run egraph operations
let egg_expr_in = to_arith(&ctx, in_smt_expr);
let egg_expr_in_2 = to_arith(&ctx, in_smt_expr_2);
let egg_rewrites = create_egg_rewrites();
let runner = egg::Runner::default()
.with_expr(&egg_expr_in)
.with_expr(&egg_expr_in_2)
.run(&egg_rewrites);

// check how many different nodes are representing the root node now
let root = runner.roots[0];
let root_nodes = &runner.egraph[root].nodes;
let final_eclass_1 = runner.egraph.find(runner.roots[0]);
let final_eclass_2 = runner.egraph.find(runner.roots[1]);
assert_eq!(
root_nodes.len(),
2,
"there should be two nodes if the rule has been applied"
final_eclass_1, final_eclass_2,
"inputs should be equivalent with commute-add"
);
}
}
11 changes: 9 additions & 2 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,16 @@ fn check_conditions(
_ => {} // ignore
}
}
let num_total = samples.num_total() as f64;
println!("The current implementation of the condition has:");
println!("False positives (BAD): {false_positive: >10}");
println!("False negatives (OK): {false_negative: >10}");
println!(
"False positives (BAD): {false_positive: >10} ({:.1}%)",
(false_positive * 100) as f64 / num_total
);
println!(
"False negatives (OK): {false_negative: >10} ({:.1}%)",
(false_negative * 100) as f64 / num_total
);
if !false_pos_examples.is_empty() {
println!("Some example assignments that are incorrectly classified as OK by our current condition:");
show_assignments(rule, info, &false_pos_examples, 10, CheckSatResponse::Sat);
Expand Down
6 changes: 5 additions & 1 deletion tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ impl Samples {
}

pub fn num_unequivalent(&self) -> usize {
self.is_equivalent.len() - self.num_equivalent()
self.num_total() - self.num_equivalent()
}

pub fn num_total(&self) -> usize {
self.is_equivalent.len()
}

pub fn iter(&self) -> impl Iterator<Item = (Assignment, bool)> + '_ {
Expand Down

0 comments on commit ed6c034

Please sign in to comment.