Skip to content

Commit

Permalink
cond synth: generate conditional rewrite rules
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 11, 2024
1 parent 2a1238b commit 4417398
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 46 deletions.
1 change: 1 addition & 0 deletions tools/egraphs-cond-synth/src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ const FEATURES: &[Feature] = &[
];

struct Feature {
#[allow(dead_code)]
name: &'static str,
labels: fn(rule: &RuleInfo) -> Vec<String>,
eval: fn(rule: &RuleInfo, v: &FxHashMap<Var, WidthInt>, out: &mut bv::BitVec),
Expand Down
4 changes: 2 additions & 2 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ fn main() {

// find rule and extract both sides
let rewrites = create_rewrites();
let rule = match rewrites.iter().find(|r| r.name.as_str() == args.rule) {
let rule = match rewrites.iter().find(|r| r.name() == args.rule) {
Some(r) => r,
None => {
let available = rewrites.iter().map(|r| r.name.as_str()).collect::<Vec<_>>();
let available = rewrites.iter().map(|r| r.name()).collect::<Vec<_>>();
panic!(
"Failed to find rewrite rule `{}`!\nAvailable rules are: {:?}",
args.rule, available
Expand Down
135 changes: 96 additions & 39 deletions tools/egraphs-cond-synth/src/rewrites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,76 +8,133 @@ use patronus_egraphs::*;

/// our version of the egg re-write macro
macro_rules! arith_rewrite {
(
$name:expr;
$lhs:expr => $rhs:expr
) => {{
ArithRewrite::new::<&str>($name, $lhs, $rhs, [], None)
}};
(
$name:expr;
$lhs:expr => $rhs:expr;
if $cond:expr
if $vars:expr, $cond:expr
) => {{
ArithRewrite::new($name, $lhs, $rhs)
ArithRewrite::new($name, $lhs, $rhs, $vars, Some($cond))
}};
}

struct ArithRewrite {
pub struct ArithRewrite {
name: String,
/// most general lhs pattern
lhs: Pattern<Arith>,
rhs: Pattern<Arith>,
/// rhs pattern with all widths derived from the lhs, maybe be the same as rhs
rhs_derived: Pattern<Arith>,
/// variables use by the condition
cond_vars: Vec<String>,
/// condition of the re_write
cond: Option<fn(&[WidthInt]) -> bool>,
}

impl ArithRewrite {
fn new(name: &str, lhs: &str, rhs: &str) -> Self {
fn new<S: AsRef<str>>(
name: &str,
lhs: &str,
rhs_derived: &str,
cond_vars: impl IntoIterator<Item = S>,
cond: Option<fn(&[WidthInt]) -> bool>,
) -> Self {
let cond_vars = cond_vars
.into_iter()
.map(|n| n.as_ref().to_string())
.collect();
Self {
name: name.to_string(),
lhs: lhs.parse::<_>().unwrap(),
rhs: rhs.parse::<_>().unwrap(),
rhs_derived: rhs_derived.parse::<_>().unwrap(),
cond,
cond_vars,
}
}

fn to_egg(&self) -> Rewrite<Arith, ()> {
Rewrite::new(self.name.clone(), self.lhs.clone(), self.rhs.clone()).unwrap()
pub fn name(&self) -> &str {
&self.name
}

pub fn patterns(&self) -> (&PatternAst<Arith>, &PatternAst<Arith>) {
(&self.lhs.ast, &self.rhs_derived.ast)
}

pub fn to_egg(&self) -> Vec<Rewrite<Arith, ()>> {
// TODO: support bi-directional rules
if let Some(cond) = self.cond {
let vars: Vec<Var> = self.cond_vars.iter().map(|n| n.parse().unwrap()).collect();
let condition = move |egraph: &mut EGraph, _, subst: &Subst| {
let values: Vec<WidthInt> = vars
.iter()
.map(|v| get_width_from_e_graph(egraph, subst, *v))
.collect();
cond(values.as_slice())
};
let cond_app = ConditionalApplier {
condition,
applier: self.rhs_derived.clone(),
};
vec![Rewrite::new(self.name.clone(), self.lhs.clone(), cond_app).unwrap()]
} else {
vec![Rewrite::new(
self.name.clone(),
self.lhs.clone(),
self.rhs_derived.clone(),
)
.unwrap()]
}
}
}

pub fn create_rewrites() -> Vec<Rewrite<Arith, ()>> {
let rewrites = vec![
arith_rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"; if true),
pub fn create_rewrites() -> Vec<ArithRewrite> {
vec![
arith_rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"),
arith_rewrite!("merge-left-shift";
// we require that b, c and (b + c) are all unsigned
"(<< ?wo ?wab ?sab (<< ?wab ?wa ?sa ?a ?wb 0 ?b) ?wc 0 ?c)" =>
// note: in this version we set the width of (b + c) on the RHS to be the width of the
// result (w_o)
"(<< ?wo ?wa ?sa ?a ?wo 0 (+ ?wo ?wb 0 ?b ?wc 0 ?c))"; if merge_left_shift_cond("?wo", "?wa", "?sa", "?wb", "?wc")),
];
rewrites.into_iter().map(|r| r.to_egg()).collect()
"(<< ?wo ?wa ?sa ?a ?wo 0 (+ ?wo ?wb 0 ?b ?wc 0 ?c))";
if["?wo", "?wa", "?sa", "?wb", "?wc"], |w| w[1] == w[2] && w[0] >= w[1]),
]
}

type EGraph = egg::EGraph<Arith, ()>;

fn merge_left_shift_cond(
wo: &'static str,
wa: &'static str,
sa: &'static str,
wb: &'static str,
wc: &'static str,
) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
let wo = wo.parse().unwrap();
let wa = wa.parse().unwrap();
let sa = sa.parse().unwrap();
let wb = wb.parse().unwrap();
let wc = wc.parse().unwrap();
move |egraph, _, subst| {
let wo = get_width_from_e_graph(egraph, subst, wo);
let wa = get_width_from_e_graph(egraph, subst, wa);
let sa = get_width_from_e_graph(egraph, subst, sa);
let wb = get_width_from_e_graph(egraph, subst, wb);
let wc = get_width_from_e_graph(egraph, subst, wc);
// actual condition
wa == wb && wo >= wa
}
fn get_width_from_e_graph(egraph: &EGraph, subst: &Subst, v: Var) -> WidthInt {
egraph[subst[v]]
.nodes
.iter()
.flat_map(|n| {
if let Arith::WidthConst(w) = n {
Some(*w)
} else {
None
}
})
.next()
.expect("failed to find constant width")
}

fn get_width_from_e_graph(egraph: &mut EGraph, subst: &Subst, v: Var) -> WidthInt {
match egraph[subst[v]].nodes.as_slice() {
[Arith::WidthConst(w)] => *w,
_ => unreachable!("expected a width!"),
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_to_egg() {
let egg_rewrites = create_rewrites()
.into_iter()
.map(|r| r.to_egg())
.reduce(|mut a, mut b| {
a.append(&mut b);
a
})
.unwrap();
assert_eq!(egg_rewrites.len(), 2);
}
}
11 changes: 6 additions & 5 deletions tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,31 @@
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>

use crate::rewrites::ArithRewrite;
use egg::*;
use indicatif::ProgressBar;
use patronus::expr::traversal::TraversalCmd;
use patronus::expr::{Context, ExprRef, TypeCheck, WidthInt};
use patronus_egraphs::*;
use rayon::prelude::*;
use rustc_hash::{FxHashMap, FxHashSet};
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};

pub fn get_rule_info(rule: &Rewrite<Arith, ()>) -> RuleInfo {
let (lhs, rhs) = extract_patterns(rule).expect("failed to extract patterns from rewrite rule");
pub fn get_rule_info(rule: &ArithRewrite) -> RuleInfo {
let (lhs, rhs) = rule.patterns();
let lhs_info = analyze_pattern(lhs);
let rhs_info = analyze_pattern(rhs);
lhs_info.merge(&rhs_info)
}

pub fn generate_samples(
rule: &Rewrite<Arith, ()>,
rule: &ArithRewrite,
max_width: WidthInt,
show_progress: bool,
dump_smt: bool,
check_cond: bool,
) -> Samples {
let (lhs, rhs) = extract_patterns(rule).expect("failed to extract patterns from rewrite rule");
let (lhs, rhs) = rule.patterns();
let lhs_info = analyze_pattern(lhs);
let rhs_info = analyze_pattern(rhs);
let rule_info = lhs_info.merge(&rhs_info);
Expand Down

0 comments on commit 4417398

Please sign in to comment.