Skip to content

Commit

Permalink
define features
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 9, 2024
1 parent bda9bfa commit 99a8793
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 8 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ baa = "0.14.6"
egg = "0.9.5"
easy-smt = "0.2.3"
regex = "1.11.1"
boolean_expression = "0.4.4"
clap = { version = "4.x", features = ["derive"] }
patronus = {path = "patronus"}

Expand Down
2 changes: 1 addition & 1 deletion patronus/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fuzzy-matcher = "0.3.7"
lazy_static = "1.4.0"
easy-smt.workspace = true
smallvec = { version = "1.x", features = ["union"] }
boolean_expression = "0.4.4"
boolean_expression.workspace = true
regex.workspace = true
baa.workspace = true
rustc-hash.workspace = true
Expand Down
1 change: 1 addition & 0 deletions tools/egraphs-cond-synth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ egg.workspace = true
clap.workspace = true
rustc-hash.workspace = true
easy-smt.workspace = true
boolean_expression.workspace = true
indicatif = "0.17.9"
rayon = "1.10.0"
thread_local = "1.1.8"
6 changes: 5 additions & 1 deletion tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// author: Mohanna Shahrad <[email protected]>

mod samples;
mod summarize;

use clap::Parser;
use egg::*;
Expand All @@ -23,6 +24,8 @@ struct Args {
print_samples: bool,
#[arg(long)]
dump_smt: bool,
#[arg(long)]
bdd_formula: bool,
#[arg(value_name = "RULE", index = 1)]
rule: String,
}
Expand Down Expand Up @@ -58,7 +61,8 @@ fn main() {
}
};

let samples = samples::generate_samples(&args.rule, rule, args.max_width, true, args.dump_smt);
let (samples, rule_info) =
samples::generate_samples(&args.rule, rule, args.max_width, true, args.dump_smt);
let delta_t = std::time::Instant::now() - start;

println!("Found {} equivalent rewrites.", samples.num_equivalent());
Expand Down
22 changes: 16 additions & 6 deletions tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub fn generate_samples(
max_width: WidthInt,
show_progress: bool,
dump_smt: bool,
) -> Samples {
) -> (Samples, RuleInfo) {
let (lhs, rhs) = extract_patterns(rule).expect("failed to extract patterns from rewrite rule");
println!("{}: {} => {}", rule_name, lhs, rhs);

Expand Down Expand Up @@ -80,9 +80,10 @@ pub fn generate_samples(
.collect::<Vec<_>>();

// merge results from different threads
samples
let samples = samples
.into_par_iter()
.reduce(|| Samples::new(&rule_info), Samples::merge)
.reduce(|| Samples::new(&rule_info), Samples::merge);
(samples, rule_info)
}

fn start_solver(dump_smt: bool) -> easy_smt::Context {
Expand Down Expand Up @@ -209,7 +210,7 @@ fn extract_patterns<L: Language>(
}

#[derive(Debug, Clone, Eq, PartialEq)]
struct RuleInfo {
pub struct RuleInfo {
/// width parameters
widths: Vec<Var>,
/// sign parameters
Expand All @@ -220,21 +221,30 @@ struct RuleInfo {
}

#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
enum VarOrConst {
pub enum VarOrConst {
C(WidthInt),
V(Var),
}

/// a unique symbol in a rule, needs to be replaced with an SMT bit-vector symbol for equivalence checks
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
struct RuleSymbol {
pub struct RuleSymbol {
var: Var,
width: VarOrConst,
sign: VarOrConst,
}

pub type Assignment = Vec<(Var, WidthInt)>;

impl RuleInfo {
pub fn signs(&self) -> impl Iterator<Item = Var> + '_ {
self.signs.iter().cloned()
}
pub fn widths(&self) -> impl Iterator<Item = Var> + '_ {
self.widths.iter().cloned()
}
}

impl RuleInfo {
fn merge(&self, other: &Self) -> Self {
let widths = union_vecs(&self.widths, &other.widths);
Expand Down
106 changes: 106 additions & 0 deletions tools/egraphs-cond-synth/src/summarize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2024 Cornell University
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>

use crate::samples::{RuleInfo, Samples};
use egg::Var;
use patronus::expr::WidthInt;
use rustc_hash::FxHashMap;

/// generate a simplified re-writ condition from samples, using BDDs
pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) {
for (assignment, is_equal) in samples.iter() {
let v = FxHashMap::from_iter(assignment);
}
}

const FEATURES: &[Feature] = &[
Feature {
name: "is_unsigned", // (13)
len: |r| Some(r.signs().count()),
eval: |r, v, o| {
for sign in r.signs() {
// s_i == unsign
o.push(v[&sign] == 0);
}
},
},
Feature {
name: "is_width_equal", // (14)
len: |r| {
if r.widths().count() <= 0 {
None
} else {
Some(r.widths().count() * (r.widths().count() - 1))
}
},
eval: |r, v, o| {
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
// w_i == w_j
o.push(v[&w_i] == v[&w_j]);
}
}
}
},
},
Feature {
name: "is_width_smaller", // (15) + (16)
len: |r| {
if r.widths().count() <= 0 {
None
} else {
Some(r.widths().count() * (r.widths().count() - 1) * 3)
}
},
eval: |r, v, o| {
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
let (w_i, w_j) = (v[&w_i], v[&w_j]);
// w_i < w_j
o.push(w_i < w_j);
// w_i + 1 < w_j
o.push(w_i + 1 < w_j);
// w_i - 1 < w_j
o.push(w_i - 1 < w_j);
}
}
}
},
},
Feature {
name: "is_width_sum_smaller", // (17) + (18)
len: |r| {
if r.widths().count() <= 1 {
None
} else {
Some(r.widths().count() * (r.widths().count() - 1) * (r.widths().count() - 2) * 2)
}
},
eval: |r, v, o| {
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
for w_k in r.widths() {
if w_k != w_i && w_k != w_j {
let (w_i, w_j, w_k) = (v[&w_i], v[&w_j], v[&w_k]);
// w_i + w_j < w_k
o.push(w_i + w_j < w_k);
// w_i + 2**w_j < w_k
o.push(w_i as u64 + 2u64.pow(w_j) < w_k as u64);
}
}
}
}
}
},
},
];

struct Feature {
name: &'static str,
len: fn(rule: &RuleInfo) -> Option<usize>,
eval: fn(rule: &RuleInfo, v: &FxHashMap<Var, WidthInt>, out: &mut Vec<bool>),
}

0 comments on commit 99a8793

Please sign in to comment.