From 535eeef623ee802fa43b0fe29543ee7999f9d1d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kevin=20L=C3=A4ufer?= Date: Tue, 10 Dec 2024 12:48:48 -0500 Subject: [PATCH] cond synth: add features module --- tools/egraphs-cond-synth/src/features.rs | 174 ++++++++++++++++++++++ tools/egraphs-cond-synth/src/main.rs | 1 + tools/egraphs-cond-synth/src/summarize.rs | 174 +--------------------- 3 files changed, 178 insertions(+), 171 deletions(-) create mode 100644 tools/egraphs-cond-synth/src/features.rs diff --git a/tools/egraphs-cond-synth/src/features.rs b/tools/egraphs-cond-synth/src/features.rs new file mode 100644 index 0000000..add5c41 --- /dev/null +++ b/tools/egraphs-cond-synth/src/features.rs @@ -0,0 +1,174 @@ +// Copyright 2024 Cornell University +// released under BSD 3-Clause License +// author: Kevin Laeufer + +use crate::samples::{RuleInfo, Samples}; +use bitvec::prelude as bv; +use egg::Var; +use patronus::expr::WidthInt; +use rustc_hash::FxHashMap; + +/// Applies the features from the ROVER paper to all assignments and returns the result. +pub fn apply_features(rule: &RuleInfo, samples: &Samples) -> FeatureResult { + let labels = get_labels(rule); + let mut results = bv::BitVec::new(); + + for (assignment, is_equal) in samples.iter() { + let v = FxHashMap::from_iter(assignment); + results.push(is_equal); + for feature in FEATURES.iter() { + (feature.eval)(rule, &v, &mut results); + } + } + + FeatureResult { labels, results } +} + +pub struct FeatureResult { + labels: Vec, + results: bv::BitVec, +} + +impl FeatureResult { + pub fn num_features(&self) -> usize { + self.labels.len() + } + pub fn labels(&self) -> &[String] { + &self.labels + } + pub fn iter(&self) -> impl Iterator + '_ { + let cs = self.num_features() + 1; + self.results.chunks(cs).map(|c| { + let is_equivalent = c[0]; + let features = &c[1..]; + (features, is_equivalent) + }) + } +} + +fn get_labels(rule: &RuleInfo) -> Vec { + FEATURES + .iter() + .map(|f| (f.labels)(rule)) + .reduce(|mut a, mut b| { + a.append(&mut b); + a + }) + .unwrap_or_default() +} + +const FEATURES: &[Feature] = &[ + Feature { + name: "is_unsigned", // (13) + labels: |r| { + let mut o = vec![]; + for sign in r.signs() { + o.push(format!("!{sign}")); + } + o + }, + eval: |r, v, o| { + for sign in r.signs() { + // s_i == unsign + o.push(v[&sign] == 0); + } + }, + }, + Feature { + name: "is_width_equal", // (14) + labels: |r| { + let mut o = vec![]; + for w_i in r.widths() { + for w_j in r.widths() { + if w_i != w_j { + o.push(format!("{w_i} == {w_j}")); + } + } + } + o + }, + eval: |r, v, o| { + for w_i in r.widths() { + for w_j in r.widths() { + if w_i != w_j { + // w_i == w_j + o.push(v[&w_i] == v[&w_j]); + } + } + } + }, + }, + Feature { + name: "is_width_smaller", // (15) + (16) + labels: |r| { + let mut o = vec![]; + for w_i in r.widths() { + for w_j in r.widths() { + if w_i != w_j { + o.push(format!("{w_i} < {w_j}")); + o.push(format!("{w_i} + 1 < {w_j}")); + o.push(format!("{w_i} - 1 < {w_j}")); + } + } + } + o + }, + eval: |r, v, o| { + for w_i in r.widths() { + for w_j in r.widths() { + if w_i != w_j { + let (w_i, w_j) = (v[&w_i], v[&w_j]); + // w_i < w_j + o.push(w_i < w_j); + // w_i + 1 < w_j + o.push(w_i + 1 < w_j); + // w_i - 1 < w_j + o.push(w_i - 1 < w_j); + } + } + } + }, + }, + Feature { + name: "is_width_sum_smaller", // (17) + (18) + labels: |r| { + let mut o = vec![]; + for w_i in r.widths() { + for w_j in r.widths() { + if w_i != w_j { + for w_k in r.widths() { + if w_k != w_i && w_k != w_j { + o.push(format!("{w_i} + {w_j} < {w_k}")); + o.push(format!("{w_i} as u64 + 2u64.pow({w_j}) < {w_k} as u64")); + } + } + } + } + } + o + }, + eval: |r, v, o| { + for w_i in r.widths() { + for w_j in r.widths() { + if w_i != w_j { + for w_k in r.widths() { + if w_k != w_i && w_k != w_j { + let (w_i, w_j, w_k) = (v[&w_i], v[&w_j], v[&w_k]); + // w_i + w_j < w_k + o.push(w_i + w_j < w_k); + // w_i + 2**w_j < w_k + o.push(w_i as u64 + 2u64.pow(w_j) < w_k as u64); + } + } + } + } + } + }, + }, +]; + +struct Feature { + name: &'static str, + labels: fn(rule: &RuleInfo) -> Vec, + eval: fn(rule: &RuleInfo, v: &FxHashMap, out: &mut bv::BitVec), +} diff --git a/tools/egraphs-cond-synth/src/main.rs b/tools/egraphs-cond-synth/src/main.rs index 14b2e1f..0d9f89a 100644 --- a/tools/egraphs-cond-synth/src/main.rs +++ b/tools/egraphs-cond-synth/src/main.rs @@ -5,6 +5,7 @@ // author: Amelia Dobis // author: Mohanna Shahrad +mod features; mod samples; mod summarize; diff --git a/tools/egraphs-cond-synth/src/summarize.rs b/tools/egraphs-cond-synth/src/summarize.rs index d5fadcd..ecd4cf8 100644 --- a/tools/egraphs-cond-synth/src/summarize.rs +++ b/tools/egraphs-cond-synth/src/summarize.rs @@ -2,16 +2,12 @@ // released under BSD 3-Clause License // author: Kevin Laeufer +use crate::features::apply_features; use crate::samples::{RuleInfo, Samples}; -use bitvec::macros::internal::funty::Fundamental; -use bitvec::prelude as bv; -use egg::Var; -use patronus::expr::WidthInt; -use rustc_hash::FxHashMap; /// generate a simplified re-write condition from samples, using BDDs pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String { - let results = check_features(rule, samples); + let results = apply_features(rule, samples); // generate BDD terminals let mut bdd = boolean_expression::BDD::::new(); @@ -30,7 +26,7 @@ pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String { .into_iter() .enumerate() .map(|(terminal, is_true)| { - if is_true.as_bool() { + if *is_true { vars[terminal] } else { bdd.not(vars[terminal]) @@ -46,167 +42,3 @@ pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String { // extract simplified expression format!("{:?}", bdd.to_expr(cond)) } - -pub fn check_features(rule: &RuleInfo, samples: &Samples) -> FeatureResult { - let labels = get_labels(rule); - let mut results = bv::BitVec::new(); - - for (assignment, is_equal) in samples.iter() { - let v = FxHashMap::from_iter(assignment); - results.push(is_equal); - for feature in FEATURES.iter() { - (feature.eval)(rule, &v, &mut results); - } - } - - FeatureResult { labels, results } -} - -pub struct FeatureResult { - labels: Vec, - results: bv::BitVec, -} - -impl FeatureResult { - pub fn num_features(&self) -> usize { - self.labels.len() - } - pub fn labels(&self) -> &[String] { - &self.labels - } - pub fn iter(&self) -> impl Iterator + '_ { - let cs = self.num_features() + 1; - self.results.chunks(cs).map(|c| { - let is_equivalent = c[0]; - let features = &c[1..]; - (features, is_equivalent) - }) - } -} - -fn get_labels(rule: &RuleInfo) -> Vec { - FEATURES - .iter() - .map(|f| (f.labels)(rule)) - .reduce(|mut a, mut b| { - a.append(&mut b); - a - }) - .unwrap_or_default() -} - -const FEATURES: &[Feature] = &[ - Feature { - name: "is_unsigned", // (13) - labels: |r| { - let mut o = vec![]; - for sign in r.signs() { - o.push(format!("!{sign}")); - } - o - }, - eval: |r, v, o| { - for sign in r.signs() { - // s_i == unsign - o.push(v[&sign] == 0); - } - }, - }, - Feature { - name: "is_width_equal", // (14) - labels: |r| { - let mut o = vec![]; - for w_i in r.widths() { - for w_j in r.widths() { - if w_i != w_j { - o.push(format!("{w_i} == {w_j}")); - } - } - } - o - }, - eval: |r, v, o| { - for w_i in r.widths() { - for w_j in r.widths() { - if w_i != w_j { - // w_i == w_j - o.push(v[&w_i] == v[&w_j]); - } - } - } - }, - }, - Feature { - name: "is_width_smaller", // (15) + (16) - labels: |r| { - let mut o = vec![]; - for w_i in r.widths() { - for w_j in r.widths() { - if w_i != w_j { - o.push(format!("{w_i} < {w_j}")); - o.push(format!("{w_i} + 1 < {w_j}")); - o.push(format!("{w_i} - 1 < {w_j}")); - } - } - } - o - }, - eval: |r, v, o| { - for w_i in r.widths() { - for w_j in r.widths() { - if w_i != w_j { - let (w_i, w_j) = (v[&w_i], v[&w_j]); - // w_i < w_j - o.push(w_i < w_j); - // w_i + 1 < w_j - o.push(w_i + 1 < w_j); - // w_i - 1 < w_j - o.push(w_i - 1 < w_j); - } - } - } - }, - }, - Feature { - name: "is_width_sum_smaller", // (17) + (18) - labels: |r| { - let mut o = vec![]; - for w_i in r.widths() { - for w_j in r.widths() { - if w_i != w_j { - for w_k in r.widths() { - if w_k != w_i && w_k != w_j { - o.push(format!("{w_i} + {w_j} < {w_k}")); - o.push(format!("{w_i} as u64 + 2u64.pow({w_j}) < {w_k} as u64")); - } - } - } - } - } - o - }, - eval: |r, v, o| { - for w_i in r.widths() { - for w_j in r.widths() { - if w_i != w_j { - for w_k in r.widths() { - if w_k != w_i && w_k != w_j { - let (w_i, w_j, w_k) = (v[&w_i], v[&w_j], v[&w_k]); - // w_i + w_j < w_k - o.push(w_i + w_j < w_k); - // w_i + 2**w_j < w_k - o.push(w_i as u64 + 2u64.pow(w_j) < w_k as u64); - } - } - } - } - } - }, - }, -]; - -struct Feature { - name: &'static str, - labels: fn(rule: &RuleInfo) -> Vec, - eval: fn(rule: &RuleInfo, v: &FxHashMap, out: &mut bv::BitVec), -}