Skip to content

Commit

Permalink
cond synth: better insight into false positives
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 11, 2024
1 parent 5cb01eb commit b786140
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tools/egraphs-cond-synth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ clap.workspace = true
rustc-hash.workspace = true
easy-smt.workspace = true
boolean_expression.workspace = true
baa.workspace = true
indicatif = "0.17.9"
rayon = "1.10.0"
thread_local = "1.1.8"
bitvec = "1.0.1"
serde_json = "1.0.133"
serde = { version = "1.0.215", features = ["derive"] }
57 changes: 52 additions & 5 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ mod summarize;

use crate::features::{apply_features, FeatureResult};
use crate::rewrites::{create_rewrites, ArithRewrite};
use crate::samples::{get_rule_info, get_var_name, to_smt, RuleInfo, Samples};
use crate::samples::{
check_eq, find_symbols_in_expr, get_rule_info, get_var_name, start_solver, to_smt, RuleInfo,
Samples,
};
use crate::summarize::bdd_summarize;
use baa::BitVecOps;
use clap::Parser;
use patronus::expr::*;
use patronus::mc::get_smt_value;
use std::io::Write;
use std::path::{Path, PathBuf};

Expand Down Expand Up @@ -250,16 +255,58 @@ fn check_conditions(rule: &ArithRewrite, samples: &Samples, info: &RuleInfo) {
if !false_pos_examples.is_empty() {
println!("Some example assignments that are incorrectly classified as OK by our current condition:");
let mut ctx = Context::default();
let mut smt_ctx = start_solver(false);
for a in false_pos_examples {
println!("{a:?}");

// generate smt expressions
let (lhs, rhs) = rule.patterns();
let lhs = to_smt(&mut ctx, lhs, info, &a);
let rhs = to_smt(&mut ctx, rhs, info, &a);
let lhs_expr = to_smt(&mut ctx, lhs, info, &a);
let rhs_expr = to_smt(&mut ctx, rhs, info, &a);

// run SMT solver to get a counter example
smt_ctx.push_many(1).unwrap();
let resp = check_eq(&mut ctx, &mut smt_ctx, lhs_expr, rhs_expr);
assert_eq!(resp, easy_smt::Response::Sat);

// get assignments to variables
let is_eq = ctx.equal(lhs_expr, rhs_expr);
let vars = find_symbols_in_expr(&ctx, is_eq);
let mut values: Vec<String> = vars
.into_iter()
.map(|v| {
let name = ctx.get_symbol_name(v).unwrap();
let value = get_value(&ctx, &mut smt_ctx, v);
format!("{name}={value}")
})
.collect();
values.push(format!(
"lhs_result={}",
get_value(&ctx, &mut smt_ctx, lhs_expr)
));
values.push(format!(
"rhs_result={}",
get_value(&ctx, &mut smt_ctx, rhs_expr)
));
smt_ctx.pop_many(1).unwrap();

println!(
" {} =/= {}",
lhs.serialize_to_str(&ctx),
rhs.serialize_to_str(&ctx)
lhs_expr.serialize_to_str(&ctx),
rhs_expr.serialize_to_str(&ctx)
);
println!(" with: {}", values.join(", "));
}
}
}

fn get_value(ctx: &Context, smt_ctx: &mut easy_smt::Context, expr: ExprRef) -> String {
let tpe = expr.get_type(&ctx);
let v = patronus::smt::convert_expr(smt_ctx, &ctx, expr, &|_| None);
let value = get_smt_value(smt_ctx, v, tpe).unwrap();
if let baa::Value::BitVec(v) = value {
format!("{}", v.to_u64().unwrap())
} else {
unreachable!("no arrays!")
}
}
30 changes: 22 additions & 8 deletions tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,9 @@ pub fn generate_samples(
let assignment = rule_info.get_assignment(max_width, assignment_index);
let lhs_expr = to_smt(&mut ctx, lhs, &lhs_info, &assignment);
let rhs_expr = to_smt(&mut ctx, rhs, &rhs_info, &assignment);
let is_eq = ctx.equal(lhs_expr, rhs_expr);
let is_not_eq = ctx.not(is_eq);
let smt_expr = patronus::smt::convert_expr(&smt_ctx, &ctx, is_not_eq, &|_| None);

smt_ctx.push_many(1).unwrap();
declare_vars(&mut smt_ctx, &ctx, is_not_eq);
smt_ctx.assert(smt_expr).unwrap();
let resp = smt_ctx.check().unwrap();
let resp = check_eq(&mut ctx, &mut smt_ctx, lhs_expr, rhs_expr);
smt_ctx.pop_many(1).unwrap();

match resp {
Expand All @@ -89,7 +84,21 @@ pub fn generate_samples(
.reduce(|| Samples::new(&rule_info), Samples::merge)
}

fn start_solver(dump_smt: bool) -> easy_smt::Context {
pub fn check_eq(
ctx: &mut Context,
smt_ctx: &mut easy_smt::Context,
lhs_expr: ExprRef,
rhs_expr: ExprRef,
) -> easy_smt::Response {
let is_eq = ctx.equal(lhs_expr, rhs_expr);
let is_not_eq = ctx.not(is_eq);
let smt_expr = patronus::smt::convert_expr(&smt_ctx, &ctx, is_not_eq, &|_| None);
declare_vars(smt_ctx, &ctx, is_not_eq);
smt_ctx.assert(smt_expr).unwrap();
smt_ctx.check().unwrap()
}

pub fn start_solver(dump_smt: bool) -> easy_smt::Context {
let solver: patronus::mc::SmtSolverCmd = patronus::mc::BITWUZLA_CMD;
let dump_file = if dump_smt {
Some(std::fs::File::create("replay.smt").unwrap())
Expand Down Expand Up @@ -225,7 +234,7 @@ impl<'a> Iterator for SamplesIter<'a> {
}
}

fn declare_vars(smt_ctx: &mut easy_smt::Context, ctx: &Context, expr: ExprRef) {
pub fn find_symbols_in_expr(ctx: &Context, expr: ExprRef) -> Vec<ExprRef> {
// find all variables in the expression
let mut vars = FxHashSet::default();
patronus::expr::traversal::top_down(ctx, expr, |ctx, e| {
Expand All @@ -238,6 +247,11 @@ fn declare_vars(smt_ctx: &mut easy_smt::Context, ctx: &Context, expr: ExprRef) {
// declare them
let mut vars = Vec::from_iter(vars);
vars.sort();
vars
}

fn declare_vars(smt_ctx: &mut easy_smt::Context, ctx: &Context, expr: ExprRef) {
let vars = find_symbols_in_expr(ctx, expr);
for v in vars.into_iter() {
let expr = &ctx[v];
let tpe = patronus::smt::convert_tpe(smt_ctx, expr.get_type(ctx));
Expand Down

0 comments on commit b786140

Please sign in to comment.