diff --git a/patronus/src/expr.rs b/patronus/src/expr.rs index a2b60dc..f11ddf4 100644 --- a/patronus/src/expr.rs +++ b/patronus/src/expr.rs @@ -25,7 +25,7 @@ pub use parse::parse_expr; pub use serialize::SerializableIrNode; pub(crate) use serialize::{serialize_expr, serialize_expr_ref}; pub(crate) use simplify::simplify; -pub use simplify::simplify_single_expression; +pub use simplify::{simplify_single_expression, Simplifier}; pub use transform::simple_transform_expr; pub(crate) use transform::{do_transform_expr, ExprTransformMode}; pub use types::{TypeCheck, TypeCheckError}; diff --git a/tools/simplify/src/main.rs b/tools/simplify/src/main.rs index 92ce42d..35c0b24 100644 --- a/tools/simplify/src/main.rs +++ b/tools/simplify/src/main.rs @@ -5,7 +5,6 @@ use clap::Parser; use patronus::expr::*; use patronus::smt::{parse_command, serialize_cmd, SmtCommand}; -use patronus::*; use rustc_hash::FxHashMap; use std::io::{BufRead, BufReader, BufWriter}; use std::path::PathBuf; @@ -16,6 +15,8 @@ use std::path::PathBuf; #[command(version)] #[command(about = "Parses a SMT file, simplifies it and writes the simplified version to an output file.", long_about = None)] struct Args { + #[arg(long)] + do_not_simplify: bool, #[arg(value_name = "INPUT", index = 1)] input_file: PathBuf, #[arg(value_name = "OUTPUT", index = 2)] @@ -37,13 +38,34 @@ fn main() { // read and write commands let mut ctx = Context::default(); let mut st = FxHashMap::default(); + let mut simplifier = Simplifier::new(SparseExprMap::default()); while let Some(cmd) = read_cmd(&mut in_reader, &mut ctx, &mut st).expect("failed to read command") { + let cmd = if args.do_not_simplify { + cmd + } else { + simplify(&mut ctx, &mut simplifier, cmd) + }; serialize_cmd(&mut out, Some(&ctx), &cmd).expect("failed to write command"); } } +fn simplify>>( + ctx: &mut Context, + s: &mut Simplifier, + cmd: SmtCommand, +) -> SmtCommand { + match cmd { + SmtCommand::Assert(e) => SmtCommand::Assert(s.simplify(ctx, e)), + SmtCommand::DefineConst(sym, value) => SmtCommand::DefineConst(sym, s.simplify(ctx, value)), + SmtCommand::CheckSatAssuming(e) => { + SmtCommand::CheckSatAssuming(e.into_iter().map(|e| s.simplify(ctx, e)).collect()) + } + other => other, + } +} + type SymbolTable = FxHashMap; fn read_cmd(